unload whisper model when live transcription is complete

This commit is contained in:
Josh Hawkins 2025-06-03 06:38:33 -05:00
parent 1a6fdf99ab
commit 10038b6e86
2 changed files with 68 additions and 32 deletions

View File

@ -40,7 +40,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
self.camera_config = camera_config self.camera_config = camera_config
self.requestor = requestor self.requestor = requestor
self.stream = None self.stream = None
self.whispermodel = None self.whisper_model = None
self.model_runner = model_runner self.model_runner = model_runner
self.transcription_segments = [] self.transcription_segments = []
self.audio_queue = queue.Queue() self.audio_queue = queue.Queue()
@ -51,6 +51,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
if self.config.audio_transcription.model_size == "large": if self.config.audio_transcription.model_size == "large":
# Whisper models need to be per-process and can only run one stream at a time # Whisper models need to be per-process and can only run one stream at a time
# TODO: try parallel: https://github.com/SYSTRAN/faster-whisper/issues/100 # TODO: try parallel: https://github.com/SYSTRAN/faster-whisper/issues/100
logger.debug(f"Loading Whisper model for {self.camera_config.name}")
self.whisper_model = FasterWhisperASR( self.whisper_model = FasterWhisperASR(
modelsize="tiny", modelsize="tiny",
device="cuda" device="cuda"
@ -64,6 +65,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
asr=self.whisper_model, asr=self.whisper_model,
) )
else: else:
logger.debug(f"Loading sherpa stream for {self.camera_config.name}")
self.stream = self.model_runner.model.create_stream() self.stream = self.model_runner.model.create_stream()
logger.debug( logger.debug(
f"Audio transcription (live) initialized for {self.camera_config.name}" f"Audio transcription (live) initialized for {self.camera_config.name}"
@ -101,7 +103,11 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
self.stream.insert_audio_chunk(audio_data) self.stream.insert_audio_chunk(audio_data)
output = self.stream.process_iter() output = self.stream.process_iter()
text = output[2].strip() text = output[2].strip()
is_endpoint = text.endswith((".", "!", "?")) is_endpoint = (
text.endswith((".", "!", "?"))
and sum(len(str(lines)) for lines in self.transcription_segments)
> 300
)
if text: if text:
self.transcription_segments.append(text) self.transcription_segments.append(text)
@ -153,10 +159,17 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
logger.debug( logger.debug(
f"Starting audio transcription thread for {self.camera_config.name}" f"Starting audio transcription thread for {self.camera_config.name}"
) )
# start with an empty transcription
self.requestor.send_data(
f"{self.camera_config.name}/audio/transcription",
"",
)
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
# Get audio data from queue with a timeout to check stop_event # Get audio data from queue with a timeout to check stop_event
obj_data, audio = self.audio_queue.get(timeout=0.1) _, audio = self.audio_queue.get(timeout=0.1)
result = self.__process_audio_stream(audio) result = self.__process_audio_stream(audio)
if not result: if not result:
@ -172,7 +185,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
self.audio_queue.task_done() self.audio_queue.task_done()
if is_endpoint: if is_endpoint:
self.reset(obj_data["camera"]) self.reset()
except queue.Empty: except queue.Empty:
continue continue
@ -184,7 +197,16 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
f"Stopping audio transcription thread for {self.camera_config.name}" f"Stopping audio transcription thread for {self.camera_config.name}"
) )
def reset(self, camera: str) -> None: def clear_audio_queue(self) -> None:
# Clear the audio queue
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
self.audio_queue.task_done()
except queue.Empty:
break
def reset(self) -> None:
if self.config.audio_transcription.model_size == "large": if self.config.audio_transcription.model_size == "large":
# get final output from whisper # get final output from whisper
output = self.stream.finish() output = self.stream.finish()
@ -197,20 +219,41 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
# reset whisper # reset whisper
self.stream.init() self.stream.init()
self.transcription_segments = []
else: else:
# reset sherpa # reset sherpa
self.model_runner.model.reset(self.stream) self.model_runner.model.reset(self.stream)
# Clear the audio queue
while not self.audio_queue.empty():
try:
self.audio_queue.get_nowait()
self.audio_queue.task_done()
except queue.Empty:
break
logger.debug("Stream reset") logger.debug("Stream reset")
def check_unload_model(self) -> None:
# regularly called in the loop in audio maintainer
if (
self.config.audio_transcription.model_size == "large"
and self.whisper_model is not None
):
logger.debug(f"Unloading Whisper model for {self.camera_config.name}")
self.clear_audio_queue()
self.transcription_segments = []
self.stream = None
self.whisper_model = None
self.requestor.send_data(
f"{self.camera_config.name}/audio/transcription",
"",
)
if (
self.config.audio_transcription.model_size == "small"
and self.stream is not None
):
logger.debug(f"Clearing sherpa stream for {self.camera_config.name}")
self.stream = None
self.requestor.send_data(
f"{self.camera_config.name}/audio/transcription",
"",
)
def stop(self) -> None: def stop(self) -> None:
"""Stop the transcription thread and clean up.""" """Stop the transcription thread and clean up."""
self.stop_event.set() self.stop_event.set()

View File

@ -234,18 +234,18 @@ class AudioEventMaintainer(threading.Thread):
) )
# run audio transcription # run audio transcription
if self.transcription_processor is not None and ( if self.transcription_processor is not None:
self.camera_config.audio_transcription.live_enabled if self.camera_config.audio_transcription.live_enabled:
): # process audio until we've reached the endpoint
self.transcribing = True self.transcription_processor.process_audio(
# process audio until we've reached the endpoint {
self.transcription_processor.process_audio( "id": f"{self.camera_config.name}_audio",
{ "camera": self.camera_config.name,
"id": f"{self.camera_config.name}_audio", },
"camera": self.camera_config.name, audio,
}, )
audio, else:
) self.transcription_processor.check_unload_model()
self.expire_detections() self.expire_detections()
@ -320,13 +320,6 @@ class AudioEventMaintainer(threading.Thread):
) )
self.detections[detection["label"]] = None self.detections[detection["label"]] = None
# clear real-time transcription
if self.transcription_processor is not None:
self.transcription_processor.reset(self.camera_config.name)
self.requestor.send_data(
f"{self.camera_config.name}/audio/transcription", ""
)
def expire_all_detections(self) -> None: def expire_all_detections(self) -> None:
"""Immediately end all current detections""" """Immediately end all current detections"""
now = datetime.datetime.now().timestamp() now = datetime.datetime.now().timestamp()