mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-01 19:17:41 +03:00
unload whisper model when live transcription is complete
This commit is contained in:
parent
1a6fdf99ab
commit
10038b6e86
@ -40,7 +40,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
|
||||
self.camera_config = camera_config
|
||||
self.requestor = requestor
|
||||
self.stream = None
|
||||
self.whispermodel = None
|
||||
self.whisper_model = None
|
||||
self.model_runner = model_runner
|
||||
self.transcription_segments = []
|
||||
self.audio_queue = queue.Queue()
|
||||
@ -51,6 +51,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
|
||||
if self.config.audio_transcription.model_size == "large":
|
||||
# Whisper models need to be per-process and can only run one stream at a time
|
||||
# TODO: try parallel: https://github.com/SYSTRAN/faster-whisper/issues/100
|
||||
logger.debug(f"Loading Whisper model for {self.camera_config.name}")
|
||||
self.whisper_model = FasterWhisperASR(
|
||||
modelsize="tiny",
|
||||
device="cuda"
|
||||
@ -64,6 +65,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
|
||||
asr=self.whisper_model,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Loading sherpa stream for {self.camera_config.name}")
|
||||
self.stream = self.model_runner.model.create_stream()
|
||||
logger.debug(
|
||||
f"Audio transcription (live) initialized for {self.camera_config.name}"
|
||||
@ -101,7 +103,11 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
|
||||
self.stream.insert_audio_chunk(audio_data)
|
||||
output = self.stream.process_iter()
|
||||
text = output[2].strip()
|
||||
is_endpoint = text.endswith((".", "!", "?"))
|
||||
is_endpoint = (
|
||||
text.endswith((".", "!", "?"))
|
||||
and sum(len(str(lines)) for lines in self.transcription_segments)
|
||||
> 300
|
||||
)
|
||||
|
||||
if text:
|
||||
self.transcription_segments.append(text)
|
||||
@ -153,10 +159,17 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
|
||||
logger.debug(
|
||||
f"Starting audio transcription thread for {self.camera_config.name}"
|
||||
)
|
||||
|
||||
# start with an empty transcription
|
||||
self.requestor.send_data(
|
||||
f"{self.camera_config.name}/audio/transcription",
|
||||
"",
|
||||
)
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# Get audio data from queue with a timeout to check stop_event
|
||||
obj_data, audio = self.audio_queue.get(timeout=0.1)
|
||||
_, audio = self.audio_queue.get(timeout=0.1)
|
||||
result = self.__process_audio_stream(audio)
|
||||
|
||||
if not result:
|
||||
@ -172,7 +185,7 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
|
||||
self.audio_queue.task_done()
|
||||
|
||||
if is_endpoint:
|
||||
self.reset(obj_data["camera"])
|
||||
self.reset()
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
@ -184,7 +197,16 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
|
||||
f"Stopping audio transcription thread for {self.camera_config.name}"
|
||||
)
|
||||
|
||||
def reset(self, camera: str) -> None:
|
||||
def clear_audio_queue(self) -> None:
|
||||
# Clear the audio queue
|
||||
while not self.audio_queue.empty():
|
||||
try:
|
||||
self.audio_queue.get_nowait()
|
||||
self.audio_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.config.audio_transcription.model_size == "large":
|
||||
# get final output from whisper
|
||||
output = self.stream.finish()
|
||||
@ -197,20 +219,41 @@ class AudioTranscriptionRealTimeProcessor(RealTimeProcessorApi):
|
||||
|
||||
# reset whisper
|
||||
self.stream.init()
|
||||
self.transcription_segments = []
|
||||
else:
|
||||
# reset sherpa
|
||||
self.model_runner.model.reset(self.stream)
|
||||
|
||||
# Clear the audio queue
|
||||
while not self.audio_queue.empty():
|
||||
try:
|
||||
self.audio_queue.get_nowait()
|
||||
self.audio_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
logger.debug("Stream reset")
|
||||
|
||||
def check_unload_model(self) -> None:
|
||||
# regularly called in the loop in audio maintainer
|
||||
if (
|
||||
self.config.audio_transcription.model_size == "large"
|
||||
and self.whisper_model is not None
|
||||
):
|
||||
logger.debug(f"Unloading Whisper model for {self.camera_config.name}")
|
||||
self.clear_audio_queue()
|
||||
self.transcription_segments = []
|
||||
self.stream = None
|
||||
self.whisper_model = None
|
||||
|
||||
self.requestor.send_data(
|
||||
f"{self.camera_config.name}/audio/transcription",
|
||||
"",
|
||||
)
|
||||
if (
|
||||
self.config.audio_transcription.model_size == "small"
|
||||
and self.stream is not None
|
||||
):
|
||||
logger.debug(f"Clearing sherpa stream for {self.camera_config.name}")
|
||||
self.stream = None
|
||||
|
||||
self.requestor.send_data(
|
||||
f"{self.camera_config.name}/audio/transcription",
|
||||
"",
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the transcription thread and clean up."""
|
||||
self.stop_event.set()
|
||||
|
||||
@ -234,18 +234,18 @@ class AudioEventMaintainer(threading.Thread):
|
||||
)
|
||||
|
||||
# run audio transcription
|
||||
if self.transcription_processor is not None and (
|
||||
self.camera_config.audio_transcription.live_enabled
|
||||
):
|
||||
self.transcribing = True
|
||||
# process audio until we've reached the endpoint
|
||||
self.transcription_processor.process_audio(
|
||||
{
|
||||
"id": f"{self.camera_config.name}_audio",
|
||||
"camera": self.camera_config.name,
|
||||
},
|
||||
audio,
|
||||
)
|
||||
if self.transcription_processor is not None:
|
||||
if self.camera_config.audio_transcription.live_enabled:
|
||||
# process audio until we've reached the endpoint
|
||||
self.transcription_processor.process_audio(
|
||||
{
|
||||
"id": f"{self.camera_config.name}_audio",
|
||||
"camera": self.camera_config.name,
|
||||
},
|
||||
audio,
|
||||
)
|
||||
else:
|
||||
self.transcription_processor.check_unload_model()
|
||||
|
||||
self.expire_detections()
|
||||
|
||||
@ -320,13 +320,6 @@ class AudioEventMaintainer(threading.Thread):
|
||||
)
|
||||
self.detections[detection["label"]] = None
|
||||
|
||||
# clear real-time transcription
|
||||
if self.transcription_processor is not None:
|
||||
self.transcription_processor.reset(self.camera_config.name)
|
||||
self.requestor.send_data(
|
||||
f"{self.camera_config.name}/audio/transcription", ""
|
||||
)
|
||||
|
||||
def expire_all_detections(self) -> None:
|
||||
"""Immediately end all current detections"""
|
||||
now = datetime.datetime.now().timestamp()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user