Cleanup bird mypy

This commit is contained in:
Nicolas Mowen 2026-03-26 10:58:24 -06:00
parent 178d30606d
commit 4c72a210a9
2 changed files with 25 additions and 18 deletions

View File

@ -4,8 +4,9 @@ import logging
import os import os
import threading import threading
import time import time
from typing import Optional from typing import Any, Optional
from embeddings.embeddings import Embeddings
from peewee import DoesNotExist from peewee import DoesNotExist
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
@ -31,7 +32,7 @@ class AudioTranscriptionPostProcessor(PostProcessorApi):
self, self,
config: FrigateConfig, config: FrigateConfig,
requestor: InterProcessRequestor, requestor: InterProcessRequestor,
embeddings, embeddings: Embeddings,
metrics: DataProcessorMetrics, metrics: DataProcessorMetrics,
): ):
super().__init__(config, metrics, None) super().__init__(config, metrics, None)
@ -40,7 +41,7 @@ class AudioTranscriptionPostProcessor(PostProcessorApi):
self.embeddings = embeddings self.embeddings = embeddings
self.recognizer = None self.recognizer = None
self.transcription_lock = threading.Lock() self.transcription_lock = threading.Lock()
self.transcription_thread = None self.transcription_thread: threading.Thread | None = None
self.transcription_running = False self.transcription_running = False
# faster-whisper handles model downloading automatically # faster-whisper handles model downloading automatically
@ -69,7 +70,7 @@ class AudioTranscriptionPostProcessor(PostProcessorApi):
self.recognizer = None self.recognizer = None
def process_data( def process_data(
self, data: dict[str, any], data_type: PostProcessDataEnum self, data: dict[str, Any], data_type: PostProcessDataEnum
) -> None: ) -> None:
"""Transcribe audio from a recording. """Transcribe audio from a recording.
@ -141,13 +142,13 @@ class AudioTranscriptionPostProcessor(PostProcessorApi):
except Exception as e: except Exception as e:
logger.error(f"Error in audio transcription post-processing: {e}") logger.error(f"Error in audio transcription post-processing: {e}")
def __transcribe_audio(self, audio_data: bytes) -> Optional[tuple[str, float]]: def __transcribe_audio(self, audio_data: bytes) -> Optional[str]:
"""Transcribe WAV audio data using faster-whisper.""" """Transcribe WAV audio data using faster-whisper."""
if not self.recognizer: if not self.recognizer:
logger.debug("Recognizer not initialized") logger.debug("Recognizer not initialized")
return None return None
try: try: # type: ignore[unreachable]
# Save audio data to a temporary wav (faster-whisper expects a file) # Save audio data to a temporary wav (faster-whisper expects a file)
temp_wav = os.path.join(CACHE_DIR, f"temp_audio_{int(time.time())}.wav") temp_wav = os.path.join(CACHE_DIR, f"temp_audio_{int(time.time())}.wav")
with open(temp_wav, "wb") as f: with open(temp_wav, "wb") as f:
@ -176,7 +177,7 @@ class AudioTranscriptionPostProcessor(PostProcessorApi):
logger.error(f"Error transcribing audio: {e}") logger.error(f"Error transcribing audio: {e}")
return None return None
def _transcription_wrapper(self, event: dict[str, any]) -> None: def _transcription_wrapper(self, event: dict[str, Any]) -> None:
"""Wrapper to run transcription and reset running flag when done.""" """Wrapper to run transcription and reset running flag when done."""
try: try:
self.process_data( self.process_data(
@ -194,7 +195,7 @@ class AudioTranscriptionPostProcessor(PostProcessorApi):
self.requestor.send_data(UPDATE_AUDIO_TRANSCRIPTION_STATE, "idle") self.requestor.send_data(UPDATE_AUDIO_TRANSCRIPTION_STATE, "idle")
def handle_request(self, topic: str, request_data: dict[str, any]) -> str | None: def handle_request(self, topic: str, request_data: dict[str, Any]) -> str | None:
if topic == "transcribe_audio": if topic == "transcribe_audio":
event = request_data["event"] event = request_data["event"]

View File

@ -14,7 +14,7 @@ from frigate.comms.event_metadata_updater import (
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import MODEL_CACHE_DIR from frigate.const import MODEL_CACHE_DIR
from frigate.log import suppress_stderr_during from frigate.log import suppress_stderr_during
from frigate.util.object import calculate_region from frigate.util.image import calculate_region
from ..types import DataProcessorMetrics from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi from .api import RealTimeProcessorApi
@ -35,10 +35,10 @@ class BirdRealTimeProcessor(RealTimeProcessorApi):
metrics: DataProcessorMetrics, metrics: DataProcessorMetrics,
): ):
super().__init__(config, metrics) super().__init__(config, metrics)
self.interpreter: Interpreter = None self.interpreter: Interpreter | None = None
self.sub_label_publisher = sub_label_publisher self.sub_label_publisher = sub_label_publisher
self.tensor_input_details: dict[str, Any] = None self.tensor_input_details: list[dict[str, Any]] | None = None
self.tensor_output_details: dict[str, Any] = None self.tensor_output_details: list[dict[str, Any]] | None = None
self.detected_birds: dict[str, float] = {} self.detected_birds: dict[str, float] = {}
self.labelmap: dict[int, str] = {} self.labelmap: dict[int, str] = {}
@ -61,7 +61,7 @@ class BirdRealTimeProcessor(RealTimeProcessorApi):
self.downloader = ModelDownloader( self.downloader = ModelDownloader(
model_name="bird", model_name="bird",
download_path=download_path, download_path=download_path,
file_names=self.model_files.keys(), file_names=list(self.model_files.keys()),
download_func=self.__download_models, download_func=self.__download_models,
complete_func=self.__build_detector, complete_func=self.__build_detector,
) )
@ -102,8 +102,12 @@ class BirdRealTimeProcessor(RealTimeProcessorApi):
i += 1 i += 1
line = f.readline() line = f.readline()
def process_frame(self, obj_data, frame): def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
if not self.interpreter: if (
not self.interpreter
or not self.tensor_input_details
or not self.tensor_output_details
):
return return
if obj_data["label"] != "bird": if obj_data["label"] != "bird":
@ -145,7 +149,7 @@ class BirdRealTimeProcessor(RealTimeProcessorApi):
self.tensor_output_details[0]["index"] self.tensor_output_details[0]["index"]
)[0] )[0]
probs = res / res.sum(axis=0) probs = res / res.sum(axis=0)
best_id = np.argmax(probs) best_id = int(np.argmax(probs))
if best_id == 964: if best_id == 964:
logger.debug("No bird classification was detected.") logger.debug("No bird classification was detected.")
@ -179,9 +183,11 @@ class BirdRealTimeProcessor(RealTimeProcessorApi):
self.config.classification = payload self.config.classification = payload
logger.debug("Bird classification config updated dynamically") logger.debug("Bird classification config updated dynamically")
def handle_request(self, topic, request_data): def handle_request(
self, topic: str, request_data: dict[str, Any]
) -> dict[str, Any] | None:
return None return None
def expire_object(self, object_id, camera): def expire_object(self, object_id: str, camera: str) -> None:
if object_id in self.detected_birds: if object_id in self.detected_birds:
self.detected_birds.pop(object_id) self.detected_birds.pop(object_id)