diff --git a/frigate/config/classification.py b/frigate/config/classification.py index 94548ee01..ba64f9007 100644 --- a/frigate/config/classification.py +++ b/frigate/config/classification.py @@ -34,35 +34,35 @@ class BirdClassificationConfig(FrigateBaseModel): ) -class TeachableMachineStateCameraConfig(FrigateBaseModel): +class CustomClassificationStateCameraConfig(FrigateBaseModel): crop: list[int, int, int, int] = Field( title="Crop of image frame on this camera to run classification on." ) -class TeachableMachineStateConfig(FrigateBaseModel): - cameras: Dict[str, TeachableMachineStateCameraConfig] = Field( +class CustomClassificationStateConfig(FrigateBaseModel): + cameras: Dict[str, CustomClassificationStateCameraConfig] = Field( title="Cameras to run classification on." ) -class TeachableMachineObjectConfig(FrigateBaseModel): +class CustomClassificationObjectConfig(FrigateBaseModel): objects: list[str] = Field(title="Object types to classify.") -class TeachableMachineConfig(FrigateBaseModel): +class CustomClassificationConfig(FrigateBaseModel): enabled: bool = Field(default=True, title="Enable running the model.") model_path: str = Field(title="Path to teachable machine tflite model.") labelmap_path: str = Field(title="Path to teachable machine labelmap.") - object_config: TeachableMachineObjectConfig | None = Field(default=None) - state_config: TeachableMachineStateConfig | None = Field(default=None) + object_config: CustomClassificationObjectConfig | None = Field(default=None) + state_config: CustomClassificationStateConfig | None = Field(default=None) class ClassificationConfig(FrigateBaseModel): bird: BirdClassificationConfig = Field( default_factory=BirdClassificationConfig, title="Bird classification config." ) - teachable_machine: Dict[str, TeachableMachineConfig] = Field( + custom: Dict[str, CustomClassificationConfig] = Field( default={}, title="Teachable Machine Model Configs." ) diff --git a/frigate/data_processing/real_time/teachable_machine.py b/frigate/data_processing/real_time/custom_classification.py similarity index 94% rename from frigate/data_processing/real_time/teachable_machine.py rename to frigate/data_processing/real_time/custom_classification.py index f7e540c3e..189b8f8d3 100644 --- a/frigate/data_processing/real_time/teachable_machine.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -11,7 +11,7 @@ from frigate.comms.event_metadata_updater import ( EventMetadataTypeEnum, ) from frigate.config import FrigateConfig -from frigate.config.classification import TeachableMachineConfig +from frigate.config.classification import CustomClassificationConfig from frigate.util.builtin import load_labels from frigate.util.object import calculate_region @@ -26,11 +26,11 @@ except ModuleNotFoundError: logger = logging.getLogger(__name__) -class TeachableMachineStateProcessor(RealTimeProcessorApi): +class CustomStateClassificationProcessor(RealTimeProcessorApi): def __init__( self, config: FrigateConfig, - model_config: TeachableMachineConfig, + model_config: CustomClassificationConfig, metrics: DataProcessorMetrics, ): super().__init__(config, metrics) @@ -96,11 +96,11 @@ class TeachableMachineStateProcessor(RealTimeProcessorApi): pass -class TeachableMachineObjectProcessor(RealTimeProcessorApi): +class CustomObjectClassificationProcessor(RealTimeProcessorApi): def __init__( self, config: FrigateConfig, - model_config: TeachableMachineConfig, + model_config: CustomClassificationConfig, sub_label_publisher: EventMetadataPublisher, metrics: DataProcessorMetrics, ): diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index c8b1a079f..975ba8ba5 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -42,14 +42,14 @@ from frigate.data_processing.post.license_plate import ( ) from frigate.data_processing.real_time.api import RealTimeProcessorApi from frigate.data_processing.real_time.bird import BirdRealTimeProcessor +from frigate.data_processing.real_time.custom_classification import ( + CustomObjectClassificationProcessor, + CustomStateClassificationProcessor, +) from frigate.data_processing.real_time.face import FaceRealTimeProcessor from frigate.data_processing.real_time.license_plate import ( LicensePlateRealTimeProcessor, ) -from frigate.data_processing.real_time.teachable_machine import ( - TeachableMachineObjectProcessor, - TeachableMachineStateProcessor, -) from frigate.data_processing.types import DataProcessorMetrics, PostProcessDataEnum from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum from frigate.genai import get_genai_client @@ -149,9 +149,9 @@ class EmbeddingMaintainer(threading.Thread): for model in self.config.classification.teachable_machine.values(): self.realtime_processors.append( - TeachableMachineStateProcessor(self.config, model, self.metrics) + CustomStateClassificationProcessor(self.config, model, self.metrics) if model.state_config != None - else TeachableMachineObjectProcessor( + else CustomObjectClassificationProcessor( self.config, model, self.event_metadata_publisher, @@ -503,7 +503,7 @@ class EmbeddingMaintainer(threading.Thread): if isinstance(processor, LicensePlateRealTimeProcessor): processor.process_frame(camera, yuv_frame, True) - if isinstance(processor, TeachableMachineStateProcessor): + if isinstance(processor, CustomStateClassificationProcessor): processor.process_frame({"camera": camera}, yuv_frame) self.frame_manager.close(frame_name)