Rename to not be teachable machine specific

This commit is contained in:
Nicolas Mowen 2025-05-23 06:56:28 -06:00
parent 0c05755743
commit 5df1cf6d1d
3 changed files with 20 additions and 20 deletions

View File

@ -34,35 +34,35 @@ class BirdClassificationConfig(FrigateBaseModel):
) )
class TeachableMachineStateCameraConfig(FrigateBaseModel): class CustomClassificationStateCameraConfig(FrigateBaseModel):
crop: list[int, int, int, int] = Field( crop: list[int, int, int, int] = Field(
title="Crop of image frame on this camera to run classification on." title="Crop of image frame on this camera to run classification on."
) )
class TeachableMachineStateConfig(FrigateBaseModel): class CustomClassificationStateConfig(FrigateBaseModel):
cameras: Dict[str, TeachableMachineStateCameraConfig] = Field( cameras: Dict[str, CustomClassificationStateCameraConfig] = Field(
title="Cameras to run classification on." title="Cameras to run classification on."
) )
class TeachableMachineObjectConfig(FrigateBaseModel): class CustomClassificationObjectConfig(FrigateBaseModel):
objects: list[str] = Field(title="Object types to classify.") objects: list[str] = Field(title="Object types to classify.")
class TeachableMachineConfig(FrigateBaseModel): class CustomClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="Enable running the model.") enabled: bool = Field(default=True, title="Enable running the model.")
model_path: str = Field(title="Path to teachable machine tflite model.") model_path: str = Field(title="Path to teachable machine tflite model.")
labelmap_path: str = Field(title="Path to teachable machine labelmap.") labelmap_path: str = Field(title="Path to teachable machine labelmap.")
object_config: TeachableMachineObjectConfig | None = Field(default=None) object_config: CustomClassificationObjectConfig | None = Field(default=None)
state_config: TeachableMachineStateConfig | None = Field(default=None) state_config: CustomClassificationStateConfig | None = Field(default=None)
class ClassificationConfig(FrigateBaseModel): class ClassificationConfig(FrigateBaseModel):
bird: BirdClassificationConfig = Field( bird: BirdClassificationConfig = Field(
default_factory=BirdClassificationConfig, title="Bird classification config." default_factory=BirdClassificationConfig, title="Bird classification config."
) )
teachable_machine: Dict[str, TeachableMachineConfig] = Field( custom: Dict[str, CustomClassificationConfig] = Field(
default={}, title="Teachable Machine Model Configs." default={}, title="Teachable Machine Model Configs."
) )

View File

@ -11,7 +11,7 @@ from frigate.comms.event_metadata_updater import (
EventMetadataTypeEnum, EventMetadataTypeEnum,
) )
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.config.classification import TeachableMachineConfig from frigate.config.classification import CustomClassificationConfig
from frigate.util.builtin import load_labels from frigate.util.builtin import load_labels
from frigate.util.object import calculate_region from frigate.util.object import calculate_region
@ -26,11 +26,11 @@ except ModuleNotFoundError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TeachableMachineStateProcessor(RealTimeProcessorApi): class CustomStateClassificationProcessor(RealTimeProcessorApi):
def __init__( def __init__(
self, self,
config: FrigateConfig, config: FrigateConfig,
model_config: TeachableMachineConfig, model_config: CustomClassificationConfig,
metrics: DataProcessorMetrics, metrics: DataProcessorMetrics,
): ):
super().__init__(config, metrics) super().__init__(config, metrics)
@ -96,11 +96,11 @@ class TeachableMachineStateProcessor(RealTimeProcessorApi):
pass pass
class TeachableMachineObjectProcessor(RealTimeProcessorApi): class CustomObjectClassificationProcessor(RealTimeProcessorApi):
def __init__( def __init__(
self, self,
config: FrigateConfig, config: FrigateConfig,
model_config: TeachableMachineConfig, model_config: CustomClassificationConfig,
sub_label_publisher: EventMetadataPublisher, sub_label_publisher: EventMetadataPublisher,
metrics: DataProcessorMetrics, metrics: DataProcessorMetrics,
): ):

View File

@ -42,14 +42,14 @@ from frigate.data_processing.post.license_plate import (
) )
from frigate.data_processing.real_time.api import RealTimeProcessorApi from frigate.data_processing.real_time.api import RealTimeProcessorApi
from frigate.data_processing.real_time.bird import BirdRealTimeProcessor from frigate.data_processing.real_time.bird import BirdRealTimeProcessor
from frigate.data_processing.real_time.custom_classification import (
CustomObjectClassificationProcessor,
CustomStateClassificationProcessor,
)
from frigate.data_processing.real_time.face import FaceRealTimeProcessor from frigate.data_processing.real_time.face import FaceRealTimeProcessor
from frigate.data_processing.real_time.license_plate import ( from frigate.data_processing.real_time.license_plate import (
LicensePlateRealTimeProcessor, LicensePlateRealTimeProcessor,
) )
from frigate.data_processing.real_time.teachable_machine import (
TeachableMachineObjectProcessor,
TeachableMachineStateProcessor,
)
from frigate.data_processing.types import DataProcessorMetrics, PostProcessDataEnum from frigate.data_processing.types import DataProcessorMetrics, PostProcessDataEnum
from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum
from frigate.genai import get_genai_client from frigate.genai import get_genai_client
@ -149,9 +149,9 @@ class EmbeddingMaintainer(threading.Thread):
for model in self.config.classification.teachable_machine.values(): for model in self.config.classification.teachable_machine.values():
self.realtime_processors.append( self.realtime_processors.append(
TeachableMachineStateProcessor(self.config, model, self.metrics) CustomStateClassificationProcessor(self.config, model, self.metrics)
if model.state_config != None if model.state_config != None
else TeachableMachineObjectProcessor( else CustomObjectClassificationProcessor(
self.config, self.config,
model, model,
self.event_metadata_publisher, self.event_metadata_publisher,
@ -503,7 +503,7 @@ class EmbeddingMaintainer(threading.Thread):
if isinstance(processor, LicensePlateRealTimeProcessor): if isinstance(processor, LicensePlateRealTimeProcessor):
processor.process_frame(camera, yuv_frame, True) processor.process_frame(camera, yuv_frame, True)
if isinstance(processor, TeachableMachineStateProcessor): if isinstance(processor, CustomStateClassificationProcessor):
processor.process_frame({"camera": camera}, yuv_frame) processor.process_frame({"camera": camera}, yuv_frame)
self.frame_manager.close(frame_name) self.frame_manager.close(frame_name)