Rename to not be teachable machine specific

This commit is contained in:
Nicolas Mowen 2025-05-23 06:56:28 -06:00
parent 0c05755743
commit 5df1cf6d1d
3 changed files with 20 additions and 20 deletions

View File

@ -34,35 +34,35 @@ class BirdClassificationConfig(FrigateBaseModel):
)
class TeachableMachineStateCameraConfig(FrigateBaseModel):
class CustomClassificationStateCameraConfig(FrigateBaseModel):
crop: list[int, int, int, int] = Field(
title="Crop of image frame on this camera to run classification on."
)
class TeachableMachineStateConfig(FrigateBaseModel):
cameras: Dict[str, TeachableMachineStateCameraConfig] = Field(
class CustomClassificationStateConfig(FrigateBaseModel):
cameras: Dict[str, CustomClassificationStateCameraConfig] = Field(
title="Cameras to run classification on."
)
class TeachableMachineObjectConfig(FrigateBaseModel):
class CustomClassificationObjectConfig(FrigateBaseModel):
objects: list[str] = Field(title="Object types to classify.")
class TeachableMachineConfig(FrigateBaseModel):
class CustomClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=True, title="Enable running the model.")
model_path: str = Field(title="Path to teachable machine tflite model.")
labelmap_path: str = Field(title="Path to teachable machine labelmap.")
object_config: TeachableMachineObjectConfig | None = Field(default=None)
state_config: TeachableMachineStateConfig | None = Field(default=None)
object_config: CustomClassificationObjectConfig | None = Field(default=None)
state_config: CustomClassificationStateConfig | None = Field(default=None)
class ClassificationConfig(FrigateBaseModel):
bird: BirdClassificationConfig = Field(
default_factory=BirdClassificationConfig, title="Bird classification config."
)
teachable_machine: Dict[str, TeachableMachineConfig] = Field(
custom: Dict[str, CustomClassificationConfig] = Field(
default={}, title="Teachable Machine Model Configs."
)

View File

@ -11,7 +11,7 @@ from frigate.comms.event_metadata_updater import (
EventMetadataTypeEnum,
)
from frigate.config import FrigateConfig
from frigate.config.classification import TeachableMachineConfig
from frigate.config.classification import CustomClassificationConfig
from frigate.util.builtin import load_labels
from frigate.util.object import calculate_region
@ -26,11 +26,11 @@ except ModuleNotFoundError:
logger = logging.getLogger(__name__)
class TeachableMachineStateProcessor(RealTimeProcessorApi):
class CustomStateClassificationProcessor(RealTimeProcessorApi):
def __init__(
self,
config: FrigateConfig,
model_config: TeachableMachineConfig,
model_config: CustomClassificationConfig,
metrics: DataProcessorMetrics,
):
super().__init__(config, metrics)
@ -96,11 +96,11 @@ class TeachableMachineStateProcessor(RealTimeProcessorApi):
pass
class TeachableMachineObjectProcessor(RealTimeProcessorApi):
class CustomObjectClassificationProcessor(RealTimeProcessorApi):
def __init__(
self,
config: FrigateConfig,
model_config: TeachableMachineConfig,
model_config: CustomClassificationConfig,
sub_label_publisher: EventMetadataPublisher,
metrics: DataProcessorMetrics,
):

View File

@ -42,14 +42,14 @@ from frigate.data_processing.post.license_plate import (
)
from frigate.data_processing.real_time.api import RealTimeProcessorApi
from frigate.data_processing.real_time.bird import BirdRealTimeProcessor
from frigate.data_processing.real_time.custom_classification import (
CustomObjectClassificationProcessor,
CustomStateClassificationProcessor,
)
from frigate.data_processing.real_time.face import FaceRealTimeProcessor
from frigate.data_processing.real_time.license_plate import (
LicensePlateRealTimeProcessor,
)
from frigate.data_processing.real_time.teachable_machine import (
TeachableMachineObjectProcessor,
TeachableMachineStateProcessor,
)
from frigate.data_processing.types import DataProcessorMetrics, PostProcessDataEnum
from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum
from frigate.genai import get_genai_client
@ -149,9 +149,9 @@ class EmbeddingMaintainer(threading.Thread):
for model in self.config.classification.teachable_machine.values():
self.realtime_processors.append(
TeachableMachineStateProcessor(self.config, model, self.metrics)
CustomStateClassificationProcessor(self.config, model, self.metrics)
if model.state_config != None
else TeachableMachineObjectProcessor(
else CustomObjectClassificationProcessor(
self.config,
model,
self.event_metadata_publisher,
@ -503,7 +503,7 @@ class EmbeddingMaintainer(threading.Thread):
if isinstance(processor, LicensePlateRealTimeProcessor):
processor.process_frame(camera, yuv_frame, True)
if isinstance(processor, TeachableMachineStateProcessor):
if isinstance(processor, CustomStateClassificationProcessor):
processor.process_frame({"camera": camera}, yuv_frame)
self.frame_manager.close(frame_name)