Add classification config

This commit is contained in:
Nicolas Mowen 2025-01-13 07:29:43 -07:00
parent 61870184df
commit bbdb712b33
6 changed files with 36 additions and 9 deletions

View File

@ -9,7 +9,7 @@ from .logger import * # noqa: F403
from .mqtt import * # noqa: F403
from .notification import * # noqa: F403
from .proxy import * # noqa: F403
from .semantic_search import * # noqa: F403
from .classification import * # noqa: F403
from .telemetry import * # noqa: F403
from .tls import * # noqa: F403
from .ui import * # noqa: F403

View File

@ -11,6 +11,22 @@ __all__ = [
]
class BirdClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable bird classification.")
threshold: float = Field(
default=0.9,
title="Minimum classification score required to be considered a match.",
gt=0.0,
le=1.0,
)
class ClassificationConfig(FrigateBaseModel):
bird: BirdClassificationConfig = Field(
default_factory=BirdClassificationConfig, title="Bird classification config."
)
class SemanticSearchConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable semantic search.")
reindex: Optional[bool] = Field(

View File

@ -51,17 +51,18 @@ from .camera.review import ReviewConfig
from .camera.snapshots import SnapshotsConfig
from .camera.timestamp import TimestampStyleConfig
from .camera_group import CameraGroupConfig
from .classification import (
ClassificationConfig,
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .database import DatabaseConfig
from .env import EnvVars
from .logger import LoggerConfig
from .mqtt import MqttConfig
from .notification import NotificationConfig
from .proxy import ProxyConfig
from .semantic_search import (
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .telemetry import TelemetryConfig
from .tls import TlsConfig
from .ui import UIConfig
@ -331,6 +332,9 @@ class FrigateConfig(FrigateBaseModel):
default_factory=TelemetryConfig, title="Telemetry configuration."
)
tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.")
classification: ClassificationConfig = Field(
default_factory=ClassificationConfig, title="Object classification config."
)
semantic_search: SemanticSearchConfig = Field(
default_factory=SemanticSearchConfig, title="Semantic search configuration."
)

View File

@ -123,6 +123,11 @@ class BirdProcessor(RealTimeProcessorApi):
return
score = round(probs[best_id], 2)
if score < self.config.classification.bird.threshold:
logger.debug(f"Score {score} is not above required threshold")
return
previous_score = self.detected_birds.get(obj_data["id"], 0.0)
if score <= previous_score:
@ -145,4 +150,5 @@ class BirdProcessor(RealTimeProcessorApi):
return None
def expire_object(self, object_id):
pass
if object_id in self.detected_birds:
self.detected_birds.pop(object_id)

View File

@ -8,7 +8,7 @@ from pyclipper import ET_CLOSEDPOLYGON, JT_ROUND, PyclipperOffset
from shapely.geometry import Polygon
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config.semantic_search import LicensePlateRecognitionConfig
from frigate.config.classification import LicensePlateRecognitionConfig
from frigate.embeddings.embeddings import Embeddings
logger = logging.getLogger(__name__)

View File

@ -79,7 +79,8 @@ class EmbeddingMaintainer(threading.Thread):
if self.config.face_recognition.enabled:
self.processors.append(FaceProcessor(self.config, metrics))
self.processors.append(BirdProcessor(self.config, metrics))
if self.config.classification.bird.enabled:
self.processors.append(BirdProcessor(self.config, metrics))
# create communication for updating event descriptions
self.requestor = InterProcessRequestor()