Add classification config

This commit is contained in:
Nicolas Mowen 2025-01-13 07:29:43 -07:00
parent 61870184df
commit bbdb712b33
6 changed files with 36 additions and 9 deletions

View File

@ -9,7 +9,7 @@ from .logger import * # noqa: F403
from .mqtt import * # noqa: F403 from .mqtt import * # noqa: F403
from .notification import * # noqa: F403 from .notification import * # noqa: F403
from .proxy import * # noqa: F403 from .proxy import * # noqa: F403
from .semantic_search import * # noqa: F403 from .classification import * # noqa: F403
from .telemetry import * # noqa: F403 from .telemetry import * # noqa: F403
from .tls import * # noqa: F403 from .tls import * # noqa: F403
from .ui import * # noqa: F403 from .ui import * # noqa: F403

View File

@ -11,6 +11,22 @@ __all__ = [
] ]
class BirdClassificationConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable bird classification.")
threshold: float = Field(
default=0.9,
title="Minimum classification score required to be considered a match.",
gt=0.0,
le=1.0,
)
class ClassificationConfig(FrigateBaseModel):
bird: BirdClassificationConfig = Field(
default_factory=BirdClassificationConfig, title="Bird classification config."
)
class SemanticSearchConfig(FrigateBaseModel): class SemanticSearchConfig(FrigateBaseModel):
enabled: bool = Field(default=False, title="Enable semantic search.") enabled: bool = Field(default=False, title="Enable semantic search.")
reindex: Optional[bool] = Field( reindex: Optional[bool] = Field(

View File

@ -51,17 +51,18 @@ from .camera.review import ReviewConfig
from .camera.snapshots import SnapshotsConfig from .camera.snapshots import SnapshotsConfig
from .camera.timestamp import TimestampStyleConfig from .camera.timestamp import TimestampStyleConfig
from .camera_group import CameraGroupConfig from .camera_group import CameraGroupConfig
from .classification import (
ClassificationConfig,
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .database import DatabaseConfig from .database import DatabaseConfig
from .env import EnvVars from .env import EnvVars
from .logger import LoggerConfig from .logger import LoggerConfig
from .mqtt import MqttConfig from .mqtt import MqttConfig
from .notification import NotificationConfig from .notification import NotificationConfig
from .proxy import ProxyConfig from .proxy import ProxyConfig
from .semantic_search import (
FaceRecognitionConfig,
LicensePlateRecognitionConfig,
SemanticSearchConfig,
)
from .telemetry import TelemetryConfig from .telemetry import TelemetryConfig
from .tls import TlsConfig from .tls import TlsConfig
from .ui import UIConfig from .ui import UIConfig
@ -331,6 +332,9 @@ class FrigateConfig(FrigateBaseModel):
default_factory=TelemetryConfig, title="Telemetry configuration." default_factory=TelemetryConfig, title="Telemetry configuration."
) )
tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.") tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.")
classification: ClassificationConfig = Field(
default_factory=ClassificationConfig, title="Object classification config."
)
semantic_search: SemanticSearchConfig = Field( semantic_search: SemanticSearchConfig = Field(
default_factory=SemanticSearchConfig, title="Semantic search configuration." default_factory=SemanticSearchConfig, title="Semantic search configuration."
) )

View File

@ -123,6 +123,11 @@ class BirdProcessor(RealTimeProcessorApi):
return return
score = round(probs[best_id], 2) score = round(probs[best_id], 2)
if score < self.config.classification.bird.threshold:
logger.debug(f"Score {score} is not above required threshold")
return
previous_score = self.detected_birds.get(obj_data["id"], 0.0) previous_score = self.detected_birds.get(obj_data["id"], 0.0)
if score <= previous_score: if score <= previous_score:
@ -145,4 +150,5 @@ class BirdProcessor(RealTimeProcessorApi):
return None return None
def expire_object(self, object_id): def expire_object(self, object_id):
pass if object_id in self.detected_birds:
self.detected_birds.pop(object_id)

View File

@ -8,7 +8,7 @@ from pyclipper import ET_CLOSEDPOLYGON, JT_ROUND, PyclipperOffset
from shapely.geometry import Polygon from shapely.geometry import Polygon
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.config.semantic_search import LicensePlateRecognitionConfig from frigate.config.classification import LicensePlateRecognitionConfig
from frigate.embeddings.embeddings import Embeddings from frigate.embeddings.embeddings import Embeddings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -79,7 +79,8 @@ class EmbeddingMaintainer(threading.Thread):
if self.config.face_recognition.enabled: if self.config.face_recognition.enabled:
self.processors.append(FaceProcessor(self.config, metrics)) self.processors.append(FaceProcessor(self.config, metrics))
self.processors.append(BirdProcessor(self.config, metrics)) if self.config.classification.bird.enabled:
self.processors.append(BirdProcessor(self.config, metrics))
# create communication for updating event descriptions # create communication for updating event descriptions
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()