diff --git a/frigate/config/__init__.py b/frigate/config/__init__.py index 1af2f08fe..2f9ec0c56 100644 --- a/frigate/config/__init__.py +++ b/frigate/config/__init__.py @@ -9,7 +9,7 @@ from .logger import * # noqa: F403 from .mqtt import * # noqa: F403 from .notification import * # noqa: F403 from .proxy import * # noqa: F403 -from .semantic_search import * # noqa: F403 +from .classification import * # noqa: F403 from .telemetry import * # noqa: F403 from .tls import * # noqa: F403 from .ui import * # noqa: F403 diff --git a/frigate/config/semantic_search.py b/frigate/config/classification.py similarity index 78% rename from frigate/config/semantic_search.py rename to frigate/config/classification.py index 66b8c7170..4e806f9d9 100644 --- a/frigate/config/semantic_search.py +++ b/frigate/config/classification.py @@ -11,6 +11,22 @@ __all__ = [ ] +class BirdClassificationConfig(FrigateBaseModel): + enabled: bool = Field(default=False, title="Enable bird classification.") + threshold: float = Field( + default=0.9, + title="Minimum classification score required to be considered a match.", + gt=0.0, + le=1.0, + ) + + +class ClassificationConfig(FrigateBaseModel): + bird: BirdClassificationConfig = Field( + default_factory=BirdClassificationConfig, title="Bird classification config." + ) + + class SemanticSearchConfig(FrigateBaseModel): enabled: bool = Field(default=False, title="Enable semantic search.") reindex: Optional[bool] = Field( diff --git a/frigate/config/config.py b/frigate/config/config.py index c4247e6f2..f3b17c5fa 100644 --- a/frigate/config/config.py +++ b/frigate/config/config.py @@ -51,17 +51,18 @@ from .camera.review import ReviewConfig from .camera.snapshots import SnapshotsConfig from .camera.timestamp import TimestampStyleConfig from .camera_group import CameraGroupConfig +from .classification import ( + ClassificationConfig, + FaceRecognitionConfig, + LicensePlateRecognitionConfig, + SemanticSearchConfig, +) from .database import DatabaseConfig from .env import EnvVars from .logger import LoggerConfig from .mqtt import MqttConfig from .notification import NotificationConfig from .proxy import ProxyConfig -from .semantic_search import ( - FaceRecognitionConfig, - LicensePlateRecognitionConfig, - SemanticSearchConfig, -) from .telemetry import TelemetryConfig from .tls import TlsConfig from .ui import UIConfig @@ -331,6 +332,9 @@ class FrigateConfig(FrigateBaseModel): default_factory=TelemetryConfig, title="Telemetry configuration." ) tls: TlsConfig = Field(default_factory=TlsConfig, title="TLS configuration.") + classification: ClassificationConfig = Field( + default_factory=ClassificationConfig, title="Object classification config." + ) semantic_search: SemanticSearchConfig = Field( default_factory=SemanticSearchConfig, title="Semantic search configuration." ) diff --git a/frigate/data_processing/real_time/bird_processor.py b/frigate/data_processing/real_time/bird_processor.py index aa9b11984..e432a186b 100644 --- a/frigate/data_processing/real_time/bird_processor.py +++ b/frigate/data_processing/real_time/bird_processor.py @@ -123,6 +123,11 @@ class BirdProcessor(RealTimeProcessorApi): return score = round(probs[best_id], 2) + + if score < self.config.classification.bird.threshold: + logger.debug(f"Score {score} is not above required threshold") + return + previous_score = self.detected_birds.get(obj_data["id"], 0.0) if score <= previous_score: @@ -145,4 +150,5 @@ class BirdProcessor(RealTimeProcessorApi): return None def expire_object(self, object_id): - pass + if object_id in self.detected_birds: + self.detected_birds.pop(object_id) diff --git a/frigate/embeddings/lpr/lpr.py b/frigate/embeddings/lpr/lpr.py index 16eba9989..d7e513c73 100644 --- a/frigate/embeddings/lpr/lpr.py +++ b/frigate/embeddings/lpr/lpr.py @@ -8,7 +8,7 @@ from pyclipper import ET_CLOSEDPOLYGON, JT_ROUND, PyclipperOffset from shapely.geometry import Polygon from frigate.comms.inter_process import InterProcessRequestor -from frigate.config.semantic_search import LicensePlateRecognitionConfig +from frigate.config.classification import LicensePlateRecognitionConfig from frigate.embeddings.embeddings import Embeddings logger = logging.getLogger(__name__) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 671df4917..aa0322fd7 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -79,7 +79,8 @@ class EmbeddingMaintainer(threading.Thread): if self.config.face_recognition.enabled: self.processors.append(FaceProcessor(self.config, metrics)) - self.processors.append(BirdProcessor(self.config, metrics)) + if self.config.classification.bird.enabled: + self.processors.append(BirdProcessor(self.config, metrics)) # create communication for updating event descriptions self.requestor = InterProcessRequestor()