Fix order of init

This commit is contained in:
Nicolas Mowen 2024-10-10 14:32:34 -06:00
parent 395cf4e33f
commit 72126e515b
6 changed files with 34 additions and 60 deletions

View File

@ -581,12 +581,12 @@ class FrigateApp:
self.init_recording_manager() self.init_recording_manager()
self.init_review_segment_manager() self.init_review_segment_manager()
self.init_go2rtc() self.init_go2rtc()
self.start_detectors()
self.init_embeddings_manager()
self.bind_database() self.bind_database()
self.check_db_data_migrations() self.check_db_data_migrations()
self.init_inter_process_communicator() self.init_inter_process_communicator()
self.init_dispatcher() self.init_dispatcher()
self.start_detectors()
self.init_embeddings_manager()
self.init_embeddings_client() self.init_embeddings_client()
self.start_video_output_processor() self.start_video_output_processor()
self.start_ptz_autotracker() self.start_ptz_autotracker()

View File

@ -64,6 +64,9 @@ class Dispatcher:
self.onvif = onvif self.onvif = onvif
self.ptz_metrics = ptz_metrics self.ptz_metrics = ptz_metrics
self.comms = communicators self.comms = communicators
self.camera_activity = {}
self.model_state = {}
self.embeddings_reindex = {}
self._camera_settings_handlers: dict[str, Callable] = { self._camera_settings_handlers: dict[str, Callable] = {
"audio": self._on_audio_command, "audio": self._on_audio_command,
@ -85,10 +88,6 @@ class Dispatcher:
for comm in self.comms: for comm in self.comms:
comm.subscribe(self._receive) comm.subscribe(self._receive)
self.camera_activity = {}
self.model_state = {}
self.embeddings_reindex = {}
def _receive(self, topic: str, payload: str) -> Optional[Any]: def _receive(self, topic: str, payload: str) -> Optional[Any]:
"""Handle receiving of payload from communicators.""" """Handle receiving of payload from communicators."""

View File

@ -43,7 +43,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
listen() listen()
# Configure Frigate DB # Configure Frigate DB
print("connecting to db in embed")
db = SqliteVecQueueDatabase( db = SqliteVecQueueDatabase(
config.database.path, config.database.path,
pragmas={ pragmas={
@ -55,10 +54,8 @@ def manage_embeddings(config: FrigateConfig) -> None:
load_vec_extension=True, load_vec_extension=True,
) )
models = [Event] models = [Event]
print("binding db to model")
db.bind(models) db.bind(models)
print("creating embedding maintainer")
maintainer = EmbeddingMaintainer( maintainer = EmbeddingMaintainer(
db, db,
config, config,

View File

@ -61,37 +61,32 @@ class GenericONNXEmbedding:
self.tokenizer = None self.tokenizer = None
self.feature_extractor = None self.feature_extractor = None
self.session = None self.session = None
files_names = list(self.download_urls.keys()) + (
[self.tokenizer_file] if self.tokenizer_file else []
)
if not all( if not all(
os.path.exists(os.path.join(self.download_path, n)) os.path.exists(os.path.join(self.download_path, n)) for n in files_names
for n in self.download_urls.keys()
): ):
print("starting model download") logger.debug(f"starting model download for {self.model_name}")
self.downloader = ModelDownloader( self.downloader = ModelDownloader(
model_name=self.model_name, model_name=self.model_name,
download_path=self.download_path, download_path=self.download_path,
file_names=list(self.download_urls.keys()) file_names=files_names,
+ ([self.tokenizer_file] if self.tokenizer_file else []),
requestor=self.requestor, requestor=self.requestor,
download_func=self._download_model, download_func=self._download_model,
) )
self.downloader.ensure_model_files() self.downloader.ensure_model_files()
else: else:
self.downloader = None self.downloader = None
for file_name in self.download_urls.keys(): ModelDownloader.mark_files_downloaded(
self.requestor.send_data( self.requestor, self.model_name, files_names
UPDATE_MODEL_STATE, )
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
},
)
self._load_model_and_tokenizer() self._load_model_and_tokenizer()
print("models are already downloaded") print(f"models are already downloaded for {self.model_name}")
def _download_model(self, path: str): def _download_model(self, path: str):
try: try:
print("beginning model download process")
file_name = os.path.basename(path) file_name = os.path.basename(path)
if file_name in self.download_urls: if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path) ModelDownloader.download_from_url(self.download_urls[file_name], path)
@ -130,11 +125,9 @@ class GenericONNXEmbedding:
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
else: else:
self.feature_extractor = self._load_feature_extractor() self.feature_extractor = self._load_feature_extractor()
print("creating onnx session")
self.session = self._load_model( self.session = self._load_model(
os.path.join(self.download_path, self.model_file) os.path.join(self.download_path, self.model_file)
) )
print("successfully loaded model.")
def _load_tokenizer(self): def _load_tokenizer(self):
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer") tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
@ -151,16 +144,11 @@ class GenericONNXEmbedding:
) )
def _load_model(self, path: str) -> Optional[ort.InferenceSession]: def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
print(f"checking if path exists {path}")
if os.path.exists(path): if os.path.exists(path):
print(
f"loading ORT session with providers {self.providers} and options {self.provider_options}"
)
return ort.InferenceSession( return ort.InferenceSession(
path, providers=self.providers, provider_options=self.provider_options path, providers=self.providers, provider_options=self.provider_options
) )
else: else:
print(f"{self.model_name} model file {path} not found.")
return None return None
def _process_image(self, image): def _process_image(self, image):
@ -174,7 +162,6 @@ class GenericONNXEmbedding:
def __call__( def __call__(
self, inputs: Union[List[str], List[Image.Image], List[str]] self, inputs: Union[List[str], List[Image.Image], List[str]]
) -> List[np.ndarray]: ) -> List[np.ndarray]:
print("beginning call for onnx embedding")
self._load_model_and_tokenizer() self._load_model_and_tokenizer()
if self.session is None or ( if self.session is None or (

View File

@ -43,9 +43,7 @@ class EmbeddingMaintainer(threading.Thread):
) -> None: ) -> None:
super().__init__(name="embeddings_maintainer") super().__init__(name="embeddings_maintainer")
self.config = config self.config = config
print("creating embeddings")
self.embeddings = Embeddings(config.semantic_search, db) self.embeddings = Embeddings(config.semantic_search, db)
print("finished creating embeddings")
# Check if we need to re-index events # Check if we need to re-index events
if config.semantic_search.reindex: if config.semantic_search.reindex:
@ -63,12 +61,10 @@ class EmbeddingMaintainer(threading.Thread):
self.stop_event = stop_event self.stop_event = stop_event
self.tracked_events = {} self.tracked_events = {}
self.genai_client = get_genai_client(config.genai) self.genai_client = get_genai_client(config.genai)
print("finished embed maintainer setup")
def run(self) -> None: def run(self) -> None:
"""Maintain a SQLite-vec database for semantic search.""" """Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
print("Doing another embeddings loop.")
self._process_requests() self._process_requests()
self._process_updates() self._process_updates()
self._process_finalized() self._process_finalized()
@ -85,8 +81,6 @@ class EmbeddingMaintainer(threading.Thread):
"""Process embeddings requests""" """Process embeddings requests"""
def _handle_request(topic: str, data: str) -> str: def _handle_request(topic: str, data: str) -> str:
print(f"Handling embeddings request of type {topic} with data {data}")
try: try:
if topic == EmbeddingsRequestEnum.embed_description.value: if topic == EmbeddingsRequestEnum.embed_description.value:
return serialize( return serialize(
@ -122,7 +116,6 @@ class EmbeddingMaintainer(threading.Thread):
if not camera or source_type != EventTypeEnum.tracked_object: if not camera or source_type != EventTypeEnum.tracked_object:
return return
print(f"Processing object update of type {source_type} on {camera}")
camera_config = self.config.cameras[camera] camera_config = self.config.cameras[camera]
if data["id"] not in self.tracked_events: if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = [] self.tracked_events[data["id"]] = []
@ -130,23 +123,15 @@ class EmbeddingMaintainer(threading.Thread):
# Create our own thumbnail based on the bounding box and the frame time # Create our own thumbnail based on the bounding box and the frame time
try: try:
frame_id = f"{camera}{data['frame_time']}" frame_id = f"{camera}{data['frame_time']}"
print("trying to get frame from manager")
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv) yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
print(f"got frame from manager and it is valid {yuv_frame is not None}")
if yuv_frame is not None: if yuv_frame is not None:
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"]) data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
self.tracked_events[data["id"]].append(data) self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_id) self.frame_manager.close(frame_id)
else:
print(
f"Unable to create embedding for thumbnail from {camera} because frame is missing."
)
except FileNotFoundError: except FileNotFoundError:
pass pass
print("Finished processing object update")
def _process_finalized(self) -> None: def _process_finalized(self) -> None:
"""Process the end of an event.""" """Process the end of an event."""
while True: while True:
@ -156,9 +141,6 @@ class EmbeddingMaintainer(threading.Thread):
break break
event_id, camera, updated_db = ended event_id, camera, updated_db = ended
print(
f"Processing finalized event for {camera} which updated the db: {updated_db}"
)
camera_config = self.config.cameras[camera] camera_config = self.config.cameras[camera]
if updated_db: if updated_db:
@ -190,9 +172,6 @@ class EmbeddingMaintainer(threading.Thread):
or set(event.zones) & set(camera_config.genai.required_zones) or set(event.zones) & set(camera_config.genai.required_zones)
) )
): ):
print(
f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
)
if event.has_snapshot and camera_config.genai.use_snapshot: if event.has_snapshot and camera_config.genai.use_snapshot:
with open( with open(
os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"), os.path.join(CLIPS_DIR, f"{event.camera}-{event.id}.jpg"),
@ -245,13 +224,13 @@ class EmbeddingMaintainer(threading.Thread):
def _process_event_metadata(self): def _process_event_metadata(self):
# Check for regenerate description requests # Check for regenerate description requests
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(timeout=0.1) (topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
timeout=0.1
)
if topic is None: if topic is None:
return return
print(f"Handling event metadata for id {event_id} and source {source}")
if event_id: if event_id:
self.handle_regenerate_description(event_id, source) self.handle_regenerate_description(event_id, source)
@ -284,7 +263,7 @@ class EmbeddingMaintainer(threading.Thread):
) )
if not description: if not description:
print("Failed to generate description for %s", event.id) logger.debug("Failed to generate description for %s", event.id)
return return
# fire and forget description update # fire and forget description update
@ -296,7 +275,7 @@ class EmbeddingMaintainer(threading.Thread):
# Encode the description # Encode the description
self.embeddings.upsert_description(event.id, description) self.embeddings.upsert_description(event.id, description)
print( logger.debug(
"Generated description for %s (%d images): %s", "Generated description for %s (%d images): %s",
event.id, event.id,
len(thumbnails), len(thumbnails),
@ -317,7 +296,7 @@ class EmbeddingMaintainer(threading.Thread):
thumbnail = base64.b64decode(event.thumbnail) thumbnail = base64.b64decode(event.thumbnail)
print( logger.debug(
f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}" f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}"
) )

View File

@ -120,6 +120,18 @@ class ModelDownloader:
if not silent: if not silent:
logger.info(f"Downloading complete: {url}") logger.info(f"Downloading complete: {url}")
@staticmethod
def mark_files_downloaded(
requestor: InterProcessRequestor, model_name: str, files: list[str]
) -> None:
for file_name in files:
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
},
)
def wait_for_download(self): def wait_for_download(self):
print("waiting for model download")
self.download_complete.wait() self.download_complete.wait()