diff --git a/frigate/genai/__init__.py b/frigate/genai/__init__.py index 7f0192912..821b449de 100644 --- a/frigate/genai/__init__.py +++ b/frigate/genai/__init__.py @@ -140,7 +140,12 @@ Each line represents a detection state, not necessarily unique individuals. Pare ) as f: f.write(context_prompt) - response = self._send(context_prompt, thumbnails) + json_schema = { + "name": "review_metadata", + "schema": ReviewMetadata.model_json_schema(), + "strict": True, + } + response = self._send(context_prompt, thumbnails, json_schema=json_schema) if debug_save and response: with open( @@ -152,6 +157,8 @@ Each line represents a detection state, not necessarily unique individuals. Pare f.write(response) if response: + # With JSON schema, response should already be valid JSON + # But keep regex cleanup as fallback for providers without schema support clean_json = re.sub( r"\n?```$", "", re.sub(r"^```[a-zA-Z0-9]*\n?", "", response) ) @@ -284,8 +291,16 @@ Guidelines: """Initialize the client.""" return None - def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: - """Submit a request to the provider.""" + def _send( + self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None + ) -> Optional[str]: + """Submit a request to the provider. + + Args: + prompt: The text prompt to send + images: List of image bytes to include + json_schema: Optional JSON schema for structured output (provider-specific support) + """ return None def get_context_size(self) -> int: diff --git a/frigate/genai/azure-openai.py b/frigate/genai/azure-openai.py index eb08f7786..9cc03fe75 100644 --- a/frigate/genai/azure-openai.py +++ b/frigate/genai/azure-openai.py @@ -41,29 +41,46 @@ class OpenAIClient(GenAIClient): azure_endpoint=azure_endpoint, ) - def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: + def _send( + self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None + ) -> Optional[str]: """Submit a request to Azure OpenAI.""" encoded_images = [base64.b64encode(image).decode("utf-8") for image in images] + + request_params = { + "model": self.genai_config.model, + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}] + + [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image}", + "detail": "low", + }, + } + for image in encoded_images + ], + }, + ], + "timeout": self.timeout, + } + + if json_schema: + request_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": json_schema.get("name", "response"), + "schema": json_schema.get("schema", {}), + "strict": json_schema.get("strict", True), + }, + } + try: result = self.provider.chat.completions.create( - model=self.genai_config.model, - messages=[ - { - "role": "user", - "content": [{"type": "text", "text": prompt}] - + [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image}", - "detail": "low", - }, - } - for image in encoded_images - ], - }, - ], - timeout=self.timeout, + **request_params, **self.genai_config.runtime_options, ) except Exception as e: diff --git a/frigate/genai/gemini.py b/frigate/genai/gemini.py index b700c33a4..1212f15ad 100644 --- a/frigate/genai/gemini.py +++ b/frigate/genai/gemini.py @@ -41,7 +41,9 @@ class GeminiClient(GenAIClient): http_options=types.HttpOptions(**http_options_dict), ) - def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: + def _send( + self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None + ) -> Optional[str]: """Submit a request to Gemini.""" contents = [ types.Part.from_bytes(data=img, mime_type="image/jpeg") for img in images @@ -51,6 +53,12 @@ class GeminiClient(GenAIClient): generation_config_dict = {"candidate_count": 1} generation_config_dict.update(self.genai_config.runtime_options) + if json_schema and "schema" in json_schema: + generation_config_dict["response_mime_type"] = "application/json" + generation_config_dict["response_schema"] = types.Schema( + json_schema=json_schema["schema"] + ) + response = self.provider.models.generate_content( model=self.genai_config.model, contents=contents, diff --git a/frigate/genai/ollama.py b/frigate/genai/ollama.py index ab6d3c0b3..f798cbd19 100644 --- a/frigate/genai/ollama.py +++ b/frigate/genai/ollama.py @@ -50,7 +50,9 @@ class OllamaClient(GenAIClient): logger.warning("Error initializing Ollama: %s", str(e)) return None - def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: + def _send( + self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None + ) -> Optional[str]: """Submit a request to Ollama""" if self.provider is None: logger.warning( @@ -62,6 +64,10 @@ class OllamaClient(GenAIClient): **self.provider_options, **self.genai_config.runtime_options, } + + if json_schema and "schema" in json_schema: + ollama_options["format"] = json_schema["schema"] + result = self.provider.generate( self.genai_config.model, prompt, diff --git a/frigate/genai/openai.py b/frigate/genai/openai.py index 1fb0dd852..a5ee6455e 100644 --- a/frigate/genai/openai.py +++ b/frigate/genai/openai.py @@ -31,7 +31,9 @@ class OpenAIClient(GenAIClient): } return OpenAI(api_key=self.genai_config.api_key, **provider_opts) - def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: + def _send( + self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None + ) -> Optional[str]: """Submit a request to OpenAI.""" encoded_images = [base64.b64encode(image).decode("utf-8") for image in images] messages_content = [] @@ -51,16 +53,31 @@ class OpenAIClient(GenAIClient): "text": prompt, } ) + + request_params = { + "model": self.genai_config.model, + "messages": [ + { + "role": "user", + "content": messages_content, + }, + ], + "timeout": self.timeout, + } + + if json_schema: + request_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": json_schema.get("name", "response"), + "schema": json_schema.get("schema", {}), + "strict": json_schema.get("strict", True), + }, + } + try: result = self.provider.chat.completions.create( - model=self.genai_config.model, - messages=[ - { - "role": "user", - "content": messages_content, - }, - ], - timeout=self.timeout, + **request_params, **self.genai_config.runtime_options, ) if (