Set json schema for genai

This commit is contained in:
Nicolas Mowen 2026-01-25 18:50:14 -07:00
parent 03ff091a71
commit 50e4601358
5 changed files with 96 additions and 33 deletions

View File

@ -140,7 +140,12 @@ Each line represents a detection state, not necessarily unique individuals. Pare
) as f: ) as f:
f.write(context_prompt) f.write(context_prompt)
response = self._send(context_prompt, thumbnails) json_schema = {
"name": "review_metadata",
"schema": ReviewMetadata.model_json_schema(),
"strict": True,
}
response = self._send(context_prompt, thumbnails, json_schema=json_schema)
if debug_save and response: if debug_save and response:
with open( with open(
@ -152,6 +157,8 @@ Each line represents a detection state, not necessarily unique individuals. Pare
f.write(response) f.write(response)
if response: if response:
# With JSON schema, response should already be valid JSON
# But keep regex cleanup as fallback for providers without schema support
clean_json = re.sub( clean_json = re.sub(
r"\n?```$", "", re.sub(r"^```[a-zA-Z0-9]*\n?", "", response) r"\n?```$", "", re.sub(r"^```[a-zA-Z0-9]*\n?", "", response)
) )
@ -284,8 +291,16 @@ Guidelines:
"""Initialize the client.""" """Initialize the client."""
return None return None
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: def _send(
"""Submit a request to the provider.""" self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None
) -> Optional[str]:
"""Submit a request to the provider.
Args:
prompt: The text prompt to send
images: List of image bytes to include
json_schema: Optional JSON schema for structured output (provider-specific support)
"""
return None return None
def get_context_size(self) -> int: def get_context_size(self) -> int:

View File

@ -41,29 +41,46 @@ class OpenAIClient(GenAIClient):
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
) )
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: def _send(
self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None
) -> Optional[str]:
"""Submit a request to Azure OpenAI.""" """Submit a request to Azure OpenAI."""
encoded_images = [base64.b64encode(image).decode("utf-8") for image in images] encoded_images = [base64.b64encode(image).decode("utf-8") for image in images]
request_params = {
"model": self.genai_config.model,
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image}",
"detail": "low",
},
}
for image in encoded_images
],
},
],
"timeout": self.timeout,
}
if json_schema:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": json_schema.get("name", "response"),
"schema": json_schema.get("schema", {}),
"strict": json_schema.get("strict", True),
},
}
try: try:
result = self.provider.chat.completions.create( result = self.provider.chat.completions.create(
model=self.genai_config.model, **request_params,
messages=[
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image}",
"detail": "low",
},
}
for image in encoded_images
],
},
],
timeout=self.timeout,
**self.genai_config.runtime_options, **self.genai_config.runtime_options,
) )
except Exception as e: except Exception as e:

View File

@ -41,7 +41,9 @@ class GeminiClient(GenAIClient):
http_options=types.HttpOptions(**http_options_dict), http_options=types.HttpOptions(**http_options_dict),
) )
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: def _send(
self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None
) -> Optional[str]:
"""Submit a request to Gemini.""" """Submit a request to Gemini."""
contents = [ contents = [
types.Part.from_bytes(data=img, mime_type="image/jpeg") for img in images types.Part.from_bytes(data=img, mime_type="image/jpeg") for img in images
@ -51,6 +53,12 @@ class GeminiClient(GenAIClient):
generation_config_dict = {"candidate_count": 1} generation_config_dict = {"candidate_count": 1}
generation_config_dict.update(self.genai_config.runtime_options) generation_config_dict.update(self.genai_config.runtime_options)
if json_schema and "schema" in json_schema:
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = types.Schema(
json_schema=json_schema["schema"]
)
response = self.provider.models.generate_content( response = self.provider.models.generate_content(
model=self.genai_config.model, model=self.genai_config.model,
contents=contents, contents=contents,

View File

@ -50,7 +50,9 @@ class OllamaClient(GenAIClient):
logger.warning("Error initializing Ollama: %s", str(e)) logger.warning("Error initializing Ollama: %s", str(e))
return None return None
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: def _send(
self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None
) -> Optional[str]:
"""Submit a request to Ollama""" """Submit a request to Ollama"""
if self.provider is None: if self.provider is None:
logger.warning( logger.warning(
@ -62,6 +64,10 @@ class OllamaClient(GenAIClient):
**self.provider_options, **self.provider_options,
**self.genai_config.runtime_options, **self.genai_config.runtime_options,
} }
if json_schema and "schema" in json_schema:
ollama_options["format"] = json_schema["schema"]
result = self.provider.generate( result = self.provider.generate(
self.genai_config.model, self.genai_config.model,
prompt, prompt,

View File

@ -31,7 +31,9 @@ class OpenAIClient(GenAIClient):
} }
return OpenAI(api_key=self.genai_config.api_key, **provider_opts) return OpenAI(api_key=self.genai_config.api_key, **provider_opts)
def _send(self, prompt: str, images: list[bytes]) -> Optional[str]: def _send(
self, prompt: str, images: list[bytes], json_schema: Optional[dict] = None
) -> Optional[str]:
"""Submit a request to OpenAI.""" """Submit a request to OpenAI."""
encoded_images = [base64.b64encode(image).decode("utf-8") for image in images] encoded_images = [base64.b64encode(image).decode("utf-8") for image in images]
messages_content = [] messages_content = []
@ -51,16 +53,31 @@ class OpenAIClient(GenAIClient):
"text": prompt, "text": prompt,
} }
) )
request_params = {
"model": self.genai_config.model,
"messages": [
{
"role": "user",
"content": messages_content,
},
],
"timeout": self.timeout,
}
if json_schema:
request_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": json_schema.get("name", "response"),
"schema": json_schema.get("schema", {}),
"strict": json_schema.get("strict", True),
},
}
try: try:
result = self.provider.chat.completions.create( result = self.provider.chat.completions.create(
model=self.genai_config.model, **request_params,
messages=[
{
"role": "user",
"content": messages_content,
},
],
timeout=self.timeout,
**self.genai_config.runtime_options, **self.genai_config.runtime_options,
) )
if ( if (