Improve model training and creation process

This commit is contained in:
Nicolas Mowen 2025-10-23 06:39:33 -06:00
parent 9160c5168e
commit 55dc2bd211
7 changed files with 100 additions and 70 deletions

View File

@ -387,27 +387,28 @@ def config_set(request: Request, body: AppConfigSetBody):
old_config: FrigateConfig = request.app.frigate_config old_config: FrigateConfig = request.app.frigate_config
request.app.frigate_config = config request.app.frigate_config = config
if body.update_topic and body.update_topic.startswith("config/cameras/"): if body.update_topic:
_, _, camera, field = body.update_topic.split("/") if body.update_topic.startswith("config/cameras/"):
_, _, camera, field = body.update_topic.split("/")
if field == "add": if field == "add":
settings = config.cameras[camera] settings = config.cameras[camera]
elif field == "remove": elif field == "remove":
settings = old_config.cameras[camera] settings = old_config.cameras[camera]
else: else:
settings = config.get_nested_object(body.update_topic) settings = config.get_nested_object(body.update_topic)
request.app.config_publisher.publish_update( request.app.config_publisher.publish_update(
CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera), CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera),
settings, settings,
)
elif body.update_topic and "/config/" in body.update_topic[1:]:
# Handle nested config updates (e.g., config/classification/custom/{name})
settings = config.get_nested_object(body.update_topic)
if settings:
request.app.config_publisher.publisher.publish(
body.update_topic, settings
) )
else:
# Handle nested config updates (e.g., config/classification/custom/{name})
settings = config.get_nested_object(body.update_topic)
if settings:
request.app.config_publisher.publisher.publish(
body.update_topic, settings
)
return JSONResponse( return JSONResponse(
content=( content=(

View File

@ -3,7 +3,9 @@
import datetime import datetime
import logging import logging
import os import os
import random
import shutil import shutil
import string
from typing import Any from typing import Any
import cv2 import cv2
@ -707,7 +709,9 @@ def categorize_classification_image(request: Request, name: str, body: dict = No
status_code=404, status_code=404,
) )
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png" random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
timestamp = datetime.datetime.now().timestamp()
new_name = f"{category}-{timestamp}-{random_id}.png"
new_file_folder = os.path.join( new_file_folder = os.path.join(
CLIPS_DIR, sanitize_filename(name), "dataset", category CLIPS_DIR, sanitize_filename(name), "dataset", category
) )

View File

@ -353,6 +353,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
) )
def handle_request(self, topic, request_data): def handle_request(self, topic, request_data):
logger.info(f"comparing with topic {topic} and the {request_data.get('model_name')} and the {self.model_config.name}")
if topic == EmbeddingsRequestEnum.reload_classification_model.value: if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name: if request_data.get("model_name") == self.model_config.name:
logger.info( logger.info(

View File

@ -286,16 +286,11 @@ class EmbeddingMaintainer(threading.Thread):
topic, model_config = self.classification_config_subscriber.check_for_update() topic, model_config = self.classification_config_subscriber.check_for_update()
if topic and model_config: if topic and model_config:
# Extract model name from topic: config/classification/custom/{model_name}
model_name = topic.split("/")[-1] model_name = topic.split("/")[-1]
logger.info(
f"Received classification config update for model: {model_name}"
)
self.config.classification.custom[model_name] = model_config self.config.classification.custom[model_name] = model_config
existing_processor_index = None
for i, processor in enumerate(self.realtime_processors): # Check if processor already exists
for processor in self.realtime_processors:
if isinstance( if isinstance(
processor, processor,
( (
@ -304,14 +299,10 @@ class EmbeddingMaintainer(threading.Thread):
), ),
): ):
if processor.model_config.name == model_name: if processor.model_config.name == model_name:
existing_processor_index = i logger.debug(
break f"Classification processor for model {model_name} already exists, skipping"
)
if existing_processor_index is not None: return
logger.info(
f"Removing existing classification processor for model: {model_name}"
)
self.realtime_processors.pop(existing_processor_index)
if model_config.state_config is not None: if model_config.state_config is not None:
processor = CustomStateClassificationProcessor( processor = CustomStateClassificationProcessor(
@ -326,7 +317,9 @@ class EmbeddingMaintainer(threading.Thread):
) )
self.realtime_processors.append(processor) self.realtime_processors.append(processor)
logger.info(f"Added classification processor for model: {model_name}") logger.info(
f"Added classification processor for model: {model_name} (type: {type(processor).__name__})"
)
def _process_requests(self) -> None: def _process_requests(self) -> None:
"""Process embeddings requests""" """Process embeddings requests"""

View File

@ -101,6 +101,10 @@
"title": "Generating Sample Images", "title": "Generating Sample Images",
"description": "We're pulling representative images from your recordings. This may take a moment..." "description": "We're pulling representative images from your recordings. This may take a moment..."
}, },
"training": {
"title": "Training Model",
"description": "Your model is being trained in the background. You can close this wizard and the training will continue."
},
"retryGenerate": "Retry Generation", "retryGenerate": "Retry Generation",
"selectClass": "Select class...", "selectClass": "Select class...",
"none": "None", "none": "None",

View File

@ -46,6 +46,7 @@ export default function Step3ChooseExamples({
const [imageClassifications, setImageClassifications] = useState<{ const [imageClassifications, setImageClassifications] = useState<{
[imageName: string]: string; [imageName: string]: string;
}>(initialData?.imageClassifications || {}); }>(initialData?.imageClassifications || {});
const [isTraining, setIsTraining] = useState(false);
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>( const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
hasGenerated ? `classification/${step1Data.modelName}/train` : null, hasGenerated ? `classification/${step1Data.modelName}/train` : null,
@ -186,25 +187,25 @@ export default function Step3ChooseExamples({
}); });
// Step 2: Classify each image by moving it to the correct category folder // Step 2: Classify each image by moving it to the correct category folder
for (const [imageName, className] of Object.entries( const categorizePromises = Object.entries(imageClassifications).map(
imageClassifications, ([imageName, className]) => {
)) { if (!className) return Promise.resolve();
if (!className) continue; return axios.post(
`/classification/${step1Data.modelName}/dataset/categorize`,
await axios.post( {
`/classification/${step1Data.modelName}/dataset/categorize`, training_file: imageName,
{ category: className === "none" ? "none" : className,
training_file: imageName, },
category: className === "none" ? "none" : className, );
}, },
); );
} await Promise.all(categorizePromises);
// Step 3: Kick off training // Step 3: Kick off training
await axios.post(`/classification/${step1Data.modelName}/train`); await axios.post(`/classification/${step1Data.modelName}/train`);
toast.success(t("wizard.step3.trainingStarted")); toast.success(t("wizard.step3.trainingStarted"));
onClose(); setIsTraining(true);
} catch (error) { } catch (error) {
const axiosError = error as { const axiosError = error as {
response?: { data?: { message?: string; detail?: string } }; response?: { data?: { message?: string; detail?: string } };
@ -220,7 +221,7 @@ export default function Step3ChooseExamples({
t("wizard.step3.errors.classifyFailed", { error: errorMessage }), t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
); );
} }
}, [onClose, imageClassifications, step1Data, step2Data, t]); }, [imageClassifications, step1Data, step2Data, t]);
const allImagesClassified = useMemo(() => { const allImagesClassified = useMemo(() => {
if (!unknownImages || unknownImages.length === 0) return false; if (!unknownImages || unknownImages.length === 0) return false;
@ -230,7 +231,22 @@ export default function Step3ChooseExamples({
return ( return (
<div className="flex flex-col gap-6"> <div className="flex flex-col gap-6">
{isGenerating ? ( {isTraining ? (
<div className="flex flex-col items-center gap-6 py-12">
<ActivityIndicator className="size-12" />
<div className="text-center">
<h3 className="mb-2 text-lg font-medium">
{t("wizard.step3.training.title")}
</h3>
<p className="text-sm text-muted-foreground">
{t("wizard.step3.training.description")}
</p>
</div>
<Button onClick={onClose} variant="select" className="mt-4">
{t("button.close", { ns: "common" })}
</Button>
</div>
) : isGenerating ? (
<div className="flex h-[50vh] flex-col items-center justify-center gap-4"> <div className="flex h-[50vh] flex-col items-center justify-center gap-4">
<ActivityIndicator className="size-12" /> <ActivityIndicator className="size-12" />
<div className="text-center"> <div className="text-center">
@ -321,20 +337,22 @@ export default function Step3ChooseExamples({
</div> </div>
)} )}
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4"> {!isTraining && (
<Button type="button" onClick={onBack} className="sm:flex-1"> <div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
{t("button.back", { ns: "common" })} <Button type="button" onClick={onBack} className="sm:flex-1">
</Button> {t("button.back", { ns: "common" })}
<Button </Button>
type="button" <Button
onClick={handleContinue} type="button"
variant="select" onClick={handleContinue}
className="flex items-center justify-center gap-2 sm:flex-1" variant="select"
disabled={!hasGenerated || isGenerating || !allImagesClassified} className="flex items-center justify-center gap-2 sm:flex-1"
> disabled={!hasGenerated || isGenerating || !allImagesClassified}
{t("button.continue", { ns: "common" })} >
</Button> {t("button.continue", { ns: "common" })}
</div> </Button>
</div>
)}
</div> </div>
); );
} }

View File

@ -30,9 +30,12 @@ export default function ModelSelectionView({
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
const [page, setPage] = useState<ModelType>("objects"); const [page, setPage] = useState<ModelType>("objects");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
const { data: config } = useSWR<FrigateConfig>("config", { const { data: config, mutate: refreshConfig } = useSWR<FrigateConfig>(
revalidateOnFocus: false, "config",
}); {
revalidateOnFocus: false,
},
);
// data // data
@ -71,7 +74,10 @@ export default function ModelSelectionView({
<> <>
<ClassificationModelWizardDialog <ClassificationModelWizardDialog
open={newModel} open={newModel}
onClose={() => setNewModel(false)} onClose={() => {
setNewModel(false);
refreshConfig();
}}
/> />
<NoModelsView onCreateModel={() => setNewModel(true)} />; <NoModelsView onCreateModel={() => setNewModel(true)} />;
</> </>
@ -82,7 +88,10 @@ export default function ModelSelectionView({
<div className="flex size-full flex-col p-2"> <div className="flex size-full flex-col p-2">
<ClassificationModelWizardDialog <ClassificationModelWizardDialog
open={newModel} open={newModel}
onClose={() => setNewModel(false)} onClose={() => {
setNewModel(false);
refreshConfig();
}}
/> />
<div className="flex h-12 w-full items-center justify-between"> <div className="flex h-12 w-full items-center justify-between">