mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-19 14:48:22 +03:00
Improve model training and creation process
This commit is contained in:
parent
9160c5168e
commit
55dc2bd211
@ -387,27 +387,28 @@ def config_set(request: Request, body: AppConfigSetBody):
|
|||||||
old_config: FrigateConfig = request.app.frigate_config
|
old_config: FrigateConfig = request.app.frigate_config
|
||||||
request.app.frigate_config = config
|
request.app.frigate_config = config
|
||||||
|
|
||||||
if body.update_topic and body.update_topic.startswith("config/cameras/"):
|
if body.update_topic:
|
||||||
_, _, camera, field = body.update_topic.split("/")
|
if body.update_topic.startswith("config/cameras/"):
|
||||||
|
_, _, camera, field = body.update_topic.split("/")
|
||||||
|
|
||||||
if field == "add":
|
if field == "add":
|
||||||
settings = config.cameras[camera]
|
settings = config.cameras[camera]
|
||||||
elif field == "remove":
|
elif field == "remove":
|
||||||
settings = old_config.cameras[camera]
|
settings = old_config.cameras[camera]
|
||||||
else:
|
else:
|
||||||
settings = config.get_nested_object(body.update_topic)
|
settings = config.get_nested_object(body.update_topic)
|
||||||
|
|
||||||
request.app.config_publisher.publish_update(
|
request.app.config_publisher.publish_update(
|
||||||
CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera),
|
CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera),
|
||||||
settings,
|
settings,
|
||||||
)
|
|
||||||
elif body.update_topic and "/config/" in body.update_topic[1:]:
|
|
||||||
# Handle nested config updates (e.g., config/classification/custom/{name})
|
|
||||||
settings = config.get_nested_object(body.update_topic)
|
|
||||||
if settings:
|
|
||||||
request.app.config_publisher.publisher.publish(
|
|
||||||
body.update_topic, settings
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Handle nested config updates (e.g., config/classification/custom/{name})
|
||||||
|
settings = config.get_nested_object(body.update_topic)
|
||||||
|
if settings:
|
||||||
|
request.app.config_publisher.publisher.publish(
|
||||||
|
body.update_topic, settings
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content=(
|
content=(
|
||||||
|
|||||||
@ -3,7 +3,9 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -707,7 +709,9 @@ def categorize_classification_image(request: Request, name: str, body: dict = No
|
|||||||
status_code=404,
|
status_code=404,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png"
|
random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
|
||||||
|
timestamp = datetime.datetime.now().timestamp()
|
||||||
|
new_name = f"{category}-{timestamp}-{random_id}.png"
|
||||||
new_file_folder = os.path.join(
|
new_file_folder = os.path.join(
|
||||||
CLIPS_DIR, sanitize_filename(name), "dataset", category
|
CLIPS_DIR, sanitize_filename(name), "dataset", category
|
||||||
)
|
)
|
||||||
|
|||||||
@ -353,6 +353,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def handle_request(self, topic, request_data):
|
def handle_request(self, topic, request_data):
|
||||||
|
logger.info(f"comparing with topic {topic} and the {request_data.get('model_name')} and the {self.model_config.name}")
|
||||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||||
if request_data.get("model_name") == self.model_config.name:
|
if request_data.get("model_name") == self.model_config.name:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@ -286,16 +286,11 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
topic, model_config = self.classification_config_subscriber.check_for_update()
|
topic, model_config = self.classification_config_subscriber.check_for_update()
|
||||||
|
|
||||||
if topic and model_config:
|
if topic and model_config:
|
||||||
# Extract model name from topic: config/classification/custom/{model_name}
|
|
||||||
model_name = topic.split("/")[-1]
|
model_name = topic.split("/")[-1]
|
||||||
logger.info(
|
|
||||||
f"Received classification config update for model: {model_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config.classification.custom[model_name] = model_config
|
self.config.classification.custom[model_name] = model_config
|
||||||
existing_processor_index = None
|
|
||||||
|
|
||||||
for i, processor in enumerate(self.realtime_processors):
|
# Check if processor already exists
|
||||||
|
for processor in self.realtime_processors:
|
||||||
if isinstance(
|
if isinstance(
|
||||||
processor,
|
processor,
|
||||||
(
|
(
|
||||||
@ -304,14 +299,10 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
if processor.model_config.name == model_name:
|
if processor.model_config.name == model_name:
|
||||||
existing_processor_index = i
|
logger.debug(
|
||||||
break
|
f"Classification processor for model {model_name} already exists, skipping"
|
||||||
|
)
|
||||||
if existing_processor_index is not None:
|
return
|
||||||
logger.info(
|
|
||||||
f"Removing existing classification processor for model: {model_name}"
|
|
||||||
)
|
|
||||||
self.realtime_processors.pop(existing_processor_index)
|
|
||||||
|
|
||||||
if model_config.state_config is not None:
|
if model_config.state_config is not None:
|
||||||
processor = CustomStateClassificationProcessor(
|
processor = CustomStateClassificationProcessor(
|
||||||
@ -326,7 +317,9 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.realtime_processors.append(processor)
|
self.realtime_processors.append(processor)
|
||||||
logger.info(f"Added classification processor for model: {model_name}")
|
logger.info(
|
||||||
|
f"Added classification processor for model: {model_name} (type: {type(processor).__name__})"
|
||||||
|
)
|
||||||
|
|
||||||
def _process_requests(self) -> None:
|
def _process_requests(self) -> None:
|
||||||
"""Process embeddings requests"""
|
"""Process embeddings requests"""
|
||||||
|
|||||||
@ -101,6 +101,10 @@
|
|||||||
"title": "Generating Sample Images",
|
"title": "Generating Sample Images",
|
||||||
"description": "We're pulling representative images from your recordings. This may take a moment..."
|
"description": "We're pulling representative images from your recordings. This may take a moment..."
|
||||||
},
|
},
|
||||||
|
"training": {
|
||||||
|
"title": "Training Model",
|
||||||
|
"description": "Your model is being trained in the background. You can close this wizard and the training will continue."
|
||||||
|
},
|
||||||
"retryGenerate": "Retry Generation",
|
"retryGenerate": "Retry Generation",
|
||||||
"selectClass": "Select class...",
|
"selectClass": "Select class...",
|
||||||
"none": "None",
|
"none": "None",
|
||||||
|
|||||||
@ -46,6 +46,7 @@ export default function Step3ChooseExamples({
|
|||||||
const [imageClassifications, setImageClassifications] = useState<{
|
const [imageClassifications, setImageClassifications] = useState<{
|
||||||
[imageName: string]: string;
|
[imageName: string]: string;
|
||||||
}>(initialData?.imageClassifications || {});
|
}>(initialData?.imageClassifications || {});
|
||||||
|
const [isTraining, setIsTraining] = useState(false);
|
||||||
|
|
||||||
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
|
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
|
||||||
hasGenerated ? `classification/${step1Data.modelName}/train` : null,
|
hasGenerated ? `classification/${step1Data.modelName}/train` : null,
|
||||||
@ -186,25 +187,25 @@ export default function Step3ChooseExamples({
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Step 2: Classify each image by moving it to the correct category folder
|
// Step 2: Classify each image by moving it to the correct category folder
|
||||||
for (const [imageName, className] of Object.entries(
|
const categorizePromises = Object.entries(imageClassifications).map(
|
||||||
imageClassifications,
|
([imageName, className]) => {
|
||||||
)) {
|
if (!className) return Promise.resolve();
|
||||||
if (!className) continue;
|
return axios.post(
|
||||||
|
`/classification/${step1Data.modelName}/dataset/categorize`,
|
||||||
await axios.post(
|
{
|
||||||
`/classification/${step1Data.modelName}/dataset/categorize`,
|
training_file: imageName,
|
||||||
{
|
category: className === "none" ? "none" : className,
|
||||||
training_file: imageName,
|
},
|
||||||
category: className === "none" ? "none" : className,
|
);
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
await Promise.all(categorizePromises);
|
||||||
|
|
||||||
// Step 3: Kick off training
|
// Step 3: Kick off training
|
||||||
await axios.post(`/classification/${step1Data.modelName}/train`);
|
await axios.post(`/classification/${step1Data.modelName}/train`);
|
||||||
|
|
||||||
toast.success(t("wizard.step3.trainingStarted"));
|
toast.success(t("wizard.step3.trainingStarted"));
|
||||||
onClose();
|
setIsTraining(true);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const axiosError = error as {
|
const axiosError = error as {
|
||||||
response?: { data?: { message?: string; detail?: string } };
|
response?: { data?: { message?: string; detail?: string } };
|
||||||
@ -220,7 +221,7 @@ export default function Step3ChooseExamples({
|
|||||||
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
|
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}, [onClose, imageClassifications, step1Data, step2Data, t]);
|
}, [imageClassifications, step1Data, step2Data, t]);
|
||||||
|
|
||||||
const allImagesClassified = useMemo(() => {
|
const allImagesClassified = useMemo(() => {
|
||||||
if (!unknownImages || unknownImages.length === 0) return false;
|
if (!unknownImages || unknownImages.length === 0) return false;
|
||||||
@ -230,7 +231,22 @@ export default function Step3ChooseExamples({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-6">
|
<div className="flex flex-col gap-6">
|
||||||
{isGenerating ? (
|
{isTraining ? (
|
||||||
|
<div className="flex flex-col items-center gap-6 py-12">
|
||||||
|
<ActivityIndicator className="size-12" />
|
||||||
|
<div className="text-center">
|
||||||
|
<h3 className="mb-2 text-lg font-medium">
|
||||||
|
{t("wizard.step3.training.title")}
|
||||||
|
</h3>
|
||||||
|
<p className="text-sm text-muted-foreground">
|
||||||
|
{t("wizard.step3.training.description")}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<Button onClick={onClose} variant="select" className="mt-4">
|
||||||
|
{t("button.close", { ns: "common" })}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
) : isGenerating ? (
|
||||||
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
|
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
|
||||||
<ActivityIndicator className="size-12" />
|
<ActivityIndicator className="size-12" />
|
||||||
<div className="text-center">
|
<div className="text-center">
|
||||||
@ -321,20 +337,22 @@ export default function Step3ChooseExamples({
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
|
{!isTraining && (
|
||||||
<Button type="button" onClick={onBack} className="sm:flex-1">
|
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
|
||||||
{t("button.back", { ns: "common" })}
|
<Button type="button" onClick={onBack} className="sm:flex-1">
|
||||||
</Button>
|
{t("button.back", { ns: "common" })}
|
||||||
<Button
|
</Button>
|
||||||
type="button"
|
<Button
|
||||||
onClick={handleContinue}
|
type="button"
|
||||||
variant="select"
|
onClick={handleContinue}
|
||||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
variant="select"
|
||||||
disabled={!hasGenerated || isGenerating || !allImagesClassified}
|
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||||
>
|
disabled={!hasGenerated || isGenerating || !allImagesClassified}
|
||||||
{t("button.continue", { ns: "common" })}
|
>
|
||||||
</Button>
|
{t("button.continue", { ns: "common" })}
|
||||||
</div>
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,9 +30,12 @@ export default function ModelSelectionView({
|
|||||||
const { t } = useTranslation(["views/classificationModel"]);
|
const { t } = useTranslation(["views/classificationModel"]);
|
||||||
const [page, setPage] = useState<ModelType>("objects");
|
const [page, setPage] = useState<ModelType>("objects");
|
||||||
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
||||||
const { data: config } = useSWR<FrigateConfig>("config", {
|
const { data: config, mutate: refreshConfig } = useSWR<FrigateConfig>(
|
||||||
revalidateOnFocus: false,
|
"config",
|
||||||
});
|
{
|
||||||
|
revalidateOnFocus: false,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
// data
|
// data
|
||||||
|
|
||||||
@ -71,7 +74,10 @@ export default function ModelSelectionView({
|
|||||||
<>
|
<>
|
||||||
<ClassificationModelWizardDialog
|
<ClassificationModelWizardDialog
|
||||||
open={newModel}
|
open={newModel}
|
||||||
onClose={() => setNewModel(false)}
|
onClose={() => {
|
||||||
|
setNewModel(false);
|
||||||
|
refreshConfig();
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
<NoModelsView onCreateModel={() => setNewModel(true)} />;
|
<NoModelsView onCreateModel={() => setNewModel(true)} />;
|
||||||
</>
|
</>
|
||||||
@ -82,7 +88,10 @@ export default function ModelSelectionView({
|
|||||||
<div className="flex size-full flex-col p-2">
|
<div className="flex size-full flex-col p-2">
|
||||||
<ClassificationModelWizardDialog
|
<ClassificationModelWizardDialog
|
||||||
open={newModel}
|
open={newModel}
|
||||||
onClose={() => setNewModel(false)}
|
onClose={() => {
|
||||||
|
setNewModel(false);
|
||||||
|
refreshConfig();
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<div className="flex h-12 w-full items-center justify-between">
|
<div className="flex h-12 w-full items-center justify-between">
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user