mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-01 19:17:41 +03:00
Adjust directories
This commit is contained in:
parent
89433a71e0
commit
d014922a99
@ -42,7 +42,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
self.model_config = model_config
|
||||
self.requestor = requestor
|
||||
self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name)
|
||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name)
|
||||
self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train")
|
||||
self.interpreter: Interpreter = None
|
||||
self.tensor_input_details: dict[str, Any] = None
|
||||
self.tensor_output_details: dict[str, Any] = None
|
||||
|
||||
@ -9,6 +9,8 @@ from tensorflow.keras import layers, models, optimizers
|
||||
from tensorflow.keras.applications import MobileNetV2
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 50
|
||||
LEARNING_RATE = 0.001
|
||||
@ -35,9 +37,10 @@ def generate_representative_dataset_factory(dataset_dir: str):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def train_classification_model(model_dir: str) -> bool:
|
||||
def train_classification_model(model_name: str) -> bool:
|
||||
"""Train a classification model."""
|
||||
dataset_dir = os.path.join(model_dir, "dataset")
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
model_dir = os.path.join(MODEL_CACHE_DIR, model_name)
|
||||
num_classes = len(
|
||||
[
|
||||
d
|
||||
|
||||
Loading…
Reference in New Issue
Block a user