From d594e9d9a64d71f60498f160c12c360bd02a8246 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:19:45 -0500 Subject: [PATCH] Add script to investigate accuracy of classification training set (#23324) * add classification testing script * add caveat --- testing-scripts/object_dataset.py | 1022 +++++++++++++++++++++++++++++ 1 file changed, 1022 insertions(+) create mode 100644 testing-scripts/object_dataset.py diff --git a/testing-scripts/object_dataset.py b/testing-scripts/object_dataset.py new file mode 100644 index 0000000000..dd4071c441 --- /dev/null +++ b/testing-scripts/object_dataset.py @@ -0,0 +1,1022 @@ +""" +Object classification investigation script. + +Standalone replica of Frigate's custom object classification inference pipeline +(see frigate/data_processing/real_time/custom_classification.py and +frigate/util/classification.py) for analyzing a training dataset outside the +running service. Useful for: + + - Diagnosing why a class produces false positives / misidentifications + - Finding the training images that the deployed model itself misclassifies + (these are the worst offenders — usually mislabeled or low-quality crops) + - Inspecting borderline-correct images that sit near the decision boundary + - Spotting class-pair confusion (which classes get mixed up) + +Layout: + - Core pipeline: load_tflite, preprocess_for_inference, classify_image — + all mirroring CustomObjectClassificationProcessor exactly + - Default run: scan the dataset, classify every image with the deployed + model.tflite, report misclassified + borderline images per class, and + print a confusion matrix + - Optional diagnostics (flags): image-quality breakdown, scoring an + unlabeled "negative" folder, cross-class contamination analysis (find + training images in class A that visually look like class B and pull + inference toward A), and copying worst offenders out for review + +Recommended workflow when troubleshooting misclassifications: + + 1. Run the basic scan first (no extra flags). Read top-down: + - Class balance ratio. If > 3x, balance counts before anything else. + The dominant class will absorb borderline predictions otherwise. + - Per-class accuracy. Any class < 50% needs attention. + - Confusion matrix. If multiple classes all over-predict the same + class (e.g. Buddy->Rex, Bailey->Rex, none->Rex), you have + feature collapse, not "a few bad photos." Don't bother with + contamination analysis yet — fix the collapse first. + + 2. Check for "degenerate blob" upsampling. Look at the SHAPE column on + worst-offender rows. If most misclassified crops are < 80x80, the + small originals are being stretched 3-7x to fit the 224x224 model + input. Upsampled crops collapse to a similar region of feature space + regardless of identity — the model can't tell them apart and defaults + them to whichever class has the most of them. + + Fix: quarantine every image where min(w, h) < 80 (or 100 for a + stricter cut) and retrain. This works when the named class has + plenty of non-small examples to fall back on AND the small crops + are mostly degenerate blobs (target unrecognizable at that size). + + CAVEAT — sometimes small crops ARE the signal, not the noise: if + your target naturally appears small at the camera distance (cats + indoors, distant subjects, wide-FOV setups), the small crops in + the named class ARE the typical inference-time input. Removing them + leaves the model unable to recognize the target at its natural + detection size, and accuracy on the named class collapses after + retraining. If that happens — named-class accuracy drops sharply + after size cut + retrain — restore the quarantine and switch to + visual review of just the misclassified small crops instead of + bulk size filtering. The size threshold is a tool for "tons of + accidental tiny blobs polluting a class with otherwise large + examples," not a universal cleanup. + + 3. Verify the "none" class exists and is healthy. Without a strong + "none" class, every unknown crop at inference gets forced into one of + your real classes — the model has no "I don't know" option. Aim for: + - Count similar to your other classes (don't let it be the smallest) + - Images >= 100x100, well-framed + - Visual variety: other dogs/objects, partial views, empty scenes, + not just one type of negative + + 4. Look for cross-class duplicates from the same Frigate event. If the + same timestamp prefix appears across multiple class folders (e.g. + "1772052999.x" present in Buddy AND Bailey AND Rex AND none), those + crops came from one moment in time. Either they were extracted from a + multi-object frame and labeled inconsistently, or they're near- + duplicates of one scene cropped slightly differently. Inspect them as + a group and decide together. + + 5. Only after (1)-(4) are clean, run --confuses : for + targeted contamination analysis. The "ringleaders" section at the + bottom is the actionable part: a short list of images appearing + repeatedly as nearest neighbors of the wrong class. Those are the + few photos doing most of the damage. + + 6. Stop deleting when the contamination delta column shows ALL negative + values for the source class. That means dataset images in + are already visually distinct from in fixed-backbone + embedding space — the trained model just hasn't learned to use that + separation. The fix from that point is to ADD more training data for + the underperforming class, not delete more. Aim for at least 20 well- + framed images per class. + +The dataset must be the same layout Frigate trains from: + //dataset//*.{webp,png,jpg,jpeg} + +The model must already be trained: + //model.tflite + //labelmap.txt + +Command-line examples (mirror the workflow steps above): + + One-time setup — download the ImageNet-pretrained MobileNetV2 backbone + that --confuses uses for model-independent embeddings: + + curl -L -o /config/model_cache/mobilenetv2-7.onnx \\ + https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx + + Step 1 — Basic scan. Always start here. Reads class balance, accuracy, + confusion matrix, and per-class worst offenders: + + python3 object_dataset.py --name "" --top-n 25 + + Step 2 — Same scan plus image-quality stats (blur, brightness, aspect + distortion) for correct vs misclassified rows. Use when you suspect + systematic quality issues are driving the misses: + + python3 object_dataset.py --name "" --top-n 25 --quality + + Step 2 (cleanup) — Quarantine crops below 80x80 (the upsampling-blob + fix). Mirrors the class folder structure so individual images can be + restored. Change `threshold = 80` to 64 (looser) or 100 (stricter): + + python3 - <<'EOF' + import cv2 + from pathlib import Path + dataset = Path("/media/frigate/clips//dataset") + quarantine = Path("/media/frigate/clips//quarantine_small") + threshold = 80 + moved = 0 + for cls_dir in sorted(dataset.iterdir()): + if not cls_dir.is_dir(): + continue + for img_path in sorted(cls_dir.iterdir()): + if img_path.suffix.lower() not in (".png", ".jpg", ".jpeg", ".webp"): + continue + img = cv2.imread(str(img_path)) + if img is None: + continue + h, w = img.shape[:2] + if min(h, w) < threshold: + dest_dir = quarantine / cls_dir.name + dest_dir.mkdir(parents=True, exist_ok=True) + img_path.rename(dest_dir / img_path.name) + moved += 1 + print(f"moved {moved} images") + EOF + + Revert any quarantine directory (puts everything back into dataset/): + + python3 - <<'EOF' + from pathlib import Path + quarantine = Path("/media/frigate/clips//quarantine_small") + dataset = Path("/media/frigate/clips//dataset") + for cls_dir in sorted(quarantine.iterdir()): + if not cls_dir.is_dir(): + continue + target = dataset / cls_dir.name + target.mkdir(parents=True, exist_ok=True) + for img_path in sorted(cls_dir.iterdir()): + img_path.rename(target / img_path.name) + EOF + + Step 4 — Inspect a same-timestamp cluster across all classes (replace + TIMESTAMP with the prefix you saw in worst-offenders, e.g. "1772052999"): + + mkdir -p /tmp/timestamp_cluster + cd "/media/frigate/clips//dataset" + for f in */*TIMESTAMP*; do + cls=$(dirname "$f"); fn=$(basename "$f") + cp "$f" "/tmp/timestamp_cluster/${cls}__${fn}" + done + + Step 5 — Cross-class contamination. Lists specific images that + look like , plus a ringleader summary of the few worst offenders. + Also copies all misclassified images into a flat browse-able folder + bucketed by (true_class)__as__(predicted_class): + + python3 object_dataset.py --name "" \\ + --embedding-model /config/model_cache/mobilenetv2-7.onnx \\ + --confuses Rex:Buddy --top-n 15 \\ + --save-misclassified /tmp/_offenders + + Or let the script pick the worst-confused class pair from the matrix: + + python3 object_dataset.py --name "" \\ + --embedding-model /config/model_cache/mobilenetv2-7.onnx \\ + --confuses auto + + Score an unlabeled folder of runtime crops against the trained model — + useful for analyzing why specific inference-time misfires happened. + Prints full per-class probability vectors and threshold-pass status: + + python3 object_dataset.py --name "" \\ + --negative /path/to/runtime_misfires --threshold 0.8 + +Full flag reference: + python3 object_dataset.py \\ + --name \\ + [--clips-dir /media/frigate/clips] \\ + [--model-cache /config/model_cache] \\ + [--threshold 0.8] [--top-n 15] \\ + [--quality] [--negative ] [--save-misclassified ] \\ + [--confuses :] [--embedding-model ] +""" + +from __future__ import annotations + +import argparse +import os +import shutil +import sys +from dataclasses import dataclass + +import cv2 +import numpy as np + +try: + from tflite_runtime.interpreter import Interpreter +except ModuleNotFoundError: + from ai_edge_litert.interpreter import Interpreter + +CLASSIFIER_INPUT_SIZE = 224 +IMAGE_EXTS = (".webp", ".png", ".jpg", ".jpeg") + + +# --------------------------------------------------------------------------- +# Replicated Frigate pipeline +# --------------------------------------------------------------------------- + + +def load_tflite(model_path: str) -> tuple[Interpreter, list[dict], list[dict]]: + """Mirror CustomObjectClassificationProcessor.__build_detector.""" + interpreter = Interpreter(model_path=model_path, num_threads=2) + interpreter.allocate_tensors() + return ( + interpreter, + interpreter.get_input_details(), + interpreter.get_output_details(), + ) + + +def load_labelmap(path: str) -> dict[int, str]: + """Mirror util.builtin.load_labels(prefill=0, indexed=False).""" + with open(path, "r", encoding="utf-8") as f: + lines = [line.strip() for line in f.readlines() if line.strip()] + return {idx: line for idx, line in enumerate(lines)} + + +def preprocess_for_inference(image_bgr: np.ndarray) -> np.ndarray: + """Mirror the inference preprocessing in process_frame. + + Frigate decodes the camera frame YUV->RGB, crops, then cv2.resize to + 224x224, and passes the uint8 array directly to the int8-quantized + interpreter. On disk we read BGR via cv2.imread, so we must convert + to RGB to match the channel order the model was trained on. + """ + rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + resized = cv2.resize(rgb, (CLASSIFIER_INPUT_SIZE, CLASSIFIER_INPUT_SIZE)) + return resized + + +class MobileNetEmbedder: + """ImageNet-pretrained MobileNetV2 backbone via cv2.dnn. + + Used as a model-independent visual embedder for cross-class contamination + analysis. The user's trained classifier may have memorized contaminating + training images and place them inside the right class in its own embedding + space — a fixed external backbone keeps the analysis honest. + + Expects the standard ONNX Model Zoo MobileNetV2-7 file (PyTorch-style + preprocessing: ImageNet mean/std on /255 input). Output is 1000-d ImageNet + logits; L2-normalized for cosine-similarity comparisons. + """ + + IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) + IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) + + def __init__(self, model_path: str): + if not os.path.exists(model_path): + raise FileNotFoundError(model_path) + self.net = cv2.dnn.readNetFromONNX(model_path) + + def embed(self, image_bgr: np.ndarray) -> np.ndarray: + rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + resized = cv2.resize(rgb, (224, 224)).astype(np.float32) / 255.0 + normalized = (resized - self.IMAGENET_MEAN) / self.IMAGENET_STD + # NHWC -> NCHW + blob = np.transpose(normalized, (2, 0, 1))[np.newaxis, :, :, :] + self.net.setInput(blob) + out = self.net.forward().squeeze().astype(np.float32) + norm = float(np.linalg.norm(out)) + return out / norm if norm > 0 else out + + +def classify_image( + interpreter: Interpreter, + input_details: list[dict], + output_details: list[dict], + image_bgr: np.ndarray, +) -> np.ndarray: + """Mirror _classify_object's tensor flow. + + Returns the per-class probability vector (length = num_classes) after + the exact `probs = res / res.sum(axis=0)` renormalization Frigate uses + on the int8-quantized output. + """ + resized = preprocess_for_inference(image_bgr) + tensor = np.expand_dims(resized, axis=0) + interpreter.set_tensor(input_details[0]["index"], tensor) + interpreter.invoke() + res = interpreter.get_tensor(output_details[0]["index"])[0].astype(np.float32) + total = res.sum(axis=0) + if total <= 0: + # Defensive: all zeros from a degenerate quantization step. + return np.full_like(res, 1.0 / len(res)) + return res / total + + +# --------------------------------------------------------------------------- +# Sample loading +# --------------------------------------------------------------------------- + + +@dataclass +class ImageSample: + path: str + true_label: str | None # None for the unlabeled negative folder + shape: tuple[int, int] + probs: np.ndarray + pred_idx: int + pred_label: str + pred_score: float + true_idx: int | None + true_score: float | None + + +def laplacian_variance(image_bgr: np.ndarray) -> float: + gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY) + return float(cv2.Laplacian(gray, cv2.CV_64F).var()) + + +def mean_brightness(image_bgr: np.ndarray) -> float: + return float(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY).mean()) + + +def aspect_distortion(shape: tuple[int, int]) -> float: + """How far the crop is from square; |1 - max(w,h)/min(w,h)|. + + A wide or tall crop gets squashed to 224x224 by the inference resize, + which can be a hidden source of misclassification. + """ + w, h = shape + if w <= 0 or h <= 0: + return float("inf") + return float(max(w, h) / min(w, h) - 1.0) + + +def iter_dataset(dataset_dir: str) -> list[tuple[str, str]]: + """Yield (class_name, image_path) pairs from the dataset directory.""" + pairs: list[tuple[str, str]] = [] + if not os.path.isdir(dataset_dir): + return pairs + for cls in sorted(os.listdir(dataset_dir)): + cls_dir = os.path.join(dataset_dir, cls) + if not os.path.isdir(cls_dir) or cls.startswith("."): + continue + for name in sorted(os.listdir(cls_dir)): + if name.startswith("."): + continue + if not name.lower().endswith(IMAGE_EXTS): + continue + pairs.append((cls, os.path.join(cls_dir, name))) + return pairs + + +def classify_folder( + folder: str, + interpreter: Interpreter, + input_details: list[dict], + output_details: list[dict], + labelmap: dict[int, str], + label_to_idx: dict[str, int], + true_label: str | None = None, +) -> list[ImageSample]: + """Classify every image directly under `folder`. Used for the negative set.""" + samples: list[ImageSample] = [] + if not os.path.isdir(folder): + return samples + for name in sorted(os.listdir(folder)): + if name.startswith(".") or not name.lower().endswith(IMAGE_EXTS): + continue + path = os.path.join(folder, name) + img = cv2.imread(path) + if img is None: + print(f" [skip unreadable] {name}") + continue + probs = classify_image(interpreter, input_details, output_details, img) + pred_idx = int(np.argmax(probs)) + true_idx = label_to_idx.get(true_label) if true_label is not None else None + true_score = float(probs[true_idx]) if true_idx is not None else None + samples.append( + ImageSample( + path=path, + true_label=true_label, + shape=(img.shape[1], img.shape[0]), + probs=probs, + pred_idx=pred_idx, + pred_label=labelmap[pred_idx], + pred_score=float(probs[pred_idx]), + true_idx=true_idx, + true_score=true_score, + ) + ) + return samples + + +def classify_dataset( + dataset_dir: str, + interpreter: Interpreter, + input_details: list[dict], + output_details: list[dict], + labelmap: dict[int, str], + label_to_idx: dict[str, int], +) -> list[ImageSample]: + samples: list[ImageSample] = [] + pairs = iter_dataset(dataset_dir) + for cls, path in pairs: + img = cv2.imread(path) + if img is None: + print(f" [skip unreadable] {cls}/{os.path.basename(path)}") + continue + probs = classify_image(interpreter, input_details, output_details, img) + pred_idx = int(np.argmax(probs)) + true_idx = label_to_idx.get(cls) + true_score = float(probs[true_idx]) if true_idx is not None else None + samples.append( + ImageSample( + path=path, + true_label=cls, + shape=(img.shape[1], img.shape[0]), + probs=probs, + pred_idx=pred_idx, + pred_label=labelmap[pred_idx], + pred_score=float(probs[pred_idx]), + true_idx=true_idx, + true_score=true_score, + ) + ) + return samples + + +# --------------------------------------------------------------------------- +# Baseline analyses (always run) +# --------------------------------------------------------------------------- + + +def summarize_dataset(samples: list[ImageSample], labelmap: dict[int, str]) -> None: + """Per-class counts, accuracy, mean confidence on the true class.""" + print("\n" + "=" * 78) + print(f"DATASET OVERVIEW ({len(samples)} images)") + print("=" * 78) + + by_class: dict[str, list[ImageSample]] = {} + for s in samples: + by_class.setdefault(s.true_label or "", []).append(s) + + print( + f"\n{'class':<20} {'count':>6} {'acc':>6} {'mean_p_true':>12} " + f"{'min_p_true':>10} {'mislabeled':>11}" + ) + for cls in sorted(by_class): + rows = by_class[cls] + correct = sum(1 for r in rows if r.pred_label == cls) + mean_pt = ( + np.mean([r.true_score for r in rows if r.true_score is not None]) + if any(r.true_score is not None for r in rows) + else float("nan") + ) + min_pt = ( + np.min([r.true_score for r in rows if r.true_score is not None]) + if any(r.true_score is not None for r in rows) + else float("nan") + ) + acc = correct / len(rows) if rows else 0.0 + bad = len(rows) - correct + print( + f"{cls:<20} {len(rows):>6} {acc:>6.2%} {mean_pt:>12.3f} " + f"{min_pt:>10.3f} {bad:>11}" + ) + + # Class balance — large skew can hide poor minority-class accuracy in the totals. + counts = [len(by_class[c]) for c in by_class] + if counts: + print( + f"\nClass balance: min={min(counts)} max={max(counts)} " + f"ratio={max(counts) / max(1, min(counts)):.1f}x" + ) + + +def confusion_matrix(samples: list[ImageSample], labelmap: dict[int, str]) -> None: + print("\n" + "=" * 78) + print("CONFUSION MATRIX (rows = true class, cols = predicted class)") + print("=" * 78) + + classes = [labelmap[i] for i in sorted(labelmap)] + idx = {c: i for i, c in enumerate(classes)} + mat = np.zeros((len(classes), len(classes)), dtype=int) + for s in samples: + if s.true_label is None or s.true_label not in idx: + continue + mat[idx[s.true_label], s.pred_idx] += 1 + + col_w = max(8, max(len(c) for c in classes) + 1) + header = " " * (col_w + 2) + "".join(f"{c[: col_w - 1]:>{col_w}}" for c in classes) + print("\n" + header) + for i, cls in enumerate(classes): + row = "".join(f"{mat[i, j]:>{col_w}}" for j in range(len(classes))) + print(f" {cls[: col_w - 1]:<{col_w}}{row}") + + # Top class-pair confusions, in both directions. + pairs: list[tuple[str, str, int]] = [] + for i, src in enumerate(classes): + for j, dst in enumerate(classes): + if i != j and mat[i, j] > 0: + pairs.append((src, dst, int(mat[i, j]))) + pairs.sort(key=lambda r: -r[2]) + if pairs: + print("\nTop class-pair confusions:") + for src, dst, n in pairs[:10]: + print(f" {n:>4} {src} -> {dst}") + + +def worst_offenders( + samples: list[ImageSample], + labelmap: dict[int, str], + top_n: int, + quality: bool, +) -> list[ImageSample]: + """Print the worst-offender images grouped by class. + + Two buckets per class: + (a) Misclassified — predicted label differs from folder. Sorted by the + confidence in the WRONG class (highest first). These are the most + confidently wrong images, the strongest candidates for relabeling + or deletion. + (b) Borderline-correct — predicted label matches but p_true is low. + These sit near the decision boundary; they're not actively wrong + but they make the class harder to learn cleanly. + + Returns the union of (a) lists across classes, for optional copying. + """ + print("\n" + "=" * 78) + print(f"WORST OFFENDERS (top {top_n} per class)") + print("=" * 78) + + by_class: dict[str, list[ImageSample]] = {} + for s in samples: + if s.true_label is None: + continue + by_class.setdefault(s.true_label, []).append(s) + + all_misclassified: list[ImageSample] = [] + for cls in sorted(by_class): + rows = by_class[cls] + miscls = [r for r in rows if r.pred_label != cls] + miscls.sort(key=lambda r: -r.pred_score) + all_misclassified.extend(miscls[:top_n]) + + print(f"\n-- class '{cls}': {len(miscls)}/{len(rows)} misclassified --") + if miscls: + print( + f"{'p_pred':>7} {'pred':<18} {'p_true':>7} {'shape':>11}" + + (" blur bright aspect " if quality else " ") + + "name" + ) + for r in miscls[:top_n]: + shape = f"{r.shape[0]}x{r.shape[1]}" + extra = "" + if quality: + img = cv2.imread(r.path) + blur = laplacian_variance(img) if img is not None else float("nan") + bright = mean_brightness(img) if img is not None else float("nan") + aspect = aspect_distortion(r.shape) + extra = f" {blur:5.0f} {bright:6.1f} {aspect:6.2f} " + pt = r.true_score if r.true_score is not None else float("nan") + print( + f"{r.pred_score:7.3f} {r.pred_label:<18} " + f"{pt:7.3f} {shape:>11}{extra}{os.path.basename(r.path)}" + ) + + # Borderline-correct: labeled right but the model isn't confident. + correct = [r for r in rows if r.pred_label == cls and r.true_score is not None] + correct.sort(key=lambda r: r.true_score or 0.0) + borderline = correct[: max(5, top_n // 3)] + if borderline: + print("\n borderline-correct (lowest p_true while still labeled right):") + for r in borderline: + # Second-best class names the neighbor that's pulling on this image. + if len(r.probs) > 1: + second = int(np.argsort(-r.probs)[1]) + second_lbl = labelmap[second] + second_p = float(r.probs[second]) + else: + second_lbl = "-" + second_p = 0.0 + print( + f" p_true={r.true_score:.3f} " + f"p_2nd={second_p:.3f} ({second_lbl}) " + f"{os.path.basename(r.path)}" + ) + + return all_misclassified + + +# --------------------------------------------------------------------------- +# Optional diagnostics +# --------------------------------------------------------------------------- + + +def quality_summary(samples: list[ImageSample]) -> None: + """Compare image-quality stats for correct vs misclassified images. + + Helps answer: are the worst offenders systematically blurrier / darker / + more squashed than the rest of the class? If so, the fix is to tighten + the data-collection criteria, not just delete individual images. + """ + print("\n" + "=" * 78) + print("IMAGE QUALITY — correct vs misclassified") + print("=" * 78) + + rows: list[tuple[str, bool, float, float, float]] = [] + for s in samples: + if s.true_label is None: + continue + img = cv2.imread(s.path) + if img is None: + continue + blur = laplacian_variance(img) + bright = mean_brightness(img) + aspect = aspect_distortion(s.shape) + rows.append((s.true_label, s.pred_label == s.true_label, blur, bright, aspect)) + + if not rows: + print(" (no readable images)") + return + + correct = [r for r in rows if r[1]] + wrong = [r for r in rows if not r[1]] + + def stats(name: str, getter, group: list) -> None: + if not group: + print(f" {name:<14} (no samples)") + return + vals = np.array([getter(r) for r in group]) + print( + f" {name:<14} n={len(vals):>4} " + f"mean={vals.mean():8.2f} median={np.median(vals):8.2f} " + f"p10={np.percentile(vals, 10):8.2f} p90={np.percentile(vals, 90):8.2f}" + ) + + print("\nBlur (laplacian variance — higher = sharper):") + stats("correct", lambda r: r[2], correct) + stats("misclassified", lambda r: r[2], wrong) + print("\nBrightness (0..255):") + stats("correct", lambda r: r[3], correct) + stats("misclassified", lambda r: r[3], wrong) + print("\nAspect distortion (0 = square; higher = more squashed by 224x224):") + stats("correct", lambda r: r[4], correct) + stats("misclassified", lambda r: r[4], wrong) + + +def summarize_negative( + neg_samples: list[ImageSample], + threshold: float, + labelmap: dict[int, str], +) -> None: + """Score an unlabeled folder of runtime crops against the model. + + Equivalent to face_dataset.py's negative-set analysis: each image is + classified, and we print its full probability vector plus whether it + would clear the configured threshold. High-confidence predictions on + crops the user knows are wrong indicate the training set is leaking + a representative image into the wrong class. + """ + print("\n" + "=" * 78) + print(f"NEGATIVE SET ANALYSIS ({len(neg_samples)} images, threshold={threshold})") + print("=" * 78) + + classes = [labelmap[i] for i in sorted(labelmap)] + print(f"\n{'pass':>4} {'score':>6} {'pred':<18} full prob vector / name") + for s in neg_samples: + passes = "yes" if s.pred_score >= threshold else "no" + full = " ".join(f"{c}={float(s.probs[i]):.2f}" for i, c in enumerate(classes)) + print( + f"{passes:>4} {s.pred_score:6.3f} {s.pred_label:<18} " + f"{full} :: {os.path.basename(s.path)}" + ) + + +def pick_worst_confusion_pair( + samples: list[ImageSample], + labelmap: dict[int, str], +) -> tuple[int, str | None, str | None]: + """Return (count, source, target) for the most-confused class pair.""" + classes = [labelmap[i] for i in sorted(labelmap)] + pairs: list[tuple[int, str, str]] = [] + for src in classes: + for tgt in classes: + if src == tgt: + continue + n = sum(1 for s in samples if s.true_label == src and s.pred_label == tgt) + if n > 0: + pairs.append((n, src, tgt)) + pairs.sort(reverse=True) + return pairs[0] if pairs else (0, None, None) + + +def cross_class_contamination( + samples: list[ImageSample], + source_class: str, + target_class: str, + label_to_idx: dict[str, int], + embedder: MobileNetEmbedder, + top_n: int, +) -> None: + """Find training images in source_class that visually look like target_class. + + Generalizes face_dataset.py's contamination_analysis to N classes. Uses a + fixed ImageNet backbone (NOT the user's trained classifier) so that + contaminators which the trained model has memorized into the source class + still surface — the trained model's own embedding would hide them. + + Three sections: + 1. Source-image culprits ranked by `cos(img, target_centroid) - + cos(img, source_centroid)`. Positive delta = the image looks more + like the target class than its own class — prime relabeling + candidates. + 2. For each target image, the top-3 nearest source training images. + Shows the visual chain of confusion image-by-image. + 3. Ringleader summary: source images that appear most often as a top-3 + neighbor across the target set. These few photos are responsible + for the bulk of the confusion. + """ + src = [s for s in samples if s.true_label == source_class] + tgt = [s for s in samples if s.true_label == target_class] + + if not src or not tgt: + print( + f"\nERROR: need both classes populated; got {len(src)} " + f"'{source_class}' and {len(tgt)} '{target_class}'" + ) + return + + print("\n" + "=" * 78) + print( + f"CROSS-CLASS CONTAMINATION '{source_class}' leaning toward '{target_class}'" + ) + print(" (model-independent embeddings via ImageNet MobileNetV2)") + print("=" * 78) + + target_idx = label_to_idx.get(target_class) + source_idx = label_to_idx.get(source_class) + + print(f"\nEmbedding {len(src) + len(tgt)} images...") + src_embs = np.stack([embedder.embed(cv2.imread(s.path)) for s in src]) + tgt_embs = np.stack([embedder.embed(cv2.imread(s.path)) for s in tgt]) + + src_centroid = src_embs.mean(axis=0) + src_centroid /= np.linalg.norm(src_centroid) + 1e-9 + tgt_centroid = tgt_embs.mean(axis=0) + tgt_centroid /= np.linalg.norm(tgt_centroid) + 1e-9 + + src_to_src = src_embs @ src_centroid + src_to_tgt = src_embs @ tgt_centroid + delta = src_to_tgt - src_to_src + + print(f"\n-- '{source_class}' images sorted by '{target_class}'-likeness --") + print(f" positive delta = visually closer to '{target_class}' centroid") + print( + f" p_{target_class} = trained model's probability for '{target_class}' " + f"on this image\n" + ) + delta_label = "delta" + tgt_cos_label = f"cos_{target_class}"[:12] + src_cos_label = f"cos_{source_class}"[:12] + p_tgt_label = f"p_{target_class}"[:10] + print( + f" {delta_label:>7} {tgt_cos_label:>12} {src_cos_label:>12} " + f"{p_tgt_label:>10} name" + ) + order = np.argsort(-delta) + for i in order[:top_n]: + s = src[i] + p_tgt = float(s.probs[target_idx]) if target_idx is not None else float("nan") + print( + f" {delta[i]:+7.3f} {src_to_tgt[i]:12.3f} {src_to_src[i]:12.3f} " + f"{p_tgt:10.3f} {os.path.basename(s.path)}" + ) + + print(f"\n-- nearest '{source_class}' neighbors for each '{target_class}' image --") + neighbor_counts: dict[str, int] = {} + src_paths = [os.path.basename(s.path) for s in src] + for i, t in enumerate(tgt): + sims = src_embs @ tgt_embs[i] + top3 = np.argsort(-sims)[:3] + p_src = float(t.probs[source_idx]) if source_idx is not None else float("nan") + marker = " <5} name") + ranked = sorted(neighbor_counts.items(), key=lambda r: -r[1]) + for name, count in ranked[:top_n]: + print(f" {count:>5} {name}") + + +def save_misclassified(samples: list[ImageSample], out_dir: str) -> None: + """Copy misclassified images to /__as__/. + + Lets you browse the worst offenders in a file manager and bulk-delete or + relabel them without poking through the original dataset tree. + """ + print("\n" + "=" * 78) + print(f"SAVING MISCLASSIFIED IMAGES -> {out_dir}") + print("=" * 78) + count = 0 + for s in samples: + if s.true_label is None or s.pred_label == s.true_label: + continue + bucket = os.path.join(out_dir, f"{s.true_label}__as__{s.pred_label}") + os.makedirs(bucket, exist_ok=True) + score_tag = f"{int(round(s.pred_score * 100)):03d}" + dest = os.path.join(bucket, f"{score_tag}_{os.path.basename(s.path)}") + try: + shutil.copy2(s.path, dest) + count += 1 + except OSError as err: + print(f" [copy failed] {s.path}: {err}") + print(f" copied {count} images into {out_dir}") + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + + +def main() -> int: + ap = argparse.ArgumentParser( + description=( + "Analyze a Frigate object-classification training dataset against its " + "deployed TFLite model." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + ap.add_argument( + "--name", + required=True, + help="Classification model name (matches the key in classification.custom.)", + ) + ap.add_argument( + "--clips-dir", + default="/media/frigate/clips", + help="Frigate clips directory; dataset is read from //dataset", + ) + ap.add_argument( + "--model-cache", + default="/config/model_cache", + help="Frigate model_cache; model is read from //model.tflite", + ) + ap.add_argument( + "--threshold", + type=float, + default=0.8, + help="Score threshold (matches model_config.threshold; default 0.8)", + ) + ap.add_argument( + "--top-n", + type=int, + default=15, + help="Worst-offender images to show per class", + ) + ap.add_argument( + "--quality", + action="store_true", + help="Include blur/brightness/aspect stats for correct vs misclassified", + ) + ap.add_argument( + "--negative", + default=None, + help="Score an unlabeled folder of crops against the model", + ) + ap.add_argument( + "--save-misclassified", + default=None, + help="Copy every misclassified image into this directory for review", + ) + ap.add_argument( + "--confuses", + default=None, + help=( + "Cross-class contamination analysis. Format ':', " + "e.g. 'rex:buddy' to find Rex training images that look like " + "Buddy. Use 'auto' to pick the worst pair from the confusion matrix. " + "Requires --embedding-model." + ), + ) + ap.add_argument( + "--embedding-model", + default=None, + help=( + "Path to ONNX MobileNetV2 file for model-independent embeddings " + "(required by --confuses). Download once with: curl -L -o " + "/config/model_cache/mobilenetv2-7.onnx https://github.com/onnx/" + "models/raw/main/validated/vision/classification/mobilenet/model/" + "mobilenetv2-7.onnx" + ), + ) + args = ap.parse_args() + + dataset_dir = os.path.join(args.clips_dir, args.name, "dataset") + model_path = os.path.join(args.model_cache, args.name, "model.tflite") + labelmap_path = os.path.join(args.model_cache, args.name, "labelmap.txt") + + for required in (dataset_dir, model_path, labelmap_path): + if not os.path.exists(required): + print(f"ERROR: required path not found: {required}") + return 1 + + print(f"Loading model from {model_path}") + interpreter, input_details, output_details = load_tflite(model_path) + labelmap = load_labelmap(labelmap_path) + label_to_idx = {v: k for k, v in labelmap.items()} + print(f" labels: {sorted(labelmap.values())}") + + print(f"\nScanning dataset at {dataset_dir} ...") + samples = classify_dataset( + dataset_dir, interpreter, input_details, output_details, labelmap, label_to_idx + ) + if not samples: + print("no images found — aborting") + return 1 + print(f" classified {len(samples)} images") + + summarize_dataset(samples, labelmap) + confusion_matrix(samples, labelmap) + misclassified = worst_offenders(samples, labelmap, args.top_n, args.quality) + + if args.quality: + quality_summary(samples) + + if args.negative: + print(f"\nLoading negatives from {args.negative} ...") + neg = classify_folder( + args.negative, + interpreter, + input_details, + output_details, + labelmap, + label_to_idx, + true_label=None, + ) + if neg: + summarize_negative(neg, args.threshold, labelmap) + + if args.confuses: + if not args.embedding_model: + print( + "\nERROR: --confuses requires --embedding-model (path to ONNX " + "MobileNetV2). See --help for the download command." + ) + return 1 + try: + embedder = MobileNetEmbedder(args.embedding_model) + except (FileNotFoundError, cv2.error) as err: + print(f"\nERROR: failed to load embedding model: {err}") + return 1 + + if args.confuses == "auto": + n, src, tgt = pick_worst_confusion_pair(samples, labelmap) + if src is None: + print( + "\nNo misclassifications in dataset — " + "nothing to investigate via --confuses auto" + ) + else: + print(f"\nAuto-picked worst confusion: {src} -> {tgt} ({n} cases)") + cross_class_contamination( + samples, src, tgt, label_to_idx, embedder, args.top_n + ) + else: + if ":" not in args.confuses: + print("\nERROR: --confuses expects ':' or 'auto'") + return 1 + src, tgt = args.confuses.split(":", 1) + if src not in label_to_idx or tgt not in label_to_idx: + print( + f"\nERROR: class names must be in the labelmap " + f"({sorted(label_to_idx)})" + ) + return 1 + cross_class_contamination( + samples, src, tgt, label_to_idx, embedder, args.top_n + ) + + if args.save_misclassified: + save_misclassified(misclassified, args.save_misclassified) + + return 0 + + +if __name__ == "__main__": + sys.exit(main())