Use batching for state classification generation

This commit is contained in:
Nicolas Mowen 2025-12-15 12:41:39 -07:00
parent 0f36422b35
commit 603d9f7d27

View File

@ -499,6 +499,10 @@ def _extract_keyframes(
"""
Extract keyframes from recordings at specified timestamps and crop to specified regions.
This implementation batches work by running multiple ffmpeg snapshot commands
concurrently, which significantly reduces total runtime compared to
processing each timestamp serially.
Args:
ffmpeg_path: Path to ffmpeg binary
timestamps: List of timestamp dicts from _select_balanced_timestamps
@ -508,15 +512,21 @@ def _extract_keyframes(
Returns:
List of paths to successfully extracted and cropped keyframe images
"""
keyframe_paths = []
from concurrent.futures import ThreadPoolExecutor, as_completed
for idx, ts_info in enumerate(timestamps):
if not timestamps:
return []
# Limit the number of concurrent ffmpeg processes so we don't overload the host.
max_workers = min(5, len(timestamps))
def _process_timestamp(idx: int, ts_info: dict) -> tuple[int, str | None]:
camera = ts_info["camera"]
timestamp = ts_info["timestamp"]
if camera not in camera_crops:
logger.warning(f"No crop coordinates for camera {camera}")
continue
return idx, None
norm_x1, norm_y1, norm_x2, norm_y2 = camera_crops[camera]
@ -533,7 +543,7 @@ def _extract_keyframes(
.get()
)
except Exception:
continue
return idx, None
relative_time = timestamp - recording.start_time
@ -547,38 +557,57 @@ def _extract_keyframes(
height=None,
)
if image_data:
nparr = np.frombuffer(image_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if not image_data:
return idx, None
if img is not None:
height, width = img.shape[:2]
nparr = np.frombuffer(image_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
x1 = int(norm_x1 * width)
y1 = int(norm_y1 * height)
x2 = int(norm_x2 * width)
y2 = int(norm_y2 * height)
if img is None:
return idx, None
x1_clipped = max(0, min(x1, width))
y1_clipped = max(0, min(y1, height))
x2_clipped = max(0, min(x2, width))
y2_clipped = max(0, min(y2, height))
height, width = img.shape[:2]
if x2_clipped > x1_clipped and y2_clipped > y1_clipped:
cropped = img[y1_clipped:y2_clipped, x1_clipped:x2_clipped]
resized = cv2.resize(cropped, (224, 224))
x1 = int(norm_x1 * width)
y1 = int(norm_y1 * height)
x2 = int(norm_x2 * width)
y2 = int(norm_y2 * height)
output_path = os.path.join(output_dir, f"frame_{idx:04d}.jpg")
cv2.imwrite(output_path, resized)
keyframe_paths.append(output_path)
x1_clipped = max(0, min(x1, width))
y1_clipped = max(0, min(y1, height))
x2_clipped = max(0, min(x2, width))
y2_clipped = max(0, min(y2, height))
if x2_clipped <= x1_clipped or y2_clipped <= y1_clipped:
return idx, None
cropped = img[y1_clipped:y2_clipped, x1_clipped:x2_clipped]
resized = cv2.resize(cropped, (224, 224))
output_path = os.path.join(output_dir, f"frame_{idx:04d}.jpg")
cv2.imwrite(output_path, resized)
return idx, output_path
except Exception as e:
logger.debug(
f"Failed to extract frame from {recording.path} at {relative_time}s: {e}"
)
continue
return idx, None
return keyframe_paths
keyframes_with_index: list[tuple[int, str]] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_idx = {
executor.submit(_process_timestamp, idx, ts_info): idx
for idx, ts_info in enumerate(timestamps)
}
for future in as_completed(future_to_idx):
_, path = future.result()
if path:
keyframes_with_index.append((future_to_idx[future], path))
keyframes_with_index.sort(key=lambda item: item[0])
return [path for _, path in keyframes_with_index]
def _select_distinct_images(