From 20ea7d4cce275c9ec3ae17075d2c42da8ddd703d Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 14 Nov 2024 20:45:20 -0700 Subject: [PATCH] Don't track shared memory in frame tracker --- frigate/util/image.py | 64 ++++++++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/frigate/util/image.py b/frigate/util/image.py index 484737f71..3170f6814 100644 --- a/frigate/util/image.py +++ b/frigate/util/image.py @@ -3,8 +3,10 @@ import datetime import logging import subprocess as sp +import threading from abc import ABC, abstractmethod -from multiprocessing import shared_memory +from multiprocessing import resource_tracker as _mprt +from multiprocessing import shared_memory as _mpshm from string import printable from typing import AnyStr, Optional @@ -731,32 +733,56 @@ class FrameManager(ABC): pass -class DictFrameManager(FrameManager): - def __init__(self): - self.frames = {} +class SharedMemory(_mpshm.SharedMemory): + # https://github.com/python/cpython/issues/82300#issuecomment-2169035092 - def create(self, name, size) -> AnyStr: - mem = bytearray(size) - self.frames[name] = mem - return mem + __lock = threading.Lock() - def get(self, name, shape): - mem = self.frames[name] - return np.ndarray(shape, dtype=np.uint8, buffer=mem) + def __init__( + self, + name: Optional[str] = None, + create: bool = False, + size: int = 0, + *, + track: bool = True, + ) -> None: + self._track = track - def close(self, name): - pass + # if tracking, normal init will suffice + if track: + return super().__init__(name=name, create=create, size=size) - def delete(self, name): - del self.frames[name] + # lock so that other threads don't attempt to use the + # register function during this time + with self.__lock: + # temporarily disable registration during initialization + orig_register = _mprt.register + _mprt.register = self.__tmp_register + + # initialize; ensure original register function is + # re-instated + try: + super().__init__(name=name, create=create, size=size) + finally: + _mprt.register = orig_register + + @staticmethod + def __tmp_register(*args, **kwargs) -> None: + return + + def unlink(self) -> None: + if _mpshm._USE_POSIX and self._name: + _mpshm._posixshmem.shm_unlink(self._name) + if self._track: + _mprt.unregister(self._name, "shared_memory") class SharedMemoryFrameManager(FrameManager): def __init__(self): - self.shm_store: dict[str, shared_memory.SharedMemory] = {} + self.shm_store: dict[str, SharedMemory] = {} def create(self, name: str, size) -> AnyStr: - shm = shared_memory.SharedMemory(name=name, create=True, size=size) + shm = SharedMemory(name=name, create=True, size=size, track=False) self.shm_store[name] = shm return shm.buf @@ -765,7 +791,7 @@ class SharedMemoryFrameManager(FrameManager): if name in self.shm_store: shm = self.shm_store[name] else: - shm = shared_memory.SharedMemory(name=name) + shm = SharedMemory(name=name) self.shm_store[name] = shm return np.ndarray(shape, dtype=np.uint8, buffer=shm.buf) except FileNotFoundError: @@ -788,7 +814,7 @@ class SharedMemoryFrameManager(FrameManager): del self.shm_store[name] else: try: - shm = shared_memory.SharedMemory(name=name) + shm = SharedMemory(name=name) shm.close() shm.unlink() except FileNotFoundError: