Don't track shared memory in frame tracker

This commit is contained in:
Nicolas Mowen 2024-11-14 20:45:20 -07:00
parent ed9c67804a
commit 20ea7d4cce

View File

@ -3,8 +3,10 @@
import datetime import datetime
import logging import logging
import subprocess as sp import subprocess as sp
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from multiprocessing import shared_memory from multiprocessing import resource_tracker as _mprt
from multiprocessing import shared_memory as _mpshm
from string import printable from string import printable
from typing import AnyStr, Optional from typing import AnyStr, Optional
@ -731,32 +733,56 @@ class FrameManager(ABC):
pass pass
class DictFrameManager(FrameManager): class SharedMemory(_mpshm.SharedMemory):
def __init__(self): # https://github.com/python/cpython/issues/82300#issuecomment-2169035092
self.frames = {}
def create(self, name, size) -> AnyStr: __lock = threading.Lock()
mem = bytearray(size)
self.frames[name] = mem
return mem
def get(self, name, shape): def __init__(
mem = self.frames[name] self,
return np.ndarray(shape, dtype=np.uint8, buffer=mem) name: Optional[str] = None,
create: bool = False,
size: int = 0,
*,
track: bool = True,
) -> None:
self._track = track
def close(self, name): # if tracking, normal init will suffice
pass if track:
return super().__init__(name=name, create=create, size=size)
def delete(self, name): # lock so that other threads don't attempt to use the
del self.frames[name] # register function during this time
with self.__lock:
# temporarily disable registration during initialization
orig_register = _mprt.register
_mprt.register = self.__tmp_register
# initialize; ensure original register function is
# re-instated
try:
super().__init__(name=name, create=create, size=size)
finally:
_mprt.register = orig_register
@staticmethod
def __tmp_register(*args, **kwargs) -> None:
return
def unlink(self) -> None:
if _mpshm._USE_POSIX and self._name:
_mpshm._posixshmem.shm_unlink(self._name)
if self._track:
_mprt.unregister(self._name, "shared_memory")
class SharedMemoryFrameManager(FrameManager): class SharedMemoryFrameManager(FrameManager):
def __init__(self): def __init__(self):
self.shm_store: dict[str, shared_memory.SharedMemory] = {} self.shm_store: dict[str, SharedMemory] = {}
def create(self, name: str, size) -> AnyStr: def create(self, name: str, size) -> AnyStr:
shm = shared_memory.SharedMemory(name=name, create=True, size=size) shm = SharedMemory(name=name, create=True, size=size, track=False)
self.shm_store[name] = shm self.shm_store[name] = shm
return shm.buf return shm.buf
@ -765,7 +791,7 @@ class SharedMemoryFrameManager(FrameManager):
if name in self.shm_store: if name in self.shm_store:
shm = self.shm_store[name] shm = self.shm_store[name]
else: else:
shm = shared_memory.SharedMemory(name=name) shm = SharedMemory(name=name)
self.shm_store[name] = shm self.shm_store[name] = shm
return np.ndarray(shape, dtype=np.uint8, buffer=shm.buf) return np.ndarray(shape, dtype=np.uint8, buffer=shm.buf)
except FileNotFoundError: except FileNotFoundError:
@ -788,7 +814,7 @@ class SharedMemoryFrameManager(FrameManager):
del self.shm_store[name] del self.shm_store[name]
else: else:
try: try:
shm = shared_memory.SharedMemory(name=name) shm = SharedMemory(name=name)
shm.close() shm.close()
shm.unlink() shm.unlink()
except FileNotFoundError: except FileNotFoundError: