Don't track shared memory in frame tracker

This commit is contained in:
Nicolas Mowen 2024-11-14 20:45:20 -07:00
parent ed9c67804a
commit 20ea7d4cce

View File

@ -3,8 +3,10 @@
import datetime
import logging
import subprocess as sp
import threading
from abc import ABC, abstractmethod
from multiprocessing import shared_memory
from multiprocessing import resource_tracker as _mprt
from multiprocessing import shared_memory as _mpshm
from string import printable
from typing import AnyStr, Optional
@ -731,32 +733,56 @@ class FrameManager(ABC):
pass
class DictFrameManager(FrameManager):
def __init__(self):
self.frames = {}
class SharedMemory(_mpshm.SharedMemory):
# https://github.com/python/cpython/issues/82300#issuecomment-2169035092
def create(self, name, size) -> AnyStr:
mem = bytearray(size)
self.frames[name] = mem
return mem
__lock = threading.Lock()
def get(self, name, shape):
mem = self.frames[name]
return np.ndarray(shape, dtype=np.uint8, buffer=mem)
def __init__(
self,
name: Optional[str] = None,
create: bool = False,
size: int = 0,
*,
track: bool = True,
) -> None:
self._track = track
def close(self, name):
pass
# if tracking, normal init will suffice
if track:
return super().__init__(name=name, create=create, size=size)
def delete(self, name):
del self.frames[name]
# lock so that other threads don't attempt to use the
# register function during this time
with self.__lock:
# temporarily disable registration during initialization
orig_register = _mprt.register
_mprt.register = self.__tmp_register
# initialize; ensure original register function is
# re-instated
try:
super().__init__(name=name, create=create, size=size)
finally:
_mprt.register = orig_register
@staticmethod
def __tmp_register(*args, **kwargs) -> None:
return
def unlink(self) -> None:
if _mpshm._USE_POSIX and self._name:
_mpshm._posixshmem.shm_unlink(self._name)
if self._track:
_mprt.unregister(self._name, "shared_memory")
class SharedMemoryFrameManager(FrameManager):
def __init__(self):
self.shm_store: dict[str, shared_memory.SharedMemory] = {}
self.shm_store: dict[str, SharedMemory] = {}
def create(self, name: str, size) -> AnyStr:
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
shm = SharedMemory(name=name, create=True, size=size, track=False)
self.shm_store[name] = shm
return shm.buf
@ -765,7 +791,7 @@ class SharedMemoryFrameManager(FrameManager):
if name in self.shm_store:
shm = self.shm_store[name]
else:
shm = shared_memory.SharedMemory(name=name)
shm = SharedMemory(name=name)
self.shm_store[name] = shm
return np.ndarray(shape, dtype=np.uint8, buffer=shm.buf)
except FileNotFoundError:
@ -788,7 +814,7 @@ class SharedMemoryFrameManager(FrameManager):
del self.shm_store[name]
else:
try:
shm = shared_memory.SharedMemory(name=name)
shm = SharedMemory(name=name)
shm.close()
shm.unlink()
except FileNotFoundError: