mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-15 11:32:09 +03:00
Correctly handle cuda input
This commit is contained in:
parent
43412f6390
commit
226b370cad
@ -87,7 +87,10 @@ class CudaGraphRunner(BaseModelRunner):
|
|||||||
"""Get the input width of the model."""
|
"""Get the input width of the model."""
|
||||||
return self._session.get_inputs()[0].shape[3]
|
return self._session.get_inputs()[0].shape[3]
|
||||||
|
|
||||||
def run(self, input_name: str, tensor_input: np.ndarray):
|
def run(self, input: dict[str, Any]):
|
||||||
|
# Extract the single tensor input (assuming one input)
|
||||||
|
input_name = list(input.keys())[0]
|
||||||
|
tensor_input = input[input_name]
|
||||||
tensor_input = np.ascontiguousarray(tensor_input)
|
tensor_input = np.ascontiguousarray(tensor_input)
|
||||||
|
|
||||||
if not self._captured:
|
if not self._captured:
|
||||||
|
|||||||
@ -95,7 +95,7 @@ class ONNXDetector(DetectionApi):
|
|||||||
if self._cg_runner is not None:
|
if self._cg_runner is not None:
|
||||||
try:
|
try:
|
||||||
# Run using CUDA graphs if available
|
# Run using CUDA graphs if available
|
||||||
tensor_output = self._cg_runner.run(model_input_name, tensor_input)
|
tensor_output = self._cg_runner.run({model_input_name: tensor_input})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"CUDA Graphs failed, falling back to regular run: {e}")
|
logger.warning(f"CUDA Graphs failed, falling back to regular run: {e}")
|
||||||
self._cg_runner = None
|
self._cg_runner = None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user