mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-15 03:22:10 +03:00
Correctly handle cuda input
This commit is contained in:
parent
43412f6390
commit
226b370cad
@ -87,7 +87,10 @@ class CudaGraphRunner(BaseModelRunner):
|
||||
"""Get the input width of the model."""
|
||||
return self._session.get_inputs()[0].shape[3]
|
||||
|
||||
def run(self, input_name: str, tensor_input: np.ndarray):
|
||||
def run(self, input: dict[str, Any]):
|
||||
# Extract the single tensor input (assuming one input)
|
||||
input_name = list(input.keys())[0]
|
||||
tensor_input = input[input_name]
|
||||
tensor_input = np.ascontiguousarray(tensor_input)
|
||||
|
||||
if not self._captured:
|
||||
|
||||
@ -95,7 +95,7 @@ class ONNXDetector(DetectionApi):
|
||||
if self._cg_runner is not None:
|
||||
try:
|
||||
# Run using CUDA graphs if available
|
||||
tensor_output = self._cg_runner.run(model_input_name, tensor_input)
|
||||
tensor_output = self._cg_runner.run({model_input_name: tensor_input})
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA Graphs failed, falling back to regular run: {e}")
|
||||
self._cg_runner = None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user