frigate/converters/ssd_mobilenet_v2_coco/assets/0001-fix-trt.patch

53 lines
1.6 KiB
Diff
Raw Normal View History

From 40953eaae8ca55838e046325b257faaff0bbe33f Mon Sep 17 00:00:00 2001
From: YS <ys@gm.com>
Date: Tue, 21 Dec 2021 21:01:35 +0300
Subject: [PATCH] fix trt
---
ssd/build_engine.py | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/ssd/build_engine.py b/ssd/build_engine.py
index 65729a9..e4a55c8 100644
--- a/ssd/build_engine.py
+++ b/ssd/build_engine.py
@@ -17,7 +17,6 @@ import uff
import tensorrt as trt
import graphsurgeon as gs
-
DIR_NAME = os.path.dirname(__file__)
LIB_FILE = os.path.abspath(os.path.join(DIR_NAME, 'libflattenconcat.so'))
MODEL_SPECS = {
@@ -286,19 +285,23 @@ def main():
text=True,
debug_mode=DEBUG_UFF)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
- builder.max_workspace_size = 1 << 28
+ config = builder.create_builder_config()
+ config.max_workspace_size = 1 << 28
builder.max_batch_size = 1
- builder.fp16_mode = True
+ config.set_flag(trt.BuilderFlag.FP16)
parser.register_input('Input', INPUT_DIMS)
parser.register_output('MarkOutput_0')
parser.parse(spec['tmp_uff'], network)
- engine = builder.build_cuda_engine(network)
+
+ plan = builder.build_serialized_network(network, config)
+
+ with trt.Runtime(TRT_LOGGER) as runtime:
+ engine = runtime.deserialize_cuda_engine(plan)
buf = engine.serialize()
with open(spec['output_bin'], 'wb') as f:
f.write(buf)
-
if __name__ == '__main__':
main()
--
2.17.1