From 40953eaae8ca55838e046325b257faaff0bbe33f Mon Sep 17 00:00:00 2001 From: YS Date: Tue, 21 Dec 2021 21:01:35 +0300 Subject: [PATCH] fix trt --- ssd/build_engine.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ssd/build_engine.py b/ssd/build_engine.py index 65729a9..e4a55c8 100644 --- a/ssd/build_engine.py +++ b/ssd/build_engine.py @@ -17,7 +17,6 @@ import uff import tensorrt as trt import graphsurgeon as gs - DIR_NAME = os.path.dirname(__file__) LIB_FILE = os.path.abspath(os.path.join(DIR_NAME, 'libflattenconcat.so')) MODEL_SPECS = { @@ -286,19 +285,23 @@ def main(): text=True, debug_mode=DEBUG_UFF) with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser: - builder.max_workspace_size = 1 << 28 + config = builder.create_builder_config() + config.max_workspace_size = 1 << 28 builder.max_batch_size = 1 - builder.fp16_mode = True + config.set_flag(trt.BuilderFlag.FP16) parser.register_input('Input', INPUT_DIMS) parser.register_output('MarkOutput_0') parser.parse(spec['tmp_uff'], network) - engine = builder.build_cuda_engine(network) + + plan = builder.build_serialized_network(network, config) + + with trt.Runtime(TRT_LOGGER) as runtime: + engine = runtime.deserialize_cuda_engine(plan) buf = engine.serialize() with open(spec['output_bin'], 'wb') as f: f.write(buf) - if __name__ == '__main__': main() -- 2.17.1