frigate/engine.py

88 lines
2.9 KiB
Python
Raw Normal View History

import ctypes
import argparse
import sys
import os
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def model_input_shape():
return 3, 300, 300
def build_engine(uff_model_path, trt_engine_datatype=trt.DataType.FLOAT, batch_size=1):
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network() as network, \
trt.UffParser() as parser:
builder.max_workspace_size = 1 << 30
builder.max_batch_size = batch_size
if trt_engine_datatype == trt.DataType.HALF:
builder.fp16_mode = True
parser.register_input("Input", model_input_shape())
parser.register_output("MarkOutput_0")
parser.parse(uff_model_path, network)
return builder.build_cuda_engine(network)
def save_engine(engine, engine_dest_path):
os.makedirs(os.path.dirname(engine_dest_path), exist_ok=True)
buf = engine.serialize()
with open(engine_dest_path, 'wb') as f:
f.write(buf)
def load_engine(trt_runtime, engine_path):
with open(engine_path, 'rb') as f:
engine_data = f.read()
engine = trt_runtime.deserialize_cuda_engine(engine_data)
return engine
def load_plugins():
trt.init_libnvinfer_plugins(TRT_LOGGER, '')
try:
ctypes.CDLL('libflattenconcat.so')
except Exception as e:
print("Error: {}\n{}".format(e, "Make sure FlattenConcat custom plugin layer is provided"))
sys.exit(1)
TRT_PRECISION_TO_DATATYPE = {
16: trt.DataType.HALF,
32: trt.DataType.FLOAT
}
if __name__ == '__main__':
# Define script command line arguments
parser = argparse.ArgumentParser(description='Utility to build TensorRT engine prior to inference.')
parser.add_argument('-i', "--input",
dest='uff_model_path', metavar='UFF_MODEL_PATH', required=True,
help='preprocessed TensorFlow model in UFF format')
parser.add_argument('-p', '--precision', type=int, choices=[32, 16], default=32,
help='desired TensorRT float precision to build an engine with')
parser.add_argument('-b', '--batch_size', type=int, default=1,
help='max TensorRT engine batch size')
parser.add_argument("-o", "--output", dest='trt_engine_path',
help="path of the output file",
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "engine.buf"))
# Parse arguments passed
args = parser.parse_args()
load_plugins()
# Using supplied .uff file alongside with UffParser build TensorRT engine
print("Building TensorRT engine. This may take few minutes.")
trt_engine = build_engine(
uff_model_path=args.uff_model_path,
trt_engine_datatype=TRT_PRECISION_TO_DATATYPE[args.precision],
batch_size=args.batch_size)
# Save the engine to file
save_engine(trt_engine, args.trt_engine_path)
print("TensorRT engine saved to {}".format(args.trt_engine_path))