mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-16 10:06:42 +03:00
88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
import ctypes
|
|
import argparse
|
|
import sys
|
|
import os
|
|
import tensorrt as trt
|
|
|
|
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
|
|
|
|
|
def model_input_shape():
|
|
return 3, 300, 300
|
|
|
|
|
|
def build_engine(uff_model_path, trt_engine_datatype=trt.DataType.FLOAT, batch_size=1):
|
|
with trt.Builder(TRT_LOGGER) as builder, \
|
|
builder.create_network() as network, \
|
|
trt.UffParser() as parser:
|
|
builder.max_workspace_size = 1 << 30
|
|
builder.max_batch_size = batch_size
|
|
if trt_engine_datatype == trt.DataType.HALF:
|
|
builder.fp16_mode = True
|
|
|
|
parser.register_input("Input", model_input_shape())
|
|
parser.register_output("MarkOutput_0")
|
|
parser.parse(uff_model_path, network)
|
|
|
|
return builder.build_cuda_engine(network)
|
|
|
|
|
|
def save_engine(engine, engine_dest_path):
|
|
os.makedirs(os.path.dirname(engine_dest_path), exist_ok=True)
|
|
buf = engine.serialize()
|
|
with open(engine_dest_path, 'wb') as f:
|
|
f.write(buf)
|
|
|
|
|
|
def load_engine(trt_runtime, engine_path):
|
|
with open(engine_path, 'rb') as f:
|
|
engine_data = f.read()
|
|
engine = trt_runtime.deserialize_cuda_engine(engine_data)
|
|
return engine
|
|
|
|
|
|
def load_plugins():
|
|
trt.init_libnvinfer_plugins(TRT_LOGGER, '')
|
|
|
|
try:
|
|
ctypes.CDLL('libflattenconcat.so')
|
|
except Exception as e:
|
|
print("Error: {}\n{}".format(e, "Make sure FlattenConcat custom plugin layer is provided"))
|
|
sys.exit(1)
|
|
|
|
|
|
TRT_PRECISION_TO_DATATYPE = {
|
|
16: trt.DataType.HALF,
|
|
32: trt.DataType.FLOAT
|
|
}
|
|
|
|
if __name__ == '__main__':
|
|
# Define script command line arguments
|
|
parser = argparse.ArgumentParser(description='Utility to build TensorRT engine prior to inference.')
|
|
parser.add_argument('-i', "--input",
|
|
dest='uff_model_path', metavar='UFF_MODEL_PATH', required=True,
|
|
help='preprocessed TensorFlow model in UFF format')
|
|
parser.add_argument('-p', '--precision', type=int, choices=[32, 16], default=32,
|
|
help='desired TensorRT float precision to build an engine with')
|
|
parser.add_argument('-b', '--batch_size', type=int, default=1,
|
|
help='max TensorRT engine batch size')
|
|
parser.add_argument("-o", "--output", dest='trt_engine_path',
|
|
help="path of the output file",
|
|
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "engine.buf"))
|
|
|
|
# Parse arguments passed
|
|
args = parser.parse_args()
|
|
|
|
load_plugins()
|
|
|
|
# Using supplied .uff file alongside with UffParser build TensorRT engine
|
|
print("Building TensorRT engine. This may take few minutes.")
|
|
trt_engine = build_engine(
|
|
uff_model_path=args.uff_model_path,
|
|
trt_engine_datatype=TRT_PRECISION_TO_DATATYPE[args.precision],
|
|
batch_size=args.batch_size)
|
|
|
|
# Save the engine to file
|
|
save_engine(trt_engine, args.trt_engine_path)
|
|
print("TensorRT engine saved to {}".format(args.trt_engine_path))
|