• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

pytorch模型转ONNX转TensorRT,模型转换和推理部署

武飞扬头像
mzgong
帮助1

一、pth模型转ONNX

  1.  
    import os
  2.  
    import sys
  3.  
    import torch
  4.  
    import numpy as np
  5.  
     
  6.  
    from feat.model import ResNet  # 导入自己的模型类
  7.  
     
  8.  
    def load_checkpoint(checkpoint_file, model):
  9.  
    """Loads the checkpoint from the given file."""
  10.  
    err_str = "Checkpoint '{}' not found"
  11.  
    assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
  12.  
    checkpoint = torch.load(checkpoint_file, map_location="cpu")
  13.  
        return checkpoint["epoch"]
  14.  
     
  15.  
    if __name__ == '__main__':
  16.  
     
  17.  
    os.environ['CUDA_VISIBLE_DEVICES']='0'   # 设置运行显卡号
  18.  
     
  19.  
    model_filename='resnet_epoch_17.pyth'
  20.  
     
  21.  
    # init model
  22.  
    model = ResNet()
  23.  
    load_checkpoint(model_filename, model)
  24.  
    model = model.cuda()
  25.  
    model.eval()
  26.  
     
  27.  
    onnx_name = 'resnet.onnx'  # 输出onnx文件
  28.  
    example = torch.randn((1,3,224,224))  # 模型输入大小
  29.  
    example = example.cuda()
  30.  
    input_names = ["input"]
  31.  
    output_names = ["outputs"]
  32.  
    dynamic_axes = {"input": {0: "batch_size"}, "outputs": {0: "batch_size"}}
  33.  
        
  34.  
        # 模型转换并保存
  35.  
        torch.onnx.export(model, example,onnx_name, opset_version=12, input_names=input_names, output_names=output_names, dynamic_axes=None)
学新通

二、测试ONNX模型精度

  1.  
    import os
  2.  
    import sys
  3.  
    import torch
  4.  
    import numpy as np
  5.  
    import onnxruntime
  6.  
    import time
  7.  
     
  8.  
    if __name__ == '__main__':
  9.  
     
  10.  
    os.environ['CUDA_VISIBLE_DEVICES']='0' # 设置运行显卡号
  11.  
     
  12.  
    model_filename='resnet_epoch_17.pyth'
  13.  
     
  14.  
    # init model
  15.  
    model = ResNet()
  16.  
    load_checkpoint(model_filename, model)
  17.  
    model = model.cuda()
  18.  
    model.eval()
  19.  
     
  20.  
        session = onnxruntime.InferenceSession(onnx_name,providers=['CUDAExecutionProvider'])
  21.  
        img = np.random.randn(1,3,224,224).astype(np.float32) # 随机输出
  22.  
        t1 = time.time()
  23.  
        onnx_preds = session.run(None, {"input": img})
  24.  
        print("onnx preds result: ", onnx_preds)
  25.  
        t2 = time.time()
  26.  
        pth_preds = model(torch.from_numpy(img).cuda())
  27.  
        print("pth preds result: ", pth_preds)
  28.  
        t3 = time.time()
  29.  
        
学新通

对比打印结果,确认结果保持一致

  1.  
    onnx preds res: [array([[-0.13128008, 0.04037811, 0.0529038 , 0.101323 , -0.03352938, [43/1903]
  2.  
    0.03099938, 0.06380229, -0.03544223, -0.03368076, 0.06361518,
  3.  
    -0.00668521, -0.01996843, -0.0132075 , -0.03448019, 0.17793381,
  4.  
    0.08131739, 0.10232763, -0.09122676, 0.01173838, 0.03181053,
  5.  
    -0.05899123, 0.01569226, -0.04734752, -0.12551421, 0.00686131,
  6.  
    -0.00749457, -0.03729884, 0.05349742, 0.0304895 , 0.02956274,
  7.  
    0.00393172, 0.00196273, 0.01296113, -0.03985897, -0.06289426,
  8.  
    -0.0825834 , -0.28903952, 0.02842386, -0.1718263 , -0.05555207,
  9.  
    -0.03707219, 0.10904352, 0.06582819, 0.04960179, 0.01508415,
  10.  
    0.05469472, 0.28663486, 0.1183752 , -0.06070469, -0.05200525,
  11.  
    -0.03477468, -0.06193898, -0.04432139, 0.0843045 , -0.12080704,
  12.  
    0.00163073, -0.08544722, 0.11994477, 0.02619292, 0.05066012,
  13.  
    -0.00332941, -0.1488586 , 0.07936171, 0.06203181, -0.0645356 ,
  14.  
    -0.07661135, -0.05883927, -0.00459472, -0.06721105, -0.02880175,
  15.  
    -0.00337263, -0.00927516, 0.03289868, 0.10054352, -0.09545278,
  16.  
    -0.0216963 , 0.11413048, -0.04580398, 0.02614305, -0.08269466,
  17.  
    0.01835637, 0.17654261, 0.0573773 , -0.06440263, 0.01176349,
  18.  
    0.00998674, 0.02840159, 0.14086637, -0.02473863, 0.05228964,
  19.  
    -0.03329878, -0.02751228, -0.04788758, 0.1546051 , 0.05838795,
  20.  
    -0.02351469, -0.01315547, -0.13732813, -0.08146078, 0.01943143,
  21.  
    -0.08991284, 0.14222968, -0.14729632, 0.24547395, -0.05293949,
  22.  
    0.04446511, 0.05436133, -0.09403729, -0.0900671 , 0.04516568,
  23.  
    0.10035874, -0.03281724, 0.19480802, -0.11344203, -0.02487336,
  24.  
    -0.08126407, -0.00491623, 0.04313428, -0.10474856, -0.11427435,
  25.  
    -0.01765379, -0.04613522, 0.08338863, 0.00564523, 0.14067101,
  26.  
    0.05428562, 0.12530491, -0.2503076 ]], dtype=float32)]
  27.  
    pth preds res: tensor([[-0.1313, 0.0404, 0.0529, 0.1013, -0.0335, 0.0310, 0.0638, -0.0354,
  28.  
    -0.0337, 0.0636, -0.0067, -0.0200, -0.0132, -0.0345, 0.1779, 0.0813,
  29.  
    0.1023, -0.0912, 0.0117, 0.0318, -0.0590, 0.0157, -0.0473, -0.1255,
  30.  
    0.0069, -0.0075, -0.0373, 0.0535, 0.0305, 0.0296, 0.0039, 0.0020,
  31.  
    0.0130, -0.0399, -0.0629, -0.0826, -0.2890, 0.0284, -0.1718, -0.0556,
  32.  
    -0.0371, 0.1090, 0.0658, 0.0496, 0.0151, 0.0547, 0.2866, 0.1184,
  33.  
    -0.0607, -0.0520, -0.0348, -0.0619, -0.0443, 0.0843, -0.1208, 0.0016,
  34.  
    -0.0854, 0.1199, 0.0262, 0.0507, -0.0033, -0.1489, 0.0794, 0.0620,
  35.  
    -0.0645, -0.0766, -0.0588, -0.0046, -0.0672, -0.0288, -0.0034, -0.0093,
  36.  
    0.0329, 0.1005, -0.0955, -0.0217, 0.1141, -0.0458, 0.0261, -0.0827,
  37.  
    0.0184, 0.1765, 0.0574, -0.0644, 0.0118, 0.0100, 0.0284, 0.1409,
  38.  
    -0.0247, 0.0523, -0.0333, -0.0275, -0.0479, 0.1546, 0.0584, -0.0235,
  39.  
    -0.0132, -0.1373, -0.0815, 0.0194, -0.0899, 0.1422, -0.1473, 0.2455,
  40.  
    -0.0529, 0.0445, 0.0544, -0.0940, -0.0901, 0.0452, 0.1004, -0.0328,
  41.  
    0.1948, -0.1134, -0.0249, -0.0813, -0.0049, 0.0431, -0.1047, -0.1143,
  42.  
    -0.0177, -0.0461, 0.0834, 0.0056, 0.1407, 0.0543, 0.1253, -0.2503]],
  43.  
    device='cuda:0', grad_fn=<DivBackward0>)
  44.  
    onnx cost time: 0.0062367916107177734 pth cost time: 0.030622243881225586
学新通

三、ONNX转TensorRT

  1.  
    import os
  2.  
    import tensorrt as trt
  3.  
     
  4.  
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
  5.  
    trt_runtime = trt.Runtime(TRT_LOGGER)
  6.  
     
  7.  
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  8.  
     
  9.  
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  10.  
     
  11.  
    os.environ['CUDA_VISIBLE_DEVICES'] = '2'
  12.  
     
  13.  
    def get_engine(input_shape, onnx_file_path = "", engine_file_path=""):
  14.  
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
  15.  
    def build_engine():
  16.  
    """Takes an ONNX file and creates a TensorRT engine to run inference with"""
  17.  
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser, builder.create_builder_config() as config:
  18.  
    # builder.max_workspace_size = 1 << 32 # 256MiBs
  19.  
    config.max_workspace_size = 1 << 33 # 1024MB
  20.  
    # config.set_flag(trt.BuilderFlag.FP16) # 使用Fp16精度,如果使用FP32需要屏蔽这一句。
  21.  
    builder.max_batch_size = 1
  22.  
    # Parse model file
  23.  
    if not os.path.exists(onnx_file_path):
  24.  
    print('ONNX file {} not found, please run torch2onnx first to generate it.'.format(onnx_file_path))
  25.  
    exit(0)
  26.  
    print('Loading ONNX file from path {}...'.format(onnx_file_path))
  27.  
    with open(onnx_file_path, 'rb') as model:
  28.  
    print('Beginning ONNX file parsing')
  29.  
    if not parser.parse(model.read()):
  30.  
    print ('ERROR: Failed to parse the ONNX file.')
  31.  
    for error in range(parser.num_errors):
  32.  
    print (parser.get_error(error))
  33.  
    return None
  34.  
    # The actual yolov3.onnx is generated with batch size 64. Reshape input to batch size 1
  35.  
    network.get_input(0).shape = input_shape
  36.  
    print('Completed parsing of ONNX file')
  37.  
    print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
  38.  
    # config = trt.IBuilderConfig(max_workspace_size = 1 << 32)
  39.  
    # config.
  40.  
    engine = builder.build_engine(network, config)
  41.  
    print("Completed creating Engine")
  42.  
    with open(engine_file_path, "wb") as f:
  43.  
    f.write(engine.serialize())
  44.  
    return engine
  45.  
    if os.path.exists(engine_file_path):
  46.  
    # If a serialized engine exists, use it instead of building an engine.
  47.  
    print("Reading engine from file {}".format(engine_file_path))
  48.  
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
  49.  
    return runtime.deserialize_cuda_engine(f.read())
  50.  
    else:
  51.  
    return build_engine()
  52.  
     
  53.  
    if __name__ == '__main__':
  54.  
    onnx_file = 'resnet.onnx'
  55.  
    engin_file = 'resnet.engine'
  56.  
    input_shape = [1, 3, 224, 224]
  57.  
    get_engine(input_shape, onnx_file, engin_file)
学新通

四、测试TensorRT模型精度

  1.  
    import os
  2.  
    import sys
  3.  
    import cv2
  4.  
    import copy
  5.  
    import torch
  6.  
    import numpy as np
  7.  
    import time
  8.  
    import onnxruntime
  9.  
    import pycuda.driver as cuda
  10.  
    import tensorrt as trt
  11.  
     
  12.  
    os.environ['CUDA_VISIBLE_DEVICES']='3'
  13.  
    TRT_LOGGER = trt.Logger()
  14.  
    import trt_common
  15.  
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  16.  
    if sys.getdefaultencoding() != 'utf-8':
  17.  
    reload(sys)
  18.  
    sys.setdefaultencoding('utf-8')
  19.  
     
  20.  
    # Simple helper data class that's a little nicer to use than a 2-tuple.
  21.  
    class HostDeviceMem(object):
  22.  
    def __init__(self, host_mem, device_mem):
  23.  
    self.host = host_mem
  24.  
    self.device = device_mem
  25.  
    def __str__(self):
  26.  
    return "Host:\n" str(self.host) "\nDevice:\n" str(self.device)
  27.  
    def __repr__(self):
  28.  
    return self.__str__()
  29.  
     
  30.  
    def get_engine(engine_file_path):
  31.  
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
  32.  
    return runtime.deserialize_cuda_engine(f.read())
  33.  
     
  34.  
    # Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
  35.  
    def allocate_buffers(engine):
  36.  
    inputs = []
  37.  
    outputs = []
  38.  
    bindings = []
  39.  
    stream = cuda.Stream()
  40.  
    for binding in engine:
  41.  
    size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
  42.  
    dtype = trt.nptype(engine.get_binding_dtype(binding))
  43.  
    # Allocate host and device buffers
  44.  
    host_mem = cuda.pagelocked_empty(size, dtype)
  45.  
    device_mem = cuda.mem_alloc(host_mem.nbytes)
  46.  
    # Append the device buffer to device bindings.
  47.  
    bindings.append(int(device_mem))
  48.  
    # Append to the appropriate list.
  49.  
    if engine.binding_is_input(binding):
  50.  
    inputs.append(HostDeviceMem(host_mem, device_mem))
  51.  
    else:
  52.  
    outputs.append(HostDeviceMem(host_mem, device_mem))
  53.  
    return inputs, outputs, bindings, stream
  54.  
     
  55.  
    # This function is generalized for multiple inputs/outputs for full dimension networks.
  56.  
    # inputs and outputs are expected to be lists of HostDeviceMem objects.
  57.  
    def do_inference_v2(context, bindings, inputs, outputs, stream):
  58.  
    # Transfer input data to the GPU.
  59.  
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
  60.  
    # Run inference.
  61.  
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
  62.  
    # Transfer predictions back from the GPU.
  63.  
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
  64.  
    # Synchronize the stream
  65.  
    stream.synchronize()
  66.  
    # Return only the host outputs.
  67.  
    return [out.host for out in outputs]
  68.  
     
  69.  
    if __name__ == '__main__':
  70.  
    os.environ['CUDA_VISIBLE_DEVICES']='3'
  71.  
    onnx_name = 'resnet.onnx'
  72.  
    trt_name = 'resnet.engine'
  73.  
     
  74.  
    session = onnxruntime.InferenceSession(onnx_name,providers=['CUDAExecutionProvider'])
  75.  
     
  76.  
    import pycuda.autoprimaryctx
  77.  
    engine = get_engine(trt_name)
  78.  
    context = engine.create_execution_context()
  79.  
    inputs, outputs, bindings, stream = allocate_buffers(engine)
  80.  
     
  81.  
        img = cv2.imread('test.jpg')
  82.  
    img = cv2.resize(img, (224,224))
  83.  
    img = img.transpose([2,0,1]).astype(np.float32)
  84.  
    img = np.expand_dims(img, axis=0)
  85.  
    t1 = time.time()
  86.  
    onnx_preds = session.run(None, {"input": img})
  87.  
    #print("onnx_preds: ", onnx_preds)
  88.  
    t2 = time.time()
  89.  
     
  90.  
    inputs[0].host = np.ascontiguousarray(img)
  91.  
    trt_outputs = do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
  92.  
    data = copy.deepcopy(trt_outputs[0])
  93.  
    #print("preds: ", data)
  94.  
    t3 = time.time()
  95.  
    print("onnx: ", t2-t1, " trt: ", t3-t2)
学新通

五、ERROR

error1

ERROR: Failed to parse the ONNX file.

In node 84 (importConv): UNSUPPORTED_NODE: Assertion failed: inputs.at(2).is_weights() && "The bias tensor is required to be an initializer for the Conv operator."

solution:

pip install onnx-simplifier

通过simplify重新保存ONNX模型

  1.  
    import onnx
  2.  
    from onnxsim import simplify
  3.  
     
  4.  
    onnx_model = onnx.load('resnet.onnx')
  5.  
    model_simp, check = simplify(onnx_model)
  6.  
     
  7.  
    onnx.save(model_simp, 'resnet_sim.onnx')

error2

ValueError: ndarray is not contiguous

solution:

数组不连续,使用np.ascontiguousarray(img) 处理数组

inputs[0].host = np.ascontiguousarray(img)

error3

Error Code 1: Myelin (Compiled against cuBLASLt 11.11.3.0 but running against cuBLASLt 11.4.1.0.)

solution:

tensorrt 和 torch同时使用调用了不同版本的libmyelin.so,不同同时使用。tensorrt和onnxruntime同时使用也会发生。

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgcfaej
系列文章
更多 icon
同类精品
更多 icon
继续加载