贴一下官方的代码示意地址:ONNX动态输入
#首先我们要有个tensor输入,比如网络的输入是batch_size*1*224*224 x = torch.randn(batch_size, 1, 224, 224, requires_grad=True) #torch_model是模型的实例化 torch_out = torch_model(x) #下面是导出的主要函数 # Export the model torch.onnx.export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) "super_resolution.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes 'output' : {0 : 'batch_size'}})比如下面的:
pillar_x = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0") pillar_y = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0") pillar_z = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0") pillar_i = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")实际效果如下: 以上。 目前正在研究怎么使用trt进行动态的数据调用做inference工作,待补充。