onnx2trt 踩坑总结
Parsing Error
Gather
- Assertion failed: !(data->getType() == nvinfer1::DataType::kINT32 && nbDims == 1) && “Cannot perform gather on a shape tensor!”
-
file: torch/nn/functional.py
# 源码: if torch._C._get_tracing_state(): return [(torch.floor(input.size(i + 2) * torch.tensor(float(scale_factors[i])))) for i in range(dim)] else: return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] # 修改为: return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)]
-
file: custom
x=x.view(x.size(),-1) -> x=x.flatten(1) # section 1 # 源码 class ShapeModel(nn.Module): def __init__(self): super(ShapeModel, self).__init__() def forward(self, x): return x.shape # 修改为 class ShapeModel(nn.Module): def __init__(self): super(ShapeModel, self).__init__() def forward(self, x): return torch.tensor(x.shape) # section 2 # 源码 class ResizeModel(nn.Module): def __init__(self): super(ResizeModel, self).__init__() def forward(self, x): return F.interpolate(x, scale_factor=(2, 2), mode='nearest') # 修改为 class ResizeModel(nn.Module): def __init__(self): super(ResizeModel, self).__init__() def forward(self, x): sh = torch.tensor(x.shape) return F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode='nearest')
Upample
```python
#return F.upsample(x, size=(x.shape[2] * 2, x.shape[3] * 2), mode='bilinear', align_corners=True)
# RuntimeError: ONNX symbolic expected a constant value in the trace
#return F.interpolate(x, size=(x.shape[2] * 2, x.shape[3] * 2), mode='bilinear', align_corners=True)
# RuntimeError: ONNX symbolic expected a constant value in the trace
#return F.upsample(x, size=(600, 600), mode='bilinear', align_corners=False)
# UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
#return F.interpolate(x, size=(600, 600), mode='bilinear', align_corners=True)
# UserWarning: ONNX export failed on upsample_bilinear2d because align_corners == True not supported
# RuntimeError: ONNX export failed: Couldn't export operator aten::upsample_bilinear2d
return F.interpolate(x, size=(600, 600), mode='bilinear', align_corners=False) #no warning, all clear
```