diff --git a/onnx/export_onnx.py b/onnx/export_onnx.py index b3598f9..a9cb864 100644 --- a/onnx/export_onnx.py +++ b/onnx/export_onnx.py @@ -1,7 +1,7 @@ """ Export ONNX model of MODNet with: input shape: (batch_size, 3, height, width) - output shape: (batch_size, 1, height, width) + output shape: (batch_size, 1, height, width) Arguments: --ckpt-path: path of the checkpoint that will be converted @@ -50,6 +50,6 @@ if __name__ == '__main__': # export to onnx model torch.onnx.export( - modnet.module, dummy_input, args.output_path, export_params = True, - input_names = ['input'], output_names = ['output'], + modnet.module, dummy_input, args.output_path, export_params = True, + input_names = ['input'], output_names = ['output'], dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, 'output': {0: 'batch_size', 2: 'height', 3: 'width'}})