diff --git a/.gitignore b/.gitignore
index d64f109..134f2ac 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,6 @@
# Temporary directories and files
*.ckpt
+*.onnx
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/README.md b/README.md
index 7c9cab4..1bdadbe 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,7 @@ WebCam Video Demo [Offline][ Path of last checkpoint to load
- --output-path --> path of onnx model to be saved
-
-example:
-python export_modnet_onnx.py \
- --ckpt-path=modnet_photographic_portrait_matting.ckpt \
- --output-path=modnet.onnx
-
-output:
-ONNX model with dynamic input shape: (batch_size, 3, height, width) &
- output shape: (batch_size, 1, height, width)
-"""
-import os
-import argparse
-import torch
-import torch.nn as nn
-from torch.autograd import Variable
-from src.models.onnx_modnet import MODNet
-
-
-
-if __name__ == '__main__':
- # define cmd arguments
- parser = argparse.ArgumentParser()
- parser.add_argument('--ckpt-path', type=str, required=True, help='path of pre-trained MODNet')
- parser.add_argument('--output-path', type=str, required=True, help='path of output onnx model')
- args = parser.parse_args()
-
- # check input arguments
- if not os.path.exists(args.ckpt_path):
- print('Cannot find checkpoint path: {0}'.format(args.ckpt_path))
- exit()
-
- # define model & load checkpoint
- modnet = MODNet(backbone_pretrained=False)
- modnet = nn.DataParallel(modnet).cuda()
- state_dict = torch.load(args.ckpt_path)
- modnet.load_state_dict(state_dict)
- modnet.eval()
-
- # prepare dummy_input
- batch_size = 1
- height = 512
- width = 512
- dummy_input = Variable(torch.randn(batch_size, 3, height, width)).cuda()
-
- # export to onnx model
- torch.onnx.export(modnet.module, dummy_input, args.output_path, export_params = True, opset_version=11,
- input_names = ['input'], output_names = ['output'],
- dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'},
- 'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
diff --git a/onnx/README.md b/onnx/README.md
new file mode 100644
index 0000000..ff6189f
--- /dev/null
+++ b/onnx/README.md
@@ -0,0 +1,30 @@
+## MODNet - ONNX Model
+
+This ONNX version of MODNet is provided by [@manthan3C273](https://github.com/manthan3C273) from the community.
+Please note that the PyTorch version required for this ONNX export function is higher than the official MODNet code (torch==1.7.1 is recommended).
+
+You can try **MODNet - Image Matting Demo (ONNX version)** in [this Colab](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing).
+You can also download the ONNX version of the official **Image Matting Model** from [this link](https://drive.google.com/file/d/1cgycTQlYXpTh26gB9FTnthE7AvruV8hd/view?usp=sharing).
+
+To export the ONNX version of MODNet (assuming you are currently in project root directory):
+1. Download the pre-trained **Image Matting Model** from this [link](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing) and put the model into the folder `MODNet/pretrained/`.
+
+2. Install all dependencies by:
+ ```
+ pip install -r onnx/requirements.txt
+ ```
+
+3. Export the ONNX version of MODNet by:
+ ```shell
+ python -m onnx.export_onnx \
+ --ckpt-path=pretrained/modnet_photographic_portrait_matting.ckpt \
+ --output-path=pretrained/modnet_photographic_portrait_matting.onnx
+ ```
+
+4. Inference the ONNX model by:
+ ```shell
+ python -m onnx.inference_onnx \
+ --image-path=$FILENAME_OF_INPUT_IMAGE$ \
+ --output-path=$FILENAME_OF_OUTPUT_MATTE$ \
+ --model-path=pretrained/modnet_photographic_portrait_matting.onnx
+ ```
diff --git a/onnx/__init__.py b/onnx/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/onnx/export_onnx.py b/onnx/export_onnx.py
new file mode 100644
index 0000000..a9cb864
--- /dev/null
+++ b/onnx/export_onnx.py
@@ -0,0 +1,55 @@
+"""
+Export ONNX model of MODNet with:
+ input shape: (batch_size, 3, height, width)
+ output shape: (batch_size, 1, height, width)
+
+Arguments:
+ --ckpt-path: path of the checkpoint that will be converted
+ --output-path: path for saving the ONNX model
+
+Example:
+ python export_onnx.py \
+ --ckpt-path=modnet_photographic_portrait_matting.ckpt \
+ --output-path=modnet_photographic_portrait_matting.onnx
+"""
+
+import os
+import argparse
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+
+from . import modnet_onnx
+
+
+if __name__ == '__main__':
+ # define cmd arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--ckpt-path', type=str, required=True, help='path of the checkpoint that will be converted')
+ parser.add_argument('--output-path', type=str, required=True, help='path for saving the ONNX model')
+ args = parser.parse_args()
+
+ # check input arguments
+ if not os.path.exists(args.ckpt_path):
+ print('Cannot find checkpoint path: {0}'.format(args.ckpt_path))
+ exit()
+
+ # define model & load checkpoint
+ modnet = modnet_onnx.MODNet(backbone_pretrained=False)
+ modnet = nn.DataParallel(modnet).cuda()
+ state_dict = torch.load(args.ckpt_path)
+ modnet.load_state_dict(state_dict)
+ modnet.eval()
+
+ # prepare dummy_input
+ batch_size = 1
+ height = 512
+ width = 512
+ dummy_input = Variable(torch.randn(batch_size, 3, height, width)).cuda()
+
+ # export to onnx model
+ torch.onnx.export(
+ modnet.module, dummy_input, args.output_path, export_params = True,
+ input_names = ['input'], output_names = ['output'],
+ dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, 'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
diff --git a/demo/image_matting/Inference_with_ONNX/inference_onnx.py b/onnx/inference_onnx.py
similarity index 73%
rename from demo/image_matting/Inference_with_ONNX/inference_onnx.py
rename to onnx/inference_onnx.py
index cccfa23..d1bd47e 100644
--- a/demo/image_matting/Inference_with_ONNX/inference_onnx.py
+++ b/onnx/inference_onnx.py
@@ -1,43 +1,40 @@
"""
-Inference with onnxruntime
+Inference ONNX model of MODNet
Arguments:
- --image-path --> path to single input image
- --output-path --> paht to save generated matte
- --model-path --> path to onnx model file
+ --image-path: path of the input image (a file)
+ --output-path: path for saving the predicted alpha matte (a file)
+ --model-path: path of the ONNX model
-example:
+Example:
python inference_onnx.py \
- --image-path=demo.jpg \
- --output-path=matte.png \
- --model-path=modnet.onnx
-
-Optional:
-Generate transparent image without background
+ --image-path=demo.jpg --output-path=matte.png --model-path=modnet.onnx
"""
+
import os
-import argparse
import cv2
+import argparse
import numpy as np
+from PIL import Image
+
import onnx
import onnxruntime
-from onnx import helper
-from PIL import Image
+
if __name__ == '__main__':
# define cmd arguments
parser = argparse.ArgumentParser()
- parser.add_argument('--image-path', type=str, help='path of input image')
- parser.add_argument('--output-path', type=str, help='path of output image')
- parser.add_argument('--model-path', type=str, help='path of onnx model')
+ parser.add_argument('--image-path', type=str, help='path of the input image (a file)')
+ parser.add_argument('--output-path', type=str, help='paht for saving the predicted alpha matte (a file)')
+ parser.add_argument('--model-path', type=str, help='path of the ONNX model')
args = parser.parse_args()
# check input arguments
if not os.path.exists(args.image_path):
- print('Cannot find input path: {0}'.format(args.image_path))
+ print('Cannot find the input image: {0}'.format(args.image_path))
exit()
if not os.path.exists(args.model_path):
- print('Cannot find model path: {0}'.format(args.model_path))
+ print('Cannot find the ONXX model: {0}'.format(args.model_path))
exit()
ref_size = 512
@@ -105,12 +102,3 @@ if __name__ == '__main__':
matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA)
cv2.imwrite(args.output_path, matte)
-
- ##############################################
- # Optional - save png image without background
- ##############################################
-
- # im_PIL = Image.open(args.image_path)
- # matte = Image.fromarray(matte)
- # im_PIL.putalpha(matte) # add alpha channel to keep transparency
- # im_PIL.save('without_background.png')
\ No newline at end of file
diff --git a/src/models/onnx_modnet.py b/onnx/modnet_onnx.py
similarity index 95%
rename from src/models/onnx_modnet.py
rename to onnx/modnet_onnx.py
index 6fa5a41..1f48341 100644
--- a/src/models/onnx_modnet.py
+++ b/onnx/modnet_onnx.py
@@ -1,18 +1,16 @@
"""
-This file is a modified version of the original file modnet.py without
-"pred_semantic" and "pred_details" as these both returns None when "inference = True"
+This file contains a modified version of the original file `modnet.py` without
+`pred_semantic` and `pred_details` as these both returns None when `inference=True`
-And it does not contain "inference" argument which will make it easier to
-convert checkpoint into onnx model.
-
-Refer: 'demo/image_matting/inference_with_ONNX/export_modnet_onnx.py' to export model.
+And it does not contain `inference` argument which will make it easier to
+convert checkpoint to ONNX model.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
-from .backbones import SUPPORTED_BACKBONES
+from src.models.backbones import SUPPORTED_BACKBONES
#------------------------------------------------------------------------------
diff --git a/demo/image_matting/Inference_with_ONNX/requirements.txt b/onnx/requirements.txt
similarity index 100%
rename from demo/image_matting/Inference_with_ONNX/requirements.txt
rename to onnx/requirements.txt