MODNet/TorchScript/export_torchscript.py

43 lines
1.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from . import modnet_torchscript
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
parser.add_argument('--out-dir', type=str, required=True, help='path for saving the TorchScript model')
args = parser.parse_args()
# check input arguments
if not os.path.exists(args.ckpt_path):
print('Cannot find checkpoint path: {0}'.format(args.ckpt_path))
exit()
if not os.path.exists(args.out_dir):
os.mkdir(args.out_dir)
# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=True)
# modnet = nn.DataParallel(modnet).cuda()
modnet = modnet.cuda()
ckpt = torch.load(args.ckpt)
# if use more than one GPU
if 'module.' in ckpt.keys():
ckpt = OrderedDict()
for k, v in ckpt.items():
k = k.replace('module.', '')
ckpt[k] = v
modnet.load_state_dict(ckpt)
modnet.eval()
scripted_model = torch.jit.script(modnet)
torch.jit.save(scripted_model, os.path.join(args.out_dir,'modnet.pt'))