mirror of https://github.com/ZHKKKe/MODNet.git
(#161) - Make the file `demo/image_matting/colab/inference.py` can run on CPU
Make the file `demo/image_matting/colab/inference.py` can run on CPUpull/167/head
parent
b258b36481
commit
97c441a8cb
|
|
@ -44,8 +44,14 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# create MODNet and load the pre-trained ckpt
|
# create MODNet and load the pre-trained ckpt
|
||||||
modnet = MODNet(backbone_pretrained=False)
|
modnet = MODNet(backbone_pretrained=False)
|
||||||
modnet = nn.DataParallel(modnet).cuda()
|
modnet = nn.DataParallel(modnet)
|
||||||
modnet.load_state_dict(torch.load(args.ckpt_path))
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
modnet = modnet.cuda()
|
||||||
|
weights = torch.load(args.ckpt_path)
|
||||||
|
else:
|
||||||
|
weights = torch.load(args.ckpt_path, map_location=torch.device('cpu'))
|
||||||
|
modnet.load_state_dict(weights)
|
||||||
modnet.eval()
|
modnet.eval()
|
||||||
|
|
||||||
# inference images
|
# inference images
|
||||||
|
|
@ -90,7 +96,7 @@ if __name__ == '__main__':
|
||||||
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
|
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
|
||||||
|
|
||||||
# inference
|
# inference
|
||||||
_, _, matte = modnet(im.cuda(), True)
|
_, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True)
|
||||||
|
|
||||||
# resize and save matte
|
# resize and save matte
|
||||||
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
|
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue