mirror of https://github.com/ZHKKKe/MODNet.git
(#161) - Make the file `demo/image_matting/colab/inference.py` can run on CPU
Make the file `demo/image_matting/colab/inference.py` can run on CPUpull/167/head
parent
b258b36481
commit
97c441a8cb
|
|
@ -44,8 +44,14 @@ if __name__ == '__main__':
|
|||
|
||||
# create MODNet and load the pre-trained ckpt
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
modnet = nn.DataParallel(modnet).cuda()
|
||||
modnet.load_state_dict(torch.load(args.ckpt_path))
|
||||
modnet = nn.DataParallel(modnet)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
modnet = modnet.cuda()
|
||||
weights = torch.load(args.ckpt_path)
|
||||
else:
|
||||
weights = torch.load(args.ckpt_path, map_location=torch.device('cpu'))
|
||||
modnet.load_state_dict(weights)
|
||||
modnet.eval()
|
||||
|
||||
# inference images
|
||||
|
|
@ -90,7 +96,7 @@ if __name__ == '__main__':
|
|||
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
|
||||
|
||||
# inference
|
||||
_, _, matte = modnet(im.cuda(), True)
|
||||
_, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True)
|
||||
|
||||
# resize and save matte
|
||||
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
|
||||
|
|
|
|||
Loading…
Reference in New Issue