From 97c441a8cb79647b74ccea978aaf3b1ec3ba9afc Mon Sep 17 00:00:00 2001 From: Daniel Manzke Date: Mon, 10 Jan 2022 17:53:34 +0100 Subject: [PATCH] (#161) - Make the file `demo/image_matting/colab/inference.py` can run on CPU Make the file `demo/image_matting/colab/inference.py` can run on CPU --- demo/image_matting/colab/inference.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/demo/image_matting/colab/inference.py b/demo/image_matting/colab/inference.py index c4f280b..f9a15e0 100644 --- a/demo/image_matting/colab/inference.py +++ b/demo/image_matting/colab/inference.py @@ -44,8 +44,14 @@ if __name__ == '__main__': # create MODNet and load the pre-trained ckpt modnet = MODNet(backbone_pretrained=False) - modnet = nn.DataParallel(modnet).cuda() - modnet.load_state_dict(torch.load(args.ckpt_path)) + modnet = nn.DataParallel(modnet) + + if torch.cuda.is_available(): + modnet = modnet.cuda() + weights = torch.load(args.ckpt_path) + else: + weights = torch.load(args.ckpt_path, map_location=torch.device('cpu')) + modnet.load_state_dict(weights) modnet.eval() # inference images @@ -90,7 +96,7 @@ if __name__ == '__main__': im = F.interpolate(im, size=(im_rh, im_rw), mode='area') # inference - _, _, matte = modnet(im.cuda(), True) + _, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True) # resize and save matte matte = F.interpolate(matte, size=(im_h, im_w), mode='area')