diff --git a/demo/image_matting/colab/inference.py b/demo/image_matting/colab/inference.py index c4f280b..d7423cf 100644 --- a/demo/image_matting/colab/inference.py +++ b/demo/image_matting/colab/inference.py @@ -44,8 +44,15 @@ if __name__ == '__main__': # create MODNet and load the pre-trained ckpt modnet = MODNet(backbone_pretrained=False) - modnet = nn.DataParallel(modnet).cuda() - modnet.load_state_dict(torch.load(args.ckpt_path)) + modnet = nn.DataParallel(modnet) + cpu_mode = False + if torch.cuda.is_available(): + modnet = modnet.cuda() + weights = torch.load(args.ckpt_path) + else: + cpu_mode = True + weights = torch.load(args.ckpt_path, map_location=torch.device('cpu')) + modnet.load_state_dict(weights) modnet.eval() # inference images @@ -90,7 +97,7 @@ if __name__ == '__main__': im = F.interpolate(im, size=(im_rh, im_rw), mode='area') # inference - _, _, matte = modnet(im.cuda(), True) + _, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True) # resize and save matte matte = F.interpolate(matte, size=(im_h, im_w), mode='area')