diff --git a/demo/image_matting/colab/inference.py b/demo/image_matting/colab/inference.py index d7423cf..f9a15e0 100644 --- a/demo/image_matting/colab/inference.py +++ b/demo/image_matting/colab/inference.py @@ -45,12 +45,11 @@ if __name__ == '__main__': # create MODNet and load the pre-trained ckpt modnet = MODNet(backbone_pretrained=False) modnet = nn.DataParallel(modnet) - cpu_mode = False + if torch.cuda.is_available(): modnet = modnet.cuda() weights = torch.load(args.ckpt_path) else: - cpu_mode = True weights = torch.load(args.ckpt_path, map_location=torch.device('cpu')) modnet.load_state_dict(weights) modnet.eval()