allow to use MODNet with CPU only

pull/161/head
Daniel Manzke 2022-01-03 20:43:28 +01:00
parent b258b36481
commit 9956841ea4
1 changed files with 10 additions and 3 deletions

View File

@ -44,8 +44,15 @@ if __name__ == '__main__':
# create MODNet and load the pre-trained ckpt # create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=False) modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda() modnet = nn.DataParallel(modnet)
modnet.load_state_dict(torch.load(args.ckpt_path)) cpu_mode = False
if torch.cuda.is_available():
modnet = modnet.cuda()
weights = torch.load(args.ckpt_path)
else:
cpu_mode = True
weights = torch.load(args.ckpt_path, map_location=torch.device('cpu'))
modnet.load_state_dict(weights)
modnet.eval() modnet.eval()
# inference images # inference images
@ -90,7 +97,7 @@ if __name__ == '__main__':
im = F.interpolate(im, size=(im_rh, im_rw), mode='area') im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
# inference # inference
_, _, matte = modnet(im.cuda(), True) _, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True)
# resize and save matte # resize and save matte
matte = F.interpolate(matte, size=(im_h, im_w), mode='area') matte = F.interpolate(matte, size=(im_h, im_w), mode='area')