removed unnecessary flag

pull/161/head
Daniel Manzke 2022-01-03 20:46:20 +01:00
parent 9956841ea4
commit 73f4b49457
1 changed files with 1 additions and 2 deletions

View File

@ -45,12 +45,11 @@ if __name__ == '__main__':
# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
cpu_mode = False
if torch.cuda.is_available():
modnet = modnet.cuda()
weights = torch.load(args.ckpt_path)
else:
cpu_mode = True
weights = torch.load(args.ckpt_path, map_location=torch.device('cpu'))
modnet.load_state_dict(weights)
modnet.eval()