diff --git a/demo/video_matting/webcam.py b/demo/video_matting/webcam.py index fc09aeb..1bf5438 100644 --- a/demo/video_matting/webcam.py +++ b/demo/video_matting/webcam.py @@ -1,4 +1,6 @@ + import cv2 +import argparse import numpy as np from PIL import Image @@ -8,6 +10,11 @@ import torchvision.transforms as transforms from src.models.modnet import MODNet +parser = argparse.ArgumentParser() +parser.add_argument('--ckpt-path', type=str, default="./pretrained/modnet_webcam_portrait_matting.ckpt", help='path of pre-trained MODNet') +parser.add_argument('--cpu', action='store_true', default=False, help="use cpu inferece") +args = parser.parse_args() + torch_transforms = transforms.Compose( [ @@ -17,10 +24,12 @@ torch_transforms = transforms.Compose( ) print('Load pre-trained MODNet...') -pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt' +pretrained_ckpt = args.ckpt_path modnet = MODNet(backbone_pretrained=False) -modnet = nn.DataParallel(modnet).cuda() -modnet.load_state_dict(torch.load(pretrained_ckpt)) +modnet = nn.DataParallel(modnet) +modnet.load_state_dict(torch.load(pretrained_ckpt, map_location='cpu')) +if not args.cpu: + modnet.cuda() modnet.eval() print('Init WebCam...') @@ -38,7 +47,7 @@ while(True): frame_PIL = Image.fromarray(frame_np) frame_tensor = torch_transforms(frame_PIL) - frame_tensor = frame_tensor[None, :, :, :].cuda() + frame_tensor = frame_tensor[None, :, :, :] if args.cpu else frame_tensor[None, :, :, :].cuda() with torch.no_grad(): _, _, matte_tensor = modnet(frame_tensor, inference=True)