add cpu inference for almost laptop usage

pull/21/head
tkianai 2020-12-14 18:55:38 +08:00
parent c51ece7232
commit b7c07bff64
1 changed files with 13 additions and 4 deletions

View File

@ -1,4 +1,6 @@
import cv2
import argparse
import numpy as np
from PIL import Image
@ -8,6 +10,11 @@ import torchvision.transforms as transforms
from src.models.modnet import MODNet
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt-path', type=str, default="./pretrained/modnet_webcam_portrait_matting.ckpt", help='path of pre-trained MODNet')
parser.add_argument('--cpu', action='store_true', default=False, help="use cpu inferece")
args = parser.parse_args()
torch_transforms = transforms.Compose(
[
@ -17,10 +24,12 @@ torch_transforms = transforms.Compose(
)
print('Load pre-trained MODNet...')
pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
pretrained_ckpt = args.ckpt_path
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda()
modnet.load_state_dict(torch.load(pretrained_ckpt))
modnet = nn.DataParallel(modnet)
modnet.load_state_dict(torch.load(pretrained_ckpt, map_location='cpu'))
if not args.cpu:
modnet.cuda()
modnet.eval()
print('Init WebCam...')
@ -38,7 +47,7 @@ while(True):
frame_PIL = Image.fromarray(frame_np)
frame_tensor = torch_transforms(frame_PIL)
frame_tensor = frame_tensor[None, :, :, :].cuda()
frame_tensor = frame_tensor[None, :, :, :] if args.cpu else frame_tensor[None, :, :, :].cuda()
with torch.no_grad():
_, _, matte_tensor = modnet(frame_tensor, inference=True)