mirror of https://github.com/ZHKKKe/MODNet.git
add cpu inference for almost laptop usage
parent
c51ece7232
commit
b7c07bff64
|
|
@ -1,4 +1,6 @@
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
@ -8,6 +10,11 @@ import torchvision.transforms as transforms
|
||||||
|
|
||||||
from src.models.modnet import MODNet
|
from src.models.modnet import MODNet
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--ckpt-path', type=str, default="./pretrained/modnet_webcam_portrait_matting.ckpt", help='path of pre-trained MODNet')
|
||||||
|
parser.add_argument('--cpu', action='store_true', default=False, help="use cpu inferece")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
torch_transforms = transforms.Compose(
|
torch_transforms = transforms.Compose(
|
||||||
[
|
[
|
||||||
|
|
@ -17,10 +24,12 @@ torch_transforms = transforms.Compose(
|
||||||
)
|
)
|
||||||
|
|
||||||
print('Load pre-trained MODNet...')
|
print('Load pre-trained MODNet...')
|
||||||
pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
|
pretrained_ckpt = args.ckpt_path
|
||||||
modnet = MODNet(backbone_pretrained=False)
|
modnet = MODNet(backbone_pretrained=False)
|
||||||
modnet = nn.DataParallel(modnet).cuda()
|
modnet = nn.DataParallel(modnet)
|
||||||
modnet.load_state_dict(torch.load(pretrained_ckpt))
|
modnet.load_state_dict(torch.load(pretrained_ckpt, map_location='cpu'))
|
||||||
|
if not args.cpu:
|
||||||
|
modnet.cuda()
|
||||||
modnet.eval()
|
modnet.eval()
|
||||||
|
|
||||||
print('Init WebCam...')
|
print('Init WebCam...')
|
||||||
|
|
@ -38,7 +47,7 @@ while(True):
|
||||||
|
|
||||||
frame_PIL = Image.fromarray(frame_np)
|
frame_PIL = Image.fromarray(frame_np)
|
||||||
frame_tensor = torch_transforms(frame_PIL)
|
frame_tensor = torch_transforms(frame_PIL)
|
||||||
frame_tensor = frame_tensor[None, :, :, :].cuda()
|
frame_tensor = frame_tensor[None, :, :, :] if args.cpu else frame_tensor[None, :, :, :].cuda()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, _, matte_tensor = modnet(frame_tensor, inference=True)
|
_, _, matte_tensor = modnet(frame_tensor, inference=True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue