Merge pull request #36 from yzhou0919/master

add demo/video_matting/offline
pull/78/head
Zhanghan Ke 2020-12-25 12:42:24 +08:00 committed by GitHub
commit 52a17f7335
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 0 deletions

View File

@ -0,0 +1,6 @@
numpy
Pillow
opencv-python
torch >= 1.0.0
torchvision
tqdm

View File

@ -0,0 +1,106 @@
import os
import cv2
import numpy as np
from PIL import Image
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from src.models.modnet import MODNet
torch_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
print('Load pre-trained MODNet...')
pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
GPU = True if torch.cuda.device_count() > 0 else False
if GPU:
print('Use GPU...')
modnet = modnet.cuda()
modnet.load_state_dict(torch.load(pretrained_ckpt))
else:
print('Use CPU...')
modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
modnet.eval()
def offline_matting(video_path, save_path, alpha_matte=False, fps=30):
# video capture
vc = cv2.VideoCapture(video_path)
if vc.isOpened():
rval, frame = vc.read()
else:
rval = False
if not rval:
print('Read video {} failed.'.format(video_path))
exit()
num_frame = vc.get(cv2.CAP_PROP_FRAME_COUNT)
h, w = frame.shape[:2]
# video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(save_path, fourcc, fps, (w, h))
print('Start matting...')
with tqdm(range(int(num_frame)))as t:
for c in t:
frame_np = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_np = cv2.resize(frame_np, (672, 512), cv2.INTER_AREA)
frame_PIL = Image.fromarray(frame_np)
frame_tensor = torch_transforms(frame_PIL)
frame_tensor = frame_tensor[None, :, :, :]
if GPU:
frame_tensor = frame_tensor.cuda()
with torch.no_grad():
_, _, matte_tensor = modnet(frame_tensor, True)
matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
if alpha_matte:
view_np = matte_np * np.full(frame_np.shape, 255.0)
else:
view_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0)
view_np = cv2.cvtColor(view_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
view_np = cv2.resize(view_np, (w, h))
video_writer.write(view_np)
rval, frame = vc.read()
c += 1
video_writer.release()
print('Save video to {}'.format(save_path))
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--video_path', type=str, default='./sample/video.mp4')
parser.add_argument('--save_path', type=str, default='./sample/res.mp4', help='Video should be .mp4 format.')
parser.add_argument('--alpha_matte', action='store_true', default=False, help='If True, save alpha_matte video.')
parser.add_argument('--fps', type=int, default=30)
args = parser.parse_args()
if not args.save_path.endswith('.mp4'):
args.save_path = os.path.splitext(args.save_path)[0] + '.mp4'
offline_matting(args.video_path, args.save_path, args.alpha_matte, args.fps)