mirror of https://github.com/ZHKKKe/MODNet.git
115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
import os
|
|
import cv2
|
|
import argparse
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms as transforms
|
|
|
|
from src.models.modnet import MODNet
|
|
|
|
|
|
torch_transforms = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
]
|
|
)
|
|
|
|
|
|
def matting(video, result, alpha_matte=False, fps=30):
|
|
# video capture
|
|
vc = cv2.VideoCapture(video)
|
|
|
|
if vc.isOpened():
|
|
rval, frame = vc.read()
|
|
else:
|
|
rval = False
|
|
|
|
if not rval:
|
|
print('Failed to read the video: {0}'.format(video))
|
|
exit()
|
|
|
|
num_frame = vc.get(cv2.CAP_PROP_FRAME_COUNT)
|
|
h, w = frame.shape[:2]
|
|
if w >= h:
|
|
rh = 512
|
|
rw = int(w / h * 512)
|
|
else:
|
|
rw = 512
|
|
rh = int(h / w * 512)
|
|
rh = rh - rh % 32
|
|
rw = rw - rw % 32
|
|
|
|
# video writer
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
video_writer = cv2.VideoWriter(result, fourcc, fps, (w, h))
|
|
|
|
print('Start matting...')
|
|
with tqdm(range(int(num_frame)))as t:
|
|
for c in t:
|
|
frame_np = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frame_np = cv2.resize(frame_np, (rw, rh), cv2.INTER_AREA)
|
|
|
|
frame_PIL = Image.fromarray(frame_np)
|
|
frame_tensor = torch_transforms(frame_PIL)
|
|
frame_tensor = frame_tensor[None, :, :, :]
|
|
if GPU:
|
|
frame_tensor = frame_tensor.cuda()
|
|
|
|
with torch.no_grad():
|
|
_, _, matte_tensor = modnet(frame_tensor, True)
|
|
|
|
matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
|
|
matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
|
|
if alpha_matte:
|
|
view_np = matte_np * np.full(frame_np.shape, 255.0)
|
|
else:
|
|
view_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0)
|
|
view_np = cv2.cvtColor(view_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
|
|
view_np = cv2.resize(view_np, (w, h))
|
|
video_writer.write(view_np)
|
|
|
|
rval, frame = vc.read()
|
|
c += 1
|
|
|
|
video_writer.release()
|
|
print('Save the result video to {0}'.format(result))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--video', type=str, required=True, help='input video file')
|
|
parser.add_argument('--result-type', type=str, default='fg', choices=['fg', 'matte'],
|
|
help='matte - save the alpha matte; fg - save the foreground')
|
|
parser.add_argument('--fps', type=int, default=30, help='fps of the result video')
|
|
|
|
print('Get CMD Arguments...')
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.video):
|
|
print('Cannot find the input video: {0}'.format(args.video))
|
|
exit()
|
|
|
|
print('Load pre-trained MODNet...')
|
|
pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
|
|
modnet = MODNet(backbone_pretrained=False)
|
|
modnet = nn.DataParallel(modnet)
|
|
|
|
GPU = True if torch.cuda.device_count() > 0 else False
|
|
if GPU:
|
|
print('Use GPU...')
|
|
modnet = modnet.cuda()
|
|
modnet.load_state_dict(torch.load(pretrained_ckpt))
|
|
else:
|
|
print('Use CPU...')
|
|
modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
|
|
modnet.eval()
|
|
|
|
result = os.path.splitext(args.video)[0] + '_{0}.mp4'.format(args.result_type)
|
|
alpha_matte = True if args.result_type == 'matte' else False
|
|
matting(args.video, result, alpha_matte, args.fps)
|