mirror of https://github.com/ZHKKKe/MODNet.git
评价与推理
parent
4424ab5b72
commit
48a2b96e71
|
|
@ -0,0 +1,72 @@
|
||||||
|
import numpy as np
|
||||||
|
from glob import glob
|
||||||
|
from src.models.modnet import MODNet
|
||||||
|
from PIL import Image
|
||||||
|
from src.infer import predit_matte
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def cal_mad(pred, gt):
|
||||||
|
diff = pred - gt
|
||||||
|
diff = np.abs(diff)
|
||||||
|
mad = np.mean(diff)
|
||||||
|
return mad
|
||||||
|
|
||||||
|
|
||||||
|
def cal_mse(pred, gt):
|
||||||
|
diff = pred - gt
|
||||||
|
diff = diff ** 2
|
||||||
|
mse = np.mean(diff)
|
||||||
|
return mse
|
||||||
|
|
||||||
|
|
||||||
|
def load_eval_dataset(dataset_root_dir='src/datasets/PPM-100'):
|
||||||
|
image_path = dataset_root_dir + '/image/*'
|
||||||
|
matte_path = dataset_root_dir + '/matte/*'
|
||||||
|
image_file_name_list = glob(image_path)
|
||||||
|
image_file_name_list = sorted(image_file_name_list)
|
||||||
|
matte_file_name_list = glob(matte_path)
|
||||||
|
matte_file_name_list = sorted(matte_file_name_list)
|
||||||
|
|
||||||
|
return image_file_name_list, matte_file_name_list
|
||||||
|
|
||||||
|
|
||||||
|
def eval(modnet: MODNet, dataset):
|
||||||
|
mse = total_mse = 0.0
|
||||||
|
mad = total_mad = 0.0
|
||||||
|
cnt = 0
|
||||||
|
|
||||||
|
for im_pth, mt_pth in zip(dataset[0], dataset[1]):
|
||||||
|
im = Image.open(im_pth)
|
||||||
|
pd_matte = predit_matte(modnet, im)
|
||||||
|
|
||||||
|
gt_matte = Image.open(mt_pth)
|
||||||
|
gt_matte = np.asarray(gt_matte) / 255
|
||||||
|
|
||||||
|
total_mse += cal_mse(pd_matte, gt_matte)
|
||||||
|
total_mad += cal_mad(pd_matte, gt_matte)
|
||||||
|
|
||||||
|
cnt += 1
|
||||||
|
if cnt > 0:
|
||||||
|
mse = total_mse / cnt
|
||||||
|
mad = total_mad / cnt
|
||||||
|
|
||||||
|
return mse, mad
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# create MODNet and load the pre-trained ckpt
|
||||||
|
modnet = MODNet(backbone_pretrained=False)
|
||||||
|
modnet = nn.DataParallel(modnet)
|
||||||
|
|
||||||
|
ckp_pth = 'pretrained/modnet_photographic_portrait_matting.ckpt'
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
modnet = modnet.cuda()
|
||||||
|
weights = torch.load(ckp_pth)
|
||||||
|
else:
|
||||||
|
weights = torch.load(ckp_pth, map_location=torch.device('cpu'))
|
||||||
|
modnet.load_state_dict(weights)
|
||||||
|
dataset = load_eval_dataset('src/datasets/PPM-100')
|
||||||
|
mse, mad = eval(modnet, dataset)
|
||||||
|
print(f'mse: {mse:6f}, mad: {mad:6f}')
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
from src.models.modnet import MODNet
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from torchvision import transforms
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def predit_matte(modnet: MODNet, im: Image):
|
||||||
|
# define image to tensor transform
|
||||||
|
im_transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# define hyper-parameters
|
||||||
|
ref_size = 512
|
||||||
|
|
||||||
|
modnet.eval()
|
||||||
|
|
||||||
|
# unify image channels to 3
|
||||||
|
im = np.asarray(im)
|
||||||
|
if len(im.shape) == 2:
|
||||||
|
im = im[:, :, None]
|
||||||
|
if im.shape[2] == 1:
|
||||||
|
im = np.repeat(im, 3, axis=2)
|
||||||
|
elif im.shape[2] == 4:
|
||||||
|
im = im[:, :, 0:3]
|
||||||
|
|
||||||
|
im = Image.fromarray(im)
|
||||||
|
# convert image to PyTorch tensor
|
||||||
|
im = im_transform(im)
|
||||||
|
|
||||||
|
# add mini-batch dim
|
||||||
|
im = im[None, :, :, :]
|
||||||
|
|
||||||
|
# resize image for input
|
||||||
|
im_b, im_c, im_h, im_w = im.shape
|
||||||
|
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
|
||||||
|
if im_w >= im_h:
|
||||||
|
im_rh = ref_size
|
||||||
|
im_rw = int(im_w / im_h * ref_size)
|
||||||
|
elif im_w < im_h:
|
||||||
|
im_rw = ref_size
|
||||||
|
im_rh = int(im_h / im_w * ref_size)
|
||||||
|
else:
|
||||||
|
im_rh = im_h
|
||||||
|
im_rw = im_w
|
||||||
|
|
||||||
|
im_rw = im_rw - im_rw % 32
|
||||||
|
im_rh = im_rh - im_rh % 32
|
||||||
|
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
|
||||||
|
|
||||||
|
# inference
|
||||||
|
_, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True)
|
||||||
|
|
||||||
|
# resize and save matte
|
||||||
|
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
|
||||||
|
matte = matte[0][0].data.cpu().numpy()
|
||||||
|
return matte
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# create MODNet and load the pre-trained ckpt
|
||||||
|
modnet = MODNet(backbone_pretrained=False)
|
||||||
|
modnet = nn.DataParallel(modnet)
|
||||||
|
|
||||||
|
ckp_pth = 'pretrained/modnet_photographic_portrait_matting.ckpt'
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
modnet = modnet.cuda()
|
||||||
|
weights = torch.load(ckp_pth)
|
||||||
|
else:
|
||||||
|
weights = torch.load(ckp_pth, map_location=torch.device('cpu'))
|
||||||
|
modnet.load_state_dict(weights)
|
||||||
|
|
||||||
|
pth = 'src/datasets/PPM-100/image/13179159164_1a4ae8d085_o.jpg'
|
||||||
|
img = Image.open(pth)
|
||||||
|
|
||||||
|
matte = predit_matte(modnet, img)
|
||||||
|
prd_img = Image.fromarray(((matte * 255).astype('uint8')), mode='L')
|
||||||
|
prd_img.save('test_predic.jpg')
|
||||||
|
|
||||||
Loading…
Reference in New Issue