mirror of https://github.com/ZHKKKe/MODNet.git
评价与推理
parent
4424ab5b72
commit
48a2b96e71
|
|
@ -0,0 +1,72 @@
|
|||
import numpy as np
|
||||
from glob import glob
|
||||
from src.models.modnet import MODNet
|
||||
from PIL import Image
|
||||
from src.infer import predit_matte
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
def cal_mad(pred, gt):
|
||||
diff = pred - gt
|
||||
diff = np.abs(diff)
|
||||
mad = np.mean(diff)
|
||||
return mad
|
||||
|
||||
|
||||
def cal_mse(pred, gt):
|
||||
diff = pred - gt
|
||||
diff = diff ** 2
|
||||
mse = np.mean(diff)
|
||||
return mse
|
||||
|
||||
|
||||
def load_eval_dataset(dataset_root_dir='src/datasets/PPM-100'):
|
||||
image_path = dataset_root_dir + '/image/*'
|
||||
matte_path = dataset_root_dir + '/matte/*'
|
||||
image_file_name_list = glob(image_path)
|
||||
image_file_name_list = sorted(image_file_name_list)
|
||||
matte_file_name_list = glob(matte_path)
|
||||
matte_file_name_list = sorted(matte_file_name_list)
|
||||
|
||||
return image_file_name_list, matte_file_name_list
|
||||
|
||||
|
||||
def eval(modnet: MODNet, dataset):
|
||||
mse = total_mse = 0.0
|
||||
mad = total_mad = 0.0
|
||||
cnt = 0
|
||||
|
||||
for im_pth, mt_pth in zip(dataset[0], dataset[1]):
|
||||
im = Image.open(im_pth)
|
||||
pd_matte = predit_matte(modnet, im)
|
||||
|
||||
gt_matte = Image.open(mt_pth)
|
||||
gt_matte = np.asarray(gt_matte) / 255
|
||||
|
||||
total_mse += cal_mse(pd_matte, gt_matte)
|
||||
total_mad += cal_mad(pd_matte, gt_matte)
|
||||
|
||||
cnt += 1
|
||||
if cnt > 0:
|
||||
mse = total_mse / cnt
|
||||
mad = total_mad / cnt
|
||||
|
||||
return mse, mad
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# create MODNet and load the pre-trained ckpt
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
modnet = nn.DataParallel(modnet)
|
||||
|
||||
ckp_pth = 'pretrained/modnet_photographic_portrait_matting.ckpt'
|
||||
if torch.cuda.is_available():
|
||||
modnet = modnet.cuda()
|
||||
weights = torch.load(ckp_pth)
|
||||
else:
|
||||
weights = torch.load(ckp_pth, map_location=torch.device('cpu'))
|
||||
modnet.load_state_dict(weights)
|
||||
dataset = load_eval_dataset('src/datasets/PPM-100')
|
||||
mse, mad = eval(modnet, dataset)
|
||||
print(f'mse: {mse:6f}, mad: {mad:6f}')
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
from src.models.modnet import MODNet
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def predit_matte(modnet: MODNet, im: Image):
|
||||
# define image to tensor transform
|
||||
im_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
|
||||
# define hyper-parameters
|
||||
ref_size = 512
|
||||
|
||||
modnet.eval()
|
||||
|
||||
# unify image channels to 3
|
||||
im = np.asarray(im)
|
||||
if len(im.shape) == 2:
|
||||
im = im[:, :, None]
|
||||
if im.shape[2] == 1:
|
||||
im = np.repeat(im, 3, axis=2)
|
||||
elif im.shape[2] == 4:
|
||||
im = im[:, :, 0:3]
|
||||
|
||||
im = Image.fromarray(im)
|
||||
# convert image to PyTorch tensor
|
||||
im = im_transform(im)
|
||||
|
||||
# add mini-batch dim
|
||||
im = im[None, :, :, :]
|
||||
|
||||
# resize image for input
|
||||
im_b, im_c, im_h, im_w = im.shape
|
||||
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
|
||||
if im_w >= im_h:
|
||||
im_rh = ref_size
|
||||
im_rw = int(im_w / im_h * ref_size)
|
||||
elif im_w < im_h:
|
||||
im_rw = ref_size
|
||||
im_rh = int(im_h / im_w * ref_size)
|
||||
else:
|
||||
im_rh = im_h
|
||||
im_rw = im_w
|
||||
|
||||
im_rw = im_rw - im_rw % 32
|
||||
im_rh = im_rh - im_rh % 32
|
||||
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
|
||||
|
||||
# inference
|
||||
_, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True)
|
||||
|
||||
# resize and save matte
|
||||
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
|
||||
matte = matte[0][0].data.cpu().numpy()
|
||||
return matte
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# create MODNet and load the pre-trained ckpt
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
modnet = nn.DataParallel(modnet)
|
||||
|
||||
ckp_pth = 'pretrained/modnet_photographic_portrait_matting.ckpt'
|
||||
if torch.cuda.is_available():
|
||||
modnet = modnet.cuda()
|
||||
weights = torch.load(ckp_pth)
|
||||
else:
|
||||
weights = torch.load(ckp_pth, map_location=torch.device('cpu'))
|
||||
modnet.load_state_dict(weights)
|
||||
|
||||
pth = 'src/datasets/PPM-100/image/13179159164_1a4ae8d085_o.jpg'
|
||||
img = Image.open(pth)
|
||||
|
||||
matte = predit_matte(modnet, img)
|
||||
prd_img = Image.fromarray(((matte * 255).astype('uint8')), mode='L')
|
||||
prd_img.save('test_predic.jpg')
|
||||
|
||||
Loading…
Reference in New Issue