diff --git a/src/eval.py b/src/eval.py new file mode 100644 index 0000000..e9e5305 --- /dev/null +++ b/src/eval.py @@ -0,0 +1,72 @@ +import numpy as np +from glob import glob +from src.models.modnet import MODNet +from PIL import Image +from src.infer import predit_matte +import torch.nn as nn +import torch + + +def cal_mad(pred, gt): + diff = pred - gt + diff = np.abs(diff) + mad = np.mean(diff) + return mad + + +def cal_mse(pred, gt): + diff = pred - gt + diff = diff ** 2 + mse = np.mean(diff) + return mse + + +def load_eval_dataset(dataset_root_dir='src/datasets/PPM-100'): + image_path = dataset_root_dir + '/image/*' + matte_path = dataset_root_dir + '/matte/*' + image_file_name_list = glob(image_path) + image_file_name_list = sorted(image_file_name_list) + matte_file_name_list = glob(matte_path) + matte_file_name_list = sorted(matte_file_name_list) + + return image_file_name_list, matte_file_name_list + + +def eval(modnet: MODNet, dataset): + mse = total_mse = 0.0 + mad = total_mad = 0.0 + cnt = 0 + + for im_pth, mt_pth in zip(dataset[0], dataset[1]): + im = Image.open(im_pth) + pd_matte = predit_matte(modnet, im) + + gt_matte = Image.open(mt_pth) + gt_matte = np.asarray(gt_matte) / 255 + + total_mse += cal_mse(pd_matte, gt_matte) + total_mad += cal_mad(pd_matte, gt_matte) + + cnt += 1 + if cnt > 0: + mse = total_mse / cnt + mad = total_mad / cnt + + return mse, mad + + +if __name__ == '__main__': + # create MODNet and load the pre-trained ckpt + modnet = MODNet(backbone_pretrained=False) + modnet = nn.DataParallel(modnet) + + ckp_pth = 'pretrained/modnet_photographic_portrait_matting.ckpt' + if torch.cuda.is_available(): + modnet = modnet.cuda() + weights = torch.load(ckp_pth) + else: + weights = torch.load(ckp_pth, map_location=torch.device('cpu')) + modnet.load_state_dict(weights) + dataset = load_eval_dataset('src/datasets/PPM-100') + mse, mad = eval(modnet, dataset) + print(f'mse: {mse:6f}, mad: {mad:6f}') diff --git a/src/infer.py b/src/infer.py new file mode 100644 index 0000000..1fe99cc --- /dev/null +++ b/src/infer.py @@ -0,0 +1,85 @@ +from src.models.modnet import MODNet +from PIL import Image +import numpy as np +from torchvision import transforms +import torch +import torch.nn.functional as F +import torch.nn as nn + + +def predit_matte(modnet: MODNet, im: Image): + # define image to tensor transform + im_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ] + ) + + # define hyper-parameters + ref_size = 512 + + modnet.eval() + + # unify image channels to 3 + im = np.asarray(im) + if len(im.shape) == 2: + im = im[:, :, None] + if im.shape[2] == 1: + im = np.repeat(im, 3, axis=2) + elif im.shape[2] == 4: + im = im[:, :, 0:3] + + im = Image.fromarray(im) + # convert image to PyTorch tensor + im = im_transform(im) + + # add mini-batch dim + im = im[None, :, :, :] + + # resize image for input + im_b, im_c, im_h, im_w = im.shape + if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: + if im_w >= im_h: + im_rh = ref_size + im_rw = int(im_w / im_h * ref_size) + elif im_w < im_h: + im_rw = ref_size + im_rh = int(im_h / im_w * ref_size) + else: + im_rh = im_h + im_rw = im_w + + im_rw = im_rw - im_rw % 32 + im_rh = im_rh - im_rh % 32 + im = F.interpolate(im, size=(im_rh, im_rw), mode='area') + + # inference + _, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True) + + # resize and save matte + matte = F.interpolate(matte, size=(im_h, im_w), mode='area') + matte = matte[0][0].data.cpu().numpy() + return matte + + +if __name__ == '__main__': + # create MODNet and load the pre-trained ckpt + modnet = MODNet(backbone_pretrained=False) + modnet = nn.DataParallel(modnet) + + ckp_pth = 'pretrained/modnet_photographic_portrait_matting.ckpt' + if torch.cuda.is_available(): + modnet = modnet.cuda() + weights = torch.load(ckp_pth) + else: + weights = torch.load(ckp_pth, map_location=torch.device('cpu')) + modnet.load_state_dict(weights) + + pth = 'src/datasets/PPM-100/image/13179159164_1a4ae8d085_o.jpg' + img = Image.open(pth) + + matte = predit_matte(modnet, img) + prd_img = Image.fromarray(((matte * 255).astype('uint8')), mode='L') + prd_img.save('test_predic.jpg') +