From 58c54f6e6c1773c8fe650d82aa09df26d72a985c Mon Sep 17 00:00:00 2001 From: actboy Date: Wed, 9 Feb 2022 20:05:26 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=A8=A1=E5=9E=8B=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/matting_dataset.py | 196 +++++++++++++++++++++++++++++++++++++++++++++++++ src/requirements.txt | 7 ++ src/trainer.py | 39 +++++++++- 3 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 src/matting_dataset.py create mode 100644 src/requirements.txt diff --git a/src/matting_dataset.py b/src/matting_dataset.py new file mode 100644 index 0000000..82ca78a --- /dev/null +++ b/src/matting_dataset.py @@ -0,0 +1,196 @@ +from torch.utils.data import Dataset, DataLoader +import numpy as np +import random +import cv2 +from glob import glob +import torch +from torchvision.transforms import functional as F +from torchvision import transforms +from PIL import Image + + +class MattingDataset(Dataset): + def __init__(self, + dataset_root_dir='src/datasets/PPM-100', + transform=None): + image_path = dataset_root_dir + '/image/*' + matte_path = dataset_root_dir + '/matte/*' + image_file_name_list = glob(image_path) + matte_file_name_list = glob(matte_path) + + self.image_file_name_list = sorted(image_file_name_list) + self.matte_file_name_list = sorted(matte_file_name_list) + for img, mat in zip(self.image_file_name_list, self.matte_file_name_list): + img_name = img.split('/')[-1] + mat_name = mat.split('/')[-1] + assert img_name == mat_name + + self.transform = transform + + def __len__(self): + return len(self.image_file_name_list) + + def __getitem__(self, index): + image_file_name = self.image_file_name_list[index] + matte_file_name = self.matte_file_name_list[index] + + image = Image.open(image_file_name) + matte = Image.open(matte_file_name) + # matte = matte.convert('RGB') + trimap = self.gen_trimap(matte) + + data = {'image': image, 'trimap': trimap, 'gt_matte': matte} + + if self.transform: + data = self.transform(data) + return data + + @staticmethod + def gen_trimap(matte): + """ + 根据归matte生成归一化的trimap + """ + matte = np.array(matte) + k_size = random.choice(range(2, 5)) + iterations = np.random.randint(5, 15) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (k_size, k_size)) + dilated = cv2.dilate(matte, kernel, iterations=iterations) + eroded = cv2.erode(matte, kernel, iterations=iterations) + + trimap = np.zeros(matte.shape) + trimap.fill(128) + trimap[eroded > 254.5] = 255 + trimap[dilated < 0.5] = 0 + trimap = Image.fromarray(np.uint8(trimap)) + return trimap + + +class Rescale(object): + """Rescale the image in a sample to a given size. + + Args: + output_size (tuple or int): Desired output size. If tuple, output is + matched to output_size. If int, smaller of image edges is matched + to output_size keeping aspect ratio the same. + """ + + def __init__(self, output_size): + assert isinstance(output_size, (int, tuple)) + self.output_size = output_size + + def __call__(self, sample): + image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte'] + + # w, h = image.size + # if h > w: + # new_h, new_w = self.output_size * h / w, self.output_size + # else: + # new_h, new_w = self.output_size, self.output_size * w / h + + # new_h, new_w = int(new_h), int(new_w) + new_h, new_w = int(self.output_size), int(self.output_size) + + new_img = F.resize(image, (new_h, new_w)) + new_trimap = F.resize(trimap, (new_h, new_w)) + new_gt_matte = F.resize(gt_matte, (new_h, new_w)) + + return {'image': new_img, 'trimap': new_trimap, 'gt_matte': new_gt_matte} + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte'] + image = F.pil_to_tensor(image) + trimap = F.pil_to_tensor(trimap) + gt_matte = F.pil_to_tensor(gt_matte) + return {'image': image, + 'trimap': trimap, + 'gt_matte': gt_matte} + + +class ConvertImageDtype(object): + def __call__(self, sample): + image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte'] + image = F.convert_image_dtype(image, torch.float) + trimap = F.convert_image_dtype(trimap, torch.float) + gt_matte = F.convert_image_dtype(gt_matte, torch.float) + + return {'image': image, 'trimap': trimap, 'gt_matte': gt_matte} + + +class Normalize(object): + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, sample): + image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte'] + image = image.type(torch.FloatTensor) + image = F.normalize(image, self.mean, self.std, self.inplace) + sample['image'] = image + + sample['trimap'] = trimap / 255 # 归一化 + return sample + + +class ToTrainArray(object): + def __call__(self, sample): + return [sample['image'], sample['trimap'], sample['gt_matte']] + + +if __name__ == '__main__': + + # test MattingDataset.gen_trimap + matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg') + matte = matte.convert('RGB') + trimap = MattingDataset.gen_trimap(matte) + trimap.save('test_trimap.png') + + # test MattingDataset + transform = transforms.Compose([ + Rescale(512), + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + mattingDataset = MattingDataset(transform=transform) + + import matplotlib.pyplot as plt + + fig = plt.figure() + + for i in range(len(mattingDataset)): + sample = mattingDataset[i] + print(mattingDataset.image_file_name_list[i]) + # print(sample) + print(i, sample['image'].shape, sample['trimap'].shape, sample['gt_matte'].shape) + + # break + + ax = plt.subplot(4, 3, 3 * i + 1) + plt.tight_layout() + ax.set_title('image #{}'.format(i)) + ax.axis('off') + img = transforms.ToPILImage()(sample['image']) + plt.imshow(img) + + ax = plt.subplot(4, 3, 3 * i + 2) + plt.tight_layout() + ax.set_title('gt_matte #{}'.format(i)) + ax.axis('off') + img = transforms.ToPILImage()(sample['gt_matte']) + plt.imshow(img) + + ax = plt.subplot(4, 3, 3 * i + 3) + plt.tight_layout() + ax.set_title('trimap #{}'.format(i)) + ax.axis('off') + img = transforms.ToPILImage()(sample['trimap']) + plt.imshow(img) + + if i == 3: + plt.show() + break diff --git a/src/requirements.txt b/src/requirements.txt new file mode 100644 index 0000000..d1b497d --- /dev/null +++ b/src/requirements.txt @@ -0,0 +1,7 @@ +torch==1.10.2 +scipy==1.7.3 +numpy==1.21.5 +opencv-python==4.5.5.62 +matplotlib==3.5.1 +torchvision==0.11.3 +Pillow==9.0.1 \ No newline at end of file diff --git a/src/trainer.py b/src/trainer.py index bd3d8be..102fe53 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - __all__ = [ 'supervised_training_iter', 'soc_adaptation_iter', @@ -80,7 +79,7 @@ class GaussianBlurLayer(nn.Module): # MODNet Training Functions # ---------------------------------------------------------------------------------- -blurer = GaussianBlurLayer(1, 3).cuda() +blurer = GaussianBlurLayer(1, 3) #.cuda() def supervised_training_iter( @@ -297,3 +296,39 @@ def soc_adaptation_iter( return soc_semantic_loss, soc_detail_loss # ---------------------------------------------------------------------------------- + + +if __name__ == '__main__': + from matting_dataset import MattingDataset, Rescale, ToTensor, Normalize, ToTrainArray, ConvertImageDtype + from torchvision import transforms + from torch.utils.data import DataLoader + from models.modnet import MODNet + transform = transforms.Compose([ + Rescale(512), + ToTensor(), + ConvertImageDtype(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ToTrainArray() + ]) + mattingDataset = MattingDataset(transform=transform) + + bs = 16 # batch size + lr = 0.01 # learn rate + epochs = 40 # total epochs + + modnet = torch.nn.DataParallel(MODNet()) #.cuda() + optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1) + + dataloader = DataLoader(mattingDataset, + batch_size=bs, + shuffle=True, + num_workers=0) + + for epoch in range(0, epochs): + for idx, (image, trimap, gt_matte) in enumerate(dataloader): + semantic_loss, detail_loss, matte_loss = \ + supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) + break + lr_scheduler.step() + break