From 4424ab5b7279658dc0a0c03ab35084f2c0359893 Mon Sep 17 00:00:00 2001 From: actboy Date: Thu, 10 Feb 2022 00:06:40 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/matting_dataset.py | 80 +++++++++++++++++++++++++------------------------- src/trainer.py | 30 +++++++++++++++---- 2 files changed, 64 insertions(+), 46 deletions(-) diff --git a/src/matting_dataset.py b/src/matting_dataset.py index 5eeba9b..b411cfd 100644 --- a/src/matting_dataset.py +++ b/src/matting_dataset.py @@ -1,4 +1,4 @@ -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset import numpy as np import random import cv2 @@ -36,35 +36,13 @@ class MattingDataset(Dataset): image = Image.open(image_file_name) matte = Image.open(matte_file_name) - # matte = matte.convert('RGB') - trimap = self.gen_trimap(matte) - data = {'image': image, 'trimap': trimap, 'gt_matte': matte} + data = {'image': image, 'gt_matte': matte} if self.transform: data = self.transform(data) return data - @staticmethod - def gen_trimap(matte): - """ - 根据归matte生成归一化的trimap - """ - matte = np.array(matte) - k_size = random.choice(range(2, 5)) - iterations = np.random.randint(5, 15) - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, - (k_size, k_size)) - dilated = cv2.dilate(matte, kernel, iterations=iterations) - eroded = cv2.erode(matte, kernel, iterations=iterations) - - trimap = np.zeros(matte.shape) - trimap.fill(128) - trimap[eroded > 254.5] = 255 - trimap[dilated < 0.5] = 0 - trimap = Image.fromarray(np.uint8(trimap)) - return trimap - class Rescale(object): """Rescale the image in a sample to a given size. @@ -80,22 +58,41 @@ class Rescale(object): self.output_size = output_size def __call__(self, sample): - image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte'] + image, gt_matte = sample['image'], sample['gt_matte'] - # w, h = image.size - # if h > w: - # new_h, new_w = self.output_size * h / w, self.output_size - # else: - # new_h, new_w = self.output_size, self.output_size * w / h - - # new_h, new_w = int(new_h), int(new_w) new_h, new_w = int(self.output_size), int(self.output_size) - new_img = F.resize(image, (new_h, new_w)) - new_trimap = F.resize(trimap, (new_h, new_w)) new_gt_matte = F.resize(gt_matte, (new_h, new_w)) - return {'image': new_img, 'trimap': new_trimap, 'gt_matte': new_gt_matte} + return {'image': new_img,'gt_matte': new_gt_matte} + + +class GenTrimap(object): + def __call__(self, sample): + gt_matte = sample['gt_matte'] + trimap = self.gen_trimap(gt_matte) + sample['trimap'] = trimap + return sample + + @staticmethod + def gen_trimap(matte): + """ + 根据归matte生成归一化的trimap + """ + matte = np.array(matte) + k_size = random.choice(range(2, 5)) + iterations = np.random.randint(5, 15) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (k_size, k_size)) + dilated = cv2.dilate(matte, kernel, iterations=iterations) + eroded = cv2.erode(matte, kernel, iterations=iterations) + + trimap = np.zeros(matte.shape) + trimap.fill(0.5) + trimap[eroded > 254.5] = 1.0 + trimap[dilated < 0.5] = 0.0 + trimap = Image.fromarray(trimap) + return trimap class ToTensor(object): @@ -106,6 +103,8 @@ class ToTensor(object): image = F.pil_to_tensor(image) trimap = F.pil_to_tensor(trimap) gt_matte = F.pil_to_tensor(gt_matte) + + return {'image': image, 'trimap': trimap, 'gt_matte': gt_matte} @@ -132,8 +131,6 @@ class Normalize(object): image = image.type(torch.FloatTensor) image = F.normalize(image, self.mean, self.std, self.inplace) sample['image'] = image - - sample['trimap'] = trimap / 255 # 归一化 return sample @@ -152,13 +149,16 @@ if __name__ == '__main__': # test MattingDataset.gen_trimap matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg') - matte = matte.convert('RGB') - trimap = MattingDataset.gen_trimap(matte) - trimap.save('test_trimap.png') + trimap1 = GenTrimap().gen_trimap(matte) + trimap1 = np.array(trimap1) * 255 + trimap1 = np.uint8(trimap1) + trimap1 = Image.fromarray(trimap1) + trimap1.save('test_trimap.png') # test MattingDataset transform = transforms.Compose([ Rescale(512), + GenTrimap(), ToTensor(), # Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) diff --git a/src/trainer.py b/src/trainer.py index 02292c9..aa71ef0 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -301,12 +301,15 @@ def soc_adaptation_iter( if __name__ == '__main__': - from matting_dataset import MattingDataset, Rescale, ToTensor, Normalize, ToTrainArray, ConvertImageDtype + from matting_dataset import MattingDataset, Rescale, \ + ToTensor, Normalize, ToTrainArray, \ + ConvertImageDtype, GenTrimap from torchvision import transforms from torch.utils.data import DataLoader from models.modnet import MODNet transform = transforms.Compose([ Rescale(512), + GenTrimap(), ToTensor(), ConvertImageDtype(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), @@ -314,7 +317,7 @@ if __name__ == '__main__': ]) mattingDataset = MattingDataset(transform=transform) - bs = 16 # batch size + bs = 4 # batch size lr = 0.01 # learn rate epochs = 40 # total epochs @@ -327,13 +330,28 @@ if __name__ == '__main__': dataloader = DataLoader(mattingDataset, batch_size=bs, - shuffle=True, - num_workers=0) + shuffle=True) for epoch in range(0, epochs): for idx, (image, trimap, gt_matte) in enumerate(dataloader): semantic_loss, detail_loss, matte_loss = \ supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) break - lr_scheduler.step() - break + if epoch % 4 == 0 and epoch > 1: + torch.save({ + 'epoch': epoch, + 'model_state_dict': modnet.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss}, + }, f'pretrained/modnet_custom_portrait_matting_{epoch+1}.ckpt') + lr_scheduler.step() + print(f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}') + if epoch == 4: + break + + torch.save({ + 'epoch': epochs, + 'model_state_dict': modnet.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss}, + }, f'pretrained/modnet_custom_portrait_matting_last_epoch.ckpt')