完善模型训练代码

pull/177/head
actboy 2022-02-09 20:05:26 +08:00
parent 5f673d5a34
commit 58c54f6e6c
3 changed files with 240 additions and 2 deletions

196
src/matting_dataset.py Normal file
View File

@ -0,0 +1,196 @@
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import cv2
from glob import glob
import torch
from torchvision.transforms import functional as F
from torchvision import transforms
from PIL import Image
class MattingDataset(Dataset):
def __init__(self,
dataset_root_dir='src/datasets/PPM-100',
transform=None):
image_path = dataset_root_dir + '/image/*'
matte_path = dataset_root_dir + '/matte/*'
image_file_name_list = glob(image_path)
matte_file_name_list = glob(matte_path)
self.image_file_name_list = sorted(image_file_name_list)
self.matte_file_name_list = sorted(matte_file_name_list)
for img, mat in zip(self.image_file_name_list, self.matte_file_name_list):
img_name = img.split('/')[-1]
mat_name = mat.split('/')[-1]
assert img_name == mat_name
self.transform = transform
def __len__(self):
return len(self.image_file_name_list)
def __getitem__(self, index):
image_file_name = self.image_file_name_list[index]
matte_file_name = self.matte_file_name_list[index]
image = Image.open(image_file_name)
matte = Image.open(matte_file_name)
# matte = matte.convert('RGB')
trimap = self.gen_trimap(matte)
data = {'image': image, 'trimap': trimap, 'gt_matte': matte}
if self.transform:
data = self.transform(data)
return data
@staticmethod
def gen_trimap(matte):
"""
根据归matte生成归一化的trimap
"""
matte = np.array(matte)
k_size = random.choice(range(2, 5))
iterations = np.random.randint(5, 15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(k_size, k_size))
dilated = cv2.dilate(matte, kernel, iterations=iterations)
eroded = cv2.erode(matte, kernel, iterations=iterations)
trimap = np.zeros(matte.shape)
trimap.fill(128)
trimap[eroded > 254.5] = 255
trimap[dilated < 0.5] = 0
trimap = Image.fromarray(np.uint8(trimap))
return trimap
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
# w, h = image.size
# if h > w:
# new_h, new_w = self.output_size * h / w, self.output_size
# else:
# new_h, new_w = self.output_size, self.output_size * w / h
# new_h, new_w = int(new_h), int(new_w)
new_h, new_w = int(self.output_size), int(self.output_size)
new_img = F.resize(image, (new_h, new_w))
new_trimap = F.resize(trimap, (new_h, new_w))
new_gt_matte = F.resize(gt_matte, (new_h, new_w))
return {'image': new_img, 'trimap': new_trimap, 'gt_matte': new_gt_matte}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
image = F.pil_to_tensor(image)
trimap = F.pil_to_tensor(trimap)
gt_matte = F.pil_to_tensor(gt_matte)
return {'image': image,
'trimap': trimap,
'gt_matte': gt_matte}
class ConvertImageDtype(object):
def __call__(self, sample):
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
image = F.convert_image_dtype(image, torch.float)
trimap = F.convert_image_dtype(trimap, torch.float)
gt_matte = F.convert_image_dtype(gt_matte, torch.float)
return {'image': image, 'trimap': trimap, 'gt_matte': gt_matte}
class Normalize(object):
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, sample):
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
image = image.type(torch.FloatTensor)
image = F.normalize(image, self.mean, self.std, self.inplace)
sample['image'] = image
sample['trimap'] = trimap / 255 # 归一化
return sample
class ToTrainArray(object):
def __call__(self, sample):
return [sample['image'], sample['trimap'], sample['gt_matte']]
if __name__ == '__main__':
# test MattingDataset.gen_trimap
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
matte = matte.convert('RGB')
trimap = MattingDataset.gen_trimap(matte)
trimap.save('test_trimap.png')
# test MattingDataset
transform = transforms.Compose([
Rescale(512),
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
mattingDataset = MattingDataset(transform=transform)
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(len(mattingDataset)):
sample = mattingDataset[i]
print(mattingDataset.image_file_name_list[i])
# print(sample)
print(i, sample['image'].shape, sample['trimap'].shape, sample['gt_matte'].shape)
# break
ax = plt.subplot(4, 3, 3 * i + 1)
plt.tight_layout()
ax.set_title('image #{}'.format(i))
ax.axis('off')
img = transforms.ToPILImage()(sample['image'])
plt.imshow(img)
ax = plt.subplot(4, 3, 3 * i + 2)
plt.tight_layout()
ax.set_title('gt_matte #{}'.format(i))
ax.axis('off')
img = transforms.ToPILImage()(sample['gt_matte'])
plt.imshow(img)
ax = plt.subplot(4, 3, 3 * i + 3)
plt.tight_layout()
ax.set_title('trimap #{}'.format(i))
ax.axis('off')
img = transforms.ToPILImage()(sample['trimap'])
plt.imshow(img)
if i == 3:
plt.show()
break

7
src/requirements.txt Normal file
View File

@ -0,0 +1,7 @@
torch==1.10.2
scipy==1.7.3
numpy==1.21.5
opencv-python==4.5.5.62
matplotlib==3.5.1
torchvision==0.11.3
Pillow==9.0.1

View File

@ -7,7 +7,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
__all__ = [ __all__ = [
'supervised_training_iter', 'supervised_training_iter',
'soc_adaptation_iter', 'soc_adaptation_iter',
@ -80,7 +79,7 @@ class GaussianBlurLayer(nn.Module):
# MODNet Training Functions # MODNet Training Functions
# ---------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------
blurer = GaussianBlurLayer(1, 3).cuda() blurer = GaussianBlurLayer(1, 3) #.cuda()
def supervised_training_iter( def supervised_training_iter(
@ -297,3 +296,39 @@ def soc_adaptation_iter(
return soc_semantic_loss, soc_detail_loss return soc_semantic_loss, soc_detail_loss
# ---------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------
if __name__ == '__main__':
from matting_dataset import MattingDataset, Rescale, ToTensor, Normalize, ToTrainArray, ConvertImageDtype
from torchvision import transforms
from torch.utils.data import DataLoader
from models.modnet import MODNet
transform = transforms.Compose([
Rescale(512),
ToTensor(),
ConvertImageDtype(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ToTrainArray()
])
mattingDataset = MattingDataset(transform=transform)
bs = 16 # batch size
lr = 0.01 # learn rate
epochs = 40 # total epochs
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
dataloader = DataLoader(mattingDataset,
batch_size=bs,
shuffle=True,
num_workers=0)
for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss = \
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
break
lr_scheduler.step()
break