mirror of https://github.com/ZHKKKe/MODNet.git
完善模型训练代码
parent
5f673d5a34
commit
58c54f6e6c
|
|
@ -0,0 +1,196 @@
|
|||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
import random
|
||||
import cv2
|
||||
from glob import glob
|
||||
import torch
|
||||
from torchvision.transforms import functional as F
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class MattingDataset(Dataset):
|
||||
def __init__(self,
|
||||
dataset_root_dir='src/datasets/PPM-100',
|
||||
transform=None):
|
||||
image_path = dataset_root_dir + '/image/*'
|
||||
matte_path = dataset_root_dir + '/matte/*'
|
||||
image_file_name_list = glob(image_path)
|
||||
matte_file_name_list = glob(matte_path)
|
||||
|
||||
self.image_file_name_list = sorted(image_file_name_list)
|
||||
self.matte_file_name_list = sorted(matte_file_name_list)
|
||||
for img, mat in zip(self.image_file_name_list, self.matte_file_name_list):
|
||||
img_name = img.split('/')[-1]
|
||||
mat_name = mat.split('/')[-1]
|
||||
assert img_name == mat_name
|
||||
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_file_name_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_file_name = self.image_file_name_list[index]
|
||||
matte_file_name = self.matte_file_name_list[index]
|
||||
|
||||
image = Image.open(image_file_name)
|
||||
matte = Image.open(matte_file_name)
|
||||
# matte = matte.convert('RGB')
|
||||
trimap = self.gen_trimap(matte)
|
||||
|
||||
data = {'image': image, 'trimap': trimap, 'gt_matte': matte}
|
||||
|
||||
if self.transform:
|
||||
data = self.transform(data)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def gen_trimap(matte):
|
||||
"""
|
||||
根据归matte生成归一化的trimap
|
||||
"""
|
||||
matte = np.array(matte)
|
||||
k_size = random.choice(range(2, 5))
|
||||
iterations = np.random.randint(5, 15)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
||||
(k_size, k_size))
|
||||
dilated = cv2.dilate(matte, kernel, iterations=iterations)
|
||||
eroded = cv2.erode(matte, kernel, iterations=iterations)
|
||||
|
||||
trimap = np.zeros(matte.shape)
|
||||
trimap.fill(128)
|
||||
trimap[eroded > 254.5] = 255
|
||||
trimap[dilated < 0.5] = 0
|
||||
trimap = Image.fromarray(np.uint8(trimap))
|
||||
return trimap
|
||||
|
||||
|
||||
class Rescale(object):
|
||||
"""Rescale the image in a sample to a given size.
|
||||
|
||||
Args:
|
||||
output_size (tuple or int): Desired output size. If tuple, output is
|
||||
matched to output_size. If int, smaller of image edges is matched
|
||||
to output_size keeping aspect ratio the same.
|
||||
"""
|
||||
|
||||
def __init__(self, output_size):
|
||||
assert isinstance(output_size, (int, tuple))
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||
|
||||
# w, h = image.size
|
||||
# if h > w:
|
||||
# new_h, new_w = self.output_size * h / w, self.output_size
|
||||
# else:
|
||||
# new_h, new_w = self.output_size, self.output_size * w / h
|
||||
|
||||
# new_h, new_w = int(new_h), int(new_w)
|
||||
new_h, new_w = int(self.output_size), int(self.output_size)
|
||||
|
||||
new_img = F.resize(image, (new_h, new_w))
|
||||
new_trimap = F.resize(trimap, (new_h, new_w))
|
||||
new_gt_matte = F.resize(gt_matte, (new_h, new_w))
|
||||
|
||||
return {'image': new_img, 'trimap': new_trimap, 'gt_matte': new_gt_matte}
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
"""Convert ndarrays in sample to Tensors."""
|
||||
|
||||
def __call__(self, sample):
|
||||
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||
image = F.pil_to_tensor(image)
|
||||
trimap = F.pil_to_tensor(trimap)
|
||||
gt_matte = F.pil_to_tensor(gt_matte)
|
||||
return {'image': image,
|
||||
'trimap': trimap,
|
||||
'gt_matte': gt_matte}
|
||||
|
||||
|
||||
class ConvertImageDtype(object):
|
||||
def __call__(self, sample):
|
||||
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||
image = F.convert_image_dtype(image, torch.float)
|
||||
trimap = F.convert_image_dtype(trimap, torch.float)
|
||||
gt_matte = F.convert_image_dtype(gt_matte, torch.float)
|
||||
|
||||
return {'image': image, 'trimap': trimap, 'gt_matte': gt_matte}
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
def __init__(self, mean, std, inplace=False):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.inplace = inplace
|
||||
|
||||
def __call__(self, sample):
|
||||
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||
image = image.type(torch.FloatTensor)
|
||||
image = F.normalize(image, self.mean, self.std, self.inplace)
|
||||
sample['image'] = image
|
||||
|
||||
sample['trimap'] = trimap / 255 # 归一化
|
||||
return sample
|
||||
|
||||
|
||||
class ToTrainArray(object):
|
||||
def __call__(self, sample):
|
||||
return [sample['image'], sample['trimap'], sample['gt_matte']]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# test MattingDataset.gen_trimap
|
||||
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
|
||||
matte = matte.convert('RGB')
|
||||
trimap = MattingDataset.gen_trimap(matte)
|
||||
trimap.save('test_trimap.png')
|
||||
|
||||
# test MattingDataset
|
||||
transform = transforms.Compose([
|
||||
Rescale(512),
|
||||
ToTensor(),
|
||||
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
mattingDataset = MattingDataset(transform=transform)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig = plt.figure()
|
||||
|
||||
for i in range(len(mattingDataset)):
|
||||
sample = mattingDataset[i]
|
||||
print(mattingDataset.image_file_name_list[i])
|
||||
# print(sample)
|
||||
print(i, sample['image'].shape, sample['trimap'].shape, sample['gt_matte'].shape)
|
||||
|
||||
# break
|
||||
|
||||
ax = plt.subplot(4, 3, 3 * i + 1)
|
||||
plt.tight_layout()
|
||||
ax.set_title('image #{}'.format(i))
|
||||
ax.axis('off')
|
||||
img = transforms.ToPILImage()(sample['image'])
|
||||
plt.imshow(img)
|
||||
|
||||
ax = plt.subplot(4, 3, 3 * i + 2)
|
||||
plt.tight_layout()
|
||||
ax.set_title('gt_matte #{}'.format(i))
|
||||
ax.axis('off')
|
||||
img = transforms.ToPILImage()(sample['gt_matte'])
|
||||
plt.imshow(img)
|
||||
|
||||
ax = plt.subplot(4, 3, 3 * i + 3)
|
||||
plt.tight_layout()
|
||||
ax.set_title('trimap #{}'.format(i))
|
||||
ax.axis('off')
|
||||
img = transforms.ToPILImage()(sample['trimap'])
|
||||
plt.imshow(img)
|
||||
|
||||
if i == 3:
|
||||
plt.show()
|
||||
break
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
torch==1.10.2
|
||||
scipy==1.7.3
|
||||
numpy==1.21.5
|
||||
opencv-python==4.5.5.62
|
||||
matplotlib==3.5.1
|
||||
torchvision==0.11.3
|
||||
Pillow==9.0.1
|
||||
|
|
@ -7,7 +7,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
__all__ = [
|
||||
'supervised_training_iter',
|
||||
'soc_adaptation_iter',
|
||||
|
|
@ -80,7 +79,7 @@ class GaussianBlurLayer(nn.Module):
|
|||
# MODNet Training Functions
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
blurer = GaussianBlurLayer(1, 3).cuda()
|
||||
blurer = GaussianBlurLayer(1, 3) #.cuda()
|
||||
|
||||
|
||||
def supervised_training_iter(
|
||||
|
|
@ -297,3 +296,39 @@ def soc_adaptation_iter(
|
|||
return soc_semantic_loss, soc_detail_loss
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from matting_dataset import MattingDataset, Rescale, ToTensor, Normalize, ToTrainArray, ConvertImageDtype
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from models.modnet import MODNet
|
||||
transform = transforms.Compose([
|
||||
Rescale(512),
|
||||
ToTensor(),
|
||||
ConvertImageDtype(),
|
||||
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
ToTrainArray()
|
||||
])
|
||||
mattingDataset = MattingDataset(transform=transform)
|
||||
|
||||
bs = 16 # batch size
|
||||
lr = 0.01 # learn rate
|
||||
epochs = 40 # total epochs
|
||||
|
||||
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
|
||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
|
||||
|
||||
dataloader = DataLoader(mattingDataset,
|
||||
batch_size=bs,
|
||||
shuffle=True,
|
||||
num_workers=0)
|
||||
|
||||
for epoch in range(0, epochs):
|
||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||
semantic_loss, detail_loss, matte_loss = \
|
||||
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
|
||||
break
|
||||
lr_scheduler.step()
|
||||
break
|
||||
|
|
|
|||
Loading…
Reference in New Issue