mirror of https://github.com/ZHKKKe/MODNet.git
完善模型训练代码
parent
5f673d5a34
commit
58c54f6e6c
|
|
@ -0,0 +1,196 @@
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import cv2
|
||||||
|
from glob import glob
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class MattingDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
dataset_root_dir='src/datasets/PPM-100',
|
||||||
|
transform=None):
|
||||||
|
image_path = dataset_root_dir + '/image/*'
|
||||||
|
matte_path = dataset_root_dir + '/matte/*'
|
||||||
|
image_file_name_list = glob(image_path)
|
||||||
|
matte_file_name_list = glob(matte_path)
|
||||||
|
|
||||||
|
self.image_file_name_list = sorted(image_file_name_list)
|
||||||
|
self.matte_file_name_list = sorted(matte_file_name_list)
|
||||||
|
for img, mat in zip(self.image_file_name_list, self.matte_file_name_list):
|
||||||
|
img_name = img.split('/')[-1]
|
||||||
|
mat_name = mat.split('/')[-1]
|
||||||
|
assert img_name == mat_name
|
||||||
|
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_file_name_list)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
image_file_name = self.image_file_name_list[index]
|
||||||
|
matte_file_name = self.matte_file_name_list[index]
|
||||||
|
|
||||||
|
image = Image.open(image_file_name)
|
||||||
|
matte = Image.open(matte_file_name)
|
||||||
|
# matte = matte.convert('RGB')
|
||||||
|
trimap = self.gen_trimap(matte)
|
||||||
|
|
||||||
|
data = {'image': image, 'trimap': trimap, 'gt_matte': matte}
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
data = self.transform(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_trimap(matte):
|
||||||
|
"""
|
||||||
|
根据归matte生成归一化的trimap
|
||||||
|
"""
|
||||||
|
matte = np.array(matte)
|
||||||
|
k_size = random.choice(range(2, 5))
|
||||||
|
iterations = np.random.randint(5, 15)
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
||||||
|
(k_size, k_size))
|
||||||
|
dilated = cv2.dilate(matte, kernel, iterations=iterations)
|
||||||
|
eroded = cv2.erode(matte, kernel, iterations=iterations)
|
||||||
|
|
||||||
|
trimap = np.zeros(matte.shape)
|
||||||
|
trimap.fill(128)
|
||||||
|
trimap[eroded > 254.5] = 255
|
||||||
|
trimap[dilated < 0.5] = 0
|
||||||
|
trimap = Image.fromarray(np.uint8(trimap))
|
||||||
|
return trimap
|
||||||
|
|
||||||
|
|
||||||
|
class Rescale(object):
|
||||||
|
"""Rescale the image in a sample to a given size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_size (tuple or int): Desired output size. If tuple, output is
|
||||||
|
matched to output_size. If int, smaller of image edges is matched
|
||||||
|
to output_size keeping aspect ratio the same.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, output_size):
|
||||||
|
assert isinstance(output_size, (int, tuple))
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||||
|
|
||||||
|
# w, h = image.size
|
||||||
|
# if h > w:
|
||||||
|
# new_h, new_w = self.output_size * h / w, self.output_size
|
||||||
|
# else:
|
||||||
|
# new_h, new_w = self.output_size, self.output_size * w / h
|
||||||
|
|
||||||
|
# new_h, new_w = int(new_h), int(new_w)
|
||||||
|
new_h, new_w = int(self.output_size), int(self.output_size)
|
||||||
|
|
||||||
|
new_img = F.resize(image, (new_h, new_w))
|
||||||
|
new_trimap = F.resize(trimap, (new_h, new_w))
|
||||||
|
new_gt_matte = F.resize(gt_matte, (new_h, new_w))
|
||||||
|
|
||||||
|
return {'image': new_img, 'trimap': new_trimap, 'gt_matte': new_gt_matte}
|
||||||
|
|
||||||
|
|
||||||
|
class ToTensor(object):
|
||||||
|
"""Convert ndarrays in sample to Tensors."""
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||||
|
image = F.pil_to_tensor(image)
|
||||||
|
trimap = F.pil_to_tensor(trimap)
|
||||||
|
gt_matte = F.pil_to_tensor(gt_matte)
|
||||||
|
return {'image': image,
|
||||||
|
'trimap': trimap,
|
||||||
|
'gt_matte': gt_matte}
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertImageDtype(object):
|
||||||
|
def __call__(self, sample):
|
||||||
|
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||||
|
image = F.convert_image_dtype(image, torch.float)
|
||||||
|
trimap = F.convert_image_dtype(trimap, torch.float)
|
||||||
|
gt_matte = F.convert_image_dtype(gt_matte, torch.float)
|
||||||
|
|
||||||
|
return {'image': image, 'trimap': trimap, 'gt_matte': gt_matte}
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(object):
|
||||||
|
def __init__(self, mean, std, inplace=False):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
self.inplace = inplace
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||||
|
image = image.type(torch.FloatTensor)
|
||||||
|
image = F.normalize(image, self.mean, self.std, self.inplace)
|
||||||
|
sample['image'] = image
|
||||||
|
|
||||||
|
sample['trimap'] = trimap / 255 # 归一化
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class ToTrainArray(object):
|
||||||
|
def __call__(self, sample):
|
||||||
|
return [sample['image'], sample['trimap'], sample['gt_matte']]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# test MattingDataset.gen_trimap
|
||||||
|
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
|
||||||
|
matte = matte.convert('RGB')
|
||||||
|
trimap = MattingDataset.gen_trimap(matte)
|
||||||
|
trimap.save('test_trimap.png')
|
||||||
|
|
||||||
|
# test MattingDataset
|
||||||
|
transform = transforms.Compose([
|
||||||
|
Rescale(512),
|
||||||
|
ToTensor(),
|
||||||
|
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
|
])
|
||||||
|
mattingDataset = MattingDataset(transform=transform)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig = plt.figure()
|
||||||
|
|
||||||
|
for i in range(len(mattingDataset)):
|
||||||
|
sample = mattingDataset[i]
|
||||||
|
print(mattingDataset.image_file_name_list[i])
|
||||||
|
# print(sample)
|
||||||
|
print(i, sample['image'].shape, sample['trimap'].shape, sample['gt_matte'].shape)
|
||||||
|
|
||||||
|
# break
|
||||||
|
|
||||||
|
ax = plt.subplot(4, 3, 3 * i + 1)
|
||||||
|
plt.tight_layout()
|
||||||
|
ax.set_title('image #{}'.format(i))
|
||||||
|
ax.axis('off')
|
||||||
|
img = transforms.ToPILImage()(sample['image'])
|
||||||
|
plt.imshow(img)
|
||||||
|
|
||||||
|
ax = plt.subplot(4, 3, 3 * i + 2)
|
||||||
|
plt.tight_layout()
|
||||||
|
ax.set_title('gt_matte #{}'.format(i))
|
||||||
|
ax.axis('off')
|
||||||
|
img = transforms.ToPILImage()(sample['gt_matte'])
|
||||||
|
plt.imshow(img)
|
||||||
|
|
||||||
|
ax = plt.subplot(4, 3, 3 * i + 3)
|
||||||
|
plt.tight_layout()
|
||||||
|
ax.set_title('trimap #{}'.format(i))
|
||||||
|
ax.axis('off')
|
||||||
|
img = transforms.ToPILImage()(sample['trimap'])
|
||||||
|
plt.imshow(img)
|
||||||
|
|
||||||
|
if i == 3:
|
||||||
|
plt.show()
|
||||||
|
break
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
torch==1.10.2
|
||||||
|
scipy==1.7.3
|
||||||
|
numpy==1.21.5
|
||||||
|
opencv-python==4.5.5.62
|
||||||
|
matplotlib==3.5.1
|
||||||
|
torchvision==0.11.3
|
||||||
|
Pillow==9.0.1
|
||||||
|
|
@ -7,7 +7,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'supervised_training_iter',
|
'supervised_training_iter',
|
||||||
'soc_adaptation_iter',
|
'soc_adaptation_iter',
|
||||||
|
|
@ -80,7 +79,7 @@ class GaussianBlurLayer(nn.Module):
|
||||||
# MODNet Training Functions
|
# MODNet Training Functions
|
||||||
# ----------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------
|
||||||
|
|
||||||
blurer = GaussianBlurLayer(1, 3).cuda()
|
blurer = GaussianBlurLayer(1, 3) #.cuda()
|
||||||
|
|
||||||
|
|
||||||
def supervised_training_iter(
|
def supervised_training_iter(
|
||||||
|
|
@ -297,3 +296,39 @@ def soc_adaptation_iter(
|
||||||
return soc_semantic_loss, soc_detail_loss
|
return soc_semantic_loss, soc_detail_loss
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from matting_dataset import MattingDataset, Rescale, ToTensor, Normalize, ToTrainArray, ConvertImageDtype
|
||||||
|
from torchvision import transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from models.modnet import MODNet
|
||||||
|
transform = transforms.Compose([
|
||||||
|
Rescale(512),
|
||||||
|
ToTensor(),
|
||||||
|
ConvertImageDtype(),
|
||||||
|
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
|
ToTrainArray()
|
||||||
|
])
|
||||||
|
mattingDataset = MattingDataset(transform=transform)
|
||||||
|
|
||||||
|
bs = 16 # batch size
|
||||||
|
lr = 0.01 # learn rate
|
||||||
|
epochs = 40 # total epochs
|
||||||
|
|
||||||
|
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
|
||||||
|
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
|
||||||
|
|
||||||
|
dataloader = DataLoader(mattingDataset,
|
||||||
|
batch_size=bs,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0)
|
||||||
|
|
||||||
|
for epoch in range(0, epochs):
|
||||||
|
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||||
|
semantic_loss, detail_loss, matte_loss = \
|
||||||
|
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
|
||||||
|
break
|
||||||
|
lr_scheduler.step()
|
||||||
|
break
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue