mirror of https://github.com/ZHKKKe/MODNet.git
模型训练代码
parent
27db49f9de
commit
4424ab5b72
|
|
@ -1,4 +1,4 @@
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
|
|
@ -36,35 +36,13 @@ class MattingDataset(Dataset):
|
||||||
|
|
||||||
image = Image.open(image_file_name)
|
image = Image.open(image_file_name)
|
||||||
matte = Image.open(matte_file_name)
|
matte = Image.open(matte_file_name)
|
||||||
# matte = matte.convert('RGB')
|
|
||||||
trimap = self.gen_trimap(matte)
|
|
||||||
|
|
||||||
data = {'image': image, 'trimap': trimap, 'gt_matte': matte}
|
data = {'image': image, 'gt_matte': matte}
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
data = self.transform(data)
|
data = self.transform(data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gen_trimap(matte):
|
|
||||||
"""
|
|
||||||
根据归matte生成归一化的trimap
|
|
||||||
"""
|
|
||||||
matte = np.array(matte)
|
|
||||||
k_size = random.choice(range(2, 5))
|
|
||||||
iterations = np.random.randint(5, 15)
|
|
||||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
|
||||||
(k_size, k_size))
|
|
||||||
dilated = cv2.dilate(matte, kernel, iterations=iterations)
|
|
||||||
eroded = cv2.erode(matte, kernel, iterations=iterations)
|
|
||||||
|
|
||||||
trimap = np.zeros(matte.shape)
|
|
||||||
trimap.fill(128)
|
|
||||||
trimap[eroded > 254.5] = 255
|
|
||||||
trimap[dilated < 0.5] = 0
|
|
||||||
trimap = Image.fromarray(np.uint8(trimap))
|
|
||||||
return trimap
|
|
||||||
|
|
||||||
|
|
||||||
class Rescale(object):
|
class Rescale(object):
|
||||||
"""Rescale the image in a sample to a given size.
|
"""Rescale the image in a sample to a given size.
|
||||||
|
|
@ -80,22 +58,41 @@ class Rescale(object):
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
|
||||||
def __call__(self, sample):
|
def __call__(self, sample):
|
||||||
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
image, gt_matte = sample['image'], sample['gt_matte']
|
||||||
|
|
||||||
# w, h = image.size
|
|
||||||
# if h > w:
|
|
||||||
# new_h, new_w = self.output_size * h / w, self.output_size
|
|
||||||
# else:
|
|
||||||
# new_h, new_w = self.output_size, self.output_size * w / h
|
|
||||||
|
|
||||||
# new_h, new_w = int(new_h), int(new_w)
|
|
||||||
new_h, new_w = int(self.output_size), int(self.output_size)
|
new_h, new_w = int(self.output_size), int(self.output_size)
|
||||||
|
|
||||||
new_img = F.resize(image, (new_h, new_w))
|
new_img = F.resize(image, (new_h, new_w))
|
||||||
new_trimap = F.resize(trimap, (new_h, new_w))
|
|
||||||
new_gt_matte = F.resize(gt_matte, (new_h, new_w))
|
new_gt_matte = F.resize(gt_matte, (new_h, new_w))
|
||||||
|
|
||||||
return {'image': new_img, 'trimap': new_trimap, 'gt_matte': new_gt_matte}
|
return {'image': new_img,'gt_matte': new_gt_matte}
|
||||||
|
|
||||||
|
|
||||||
|
class GenTrimap(object):
|
||||||
|
def __call__(self, sample):
|
||||||
|
gt_matte = sample['gt_matte']
|
||||||
|
trimap = self.gen_trimap(gt_matte)
|
||||||
|
sample['trimap'] = trimap
|
||||||
|
return sample
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_trimap(matte):
|
||||||
|
"""
|
||||||
|
根据归matte生成归一化的trimap
|
||||||
|
"""
|
||||||
|
matte = np.array(matte)
|
||||||
|
k_size = random.choice(range(2, 5))
|
||||||
|
iterations = np.random.randint(5, 15)
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
||||||
|
(k_size, k_size))
|
||||||
|
dilated = cv2.dilate(matte, kernel, iterations=iterations)
|
||||||
|
eroded = cv2.erode(matte, kernel, iterations=iterations)
|
||||||
|
|
||||||
|
trimap = np.zeros(matte.shape)
|
||||||
|
trimap.fill(0.5)
|
||||||
|
trimap[eroded > 254.5] = 1.0
|
||||||
|
trimap[dilated < 0.5] = 0.0
|
||||||
|
trimap = Image.fromarray(trimap)
|
||||||
|
return trimap
|
||||||
|
|
||||||
|
|
||||||
class ToTensor(object):
|
class ToTensor(object):
|
||||||
|
|
@ -106,6 +103,8 @@ class ToTensor(object):
|
||||||
image = F.pil_to_tensor(image)
|
image = F.pil_to_tensor(image)
|
||||||
trimap = F.pil_to_tensor(trimap)
|
trimap = F.pil_to_tensor(trimap)
|
||||||
gt_matte = F.pil_to_tensor(gt_matte)
|
gt_matte = F.pil_to_tensor(gt_matte)
|
||||||
|
|
||||||
|
|
||||||
return {'image': image,
|
return {'image': image,
|
||||||
'trimap': trimap,
|
'trimap': trimap,
|
||||||
'gt_matte': gt_matte}
|
'gt_matte': gt_matte}
|
||||||
|
|
@ -132,8 +131,6 @@ class Normalize(object):
|
||||||
image = image.type(torch.FloatTensor)
|
image = image.type(torch.FloatTensor)
|
||||||
image = F.normalize(image, self.mean, self.std, self.inplace)
|
image = F.normalize(image, self.mean, self.std, self.inplace)
|
||||||
sample['image'] = image
|
sample['image'] = image
|
||||||
|
|
||||||
sample['trimap'] = trimap / 255 # 归一化
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -152,13 +149,16 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# test MattingDataset.gen_trimap
|
# test MattingDataset.gen_trimap
|
||||||
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
|
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
|
||||||
matte = matte.convert('RGB')
|
trimap1 = GenTrimap().gen_trimap(matte)
|
||||||
trimap = MattingDataset.gen_trimap(matte)
|
trimap1 = np.array(trimap1) * 255
|
||||||
trimap.save('test_trimap.png')
|
trimap1 = np.uint8(trimap1)
|
||||||
|
trimap1 = Image.fromarray(trimap1)
|
||||||
|
trimap1.save('test_trimap.png')
|
||||||
|
|
||||||
# test MattingDataset
|
# test MattingDataset
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
Rescale(512),
|
Rescale(512),
|
||||||
|
GenTrimap(),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
# Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
# Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
])
|
])
|
||||||
|
|
|
||||||
|
|
@ -301,12 +301,15 @@ def soc_adaptation_iter(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from matting_dataset import MattingDataset, Rescale, ToTensor, Normalize, ToTrainArray, ConvertImageDtype
|
from matting_dataset import MattingDataset, Rescale, \
|
||||||
|
ToTensor, Normalize, ToTrainArray, \
|
||||||
|
ConvertImageDtype, GenTrimap
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from models.modnet import MODNet
|
from models.modnet import MODNet
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
Rescale(512),
|
Rescale(512),
|
||||||
|
GenTrimap(),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
ConvertImageDtype(),
|
ConvertImageDtype(),
|
||||||
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
|
|
@ -314,7 +317,7 @@ if __name__ == '__main__':
|
||||||
])
|
])
|
||||||
mattingDataset = MattingDataset(transform=transform)
|
mattingDataset = MattingDataset(transform=transform)
|
||||||
|
|
||||||
bs = 16 # batch size
|
bs = 4 # batch size
|
||||||
lr = 0.01 # learn rate
|
lr = 0.01 # learn rate
|
||||||
epochs = 40 # total epochs
|
epochs = 40 # total epochs
|
||||||
|
|
||||||
|
|
@ -327,13 +330,28 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
dataloader = DataLoader(mattingDataset,
|
dataloader = DataLoader(mattingDataset,
|
||||||
batch_size=bs,
|
batch_size=bs,
|
||||||
shuffle=True,
|
shuffle=True)
|
||||||
num_workers=0)
|
|
||||||
|
|
||||||
for epoch in range(0, epochs):
|
for epoch in range(0, epochs):
|
||||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||||
semantic_loss, detail_loss, matte_loss = \
|
semantic_loss, detail_loss, matte_loss = \
|
||||||
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
|
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
|
||||||
break
|
break
|
||||||
lr_scheduler.step()
|
if epoch % 4 == 0 and epoch > 1:
|
||||||
break
|
torch.save({
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': modnet.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
|
||||||
|
}, f'pretrained/modnet_custom_portrait_matting_{epoch+1}.ckpt')
|
||||||
|
lr_scheduler.step()
|
||||||
|
print(f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}')
|
||||||
|
if epoch == 4:
|
||||||
|
break
|
||||||
|
|
||||||
|
torch.save({
|
||||||
|
'epoch': epochs,
|
||||||
|
'model_state_dict': modnet.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
|
||||||
|
}, f'pretrained/modnet_custom_portrait_matting_last_epoch.ckpt')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue