模型训练代码

pull/177/head
actboy 2022-02-10 00:06:40 +08:00
parent 27db49f9de
commit 4424ab5b72
2 changed files with 64 additions and 46 deletions

View File

@ -1,4 +1,4 @@
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import cv2
@ -36,35 +36,13 @@ class MattingDataset(Dataset):
image = Image.open(image_file_name)
matte = Image.open(matte_file_name)
# matte = matte.convert('RGB')
trimap = self.gen_trimap(matte)
data = {'image': image, 'trimap': trimap, 'gt_matte': matte}
data = {'image': image, 'gt_matte': matte}
if self.transform:
data = self.transform(data)
return data
@staticmethod
def gen_trimap(matte):
"""
根据归matte生成归一化的trimap
"""
matte = np.array(matte)
k_size = random.choice(range(2, 5))
iterations = np.random.randint(5, 15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(k_size, k_size))
dilated = cv2.dilate(matte, kernel, iterations=iterations)
eroded = cv2.erode(matte, kernel, iterations=iterations)
trimap = np.zeros(matte.shape)
trimap.fill(128)
trimap[eroded > 254.5] = 255
trimap[dilated < 0.5] = 0
trimap = Image.fromarray(np.uint8(trimap))
return trimap
class Rescale(object):
"""Rescale the image in a sample to a given size.
@ -80,22 +58,41 @@ class Rescale(object):
self.output_size = output_size
def __call__(self, sample):
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
image, gt_matte = sample['image'], sample['gt_matte']
# w, h = image.size
# if h > w:
# new_h, new_w = self.output_size * h / w, self.output_size
# else:
# new_h, new_w = self.output_size, self.output_size * w / h
# new_h, new_w = int(new_h), int(new_w)
new_h, new_w = int(self.output_size), int(self.output_size)
new_img = F.resize(image, (new_h, new_w))
new_trimap = F.resize(trimap, (new_h, new_w))
new_gt_matte = F.resize(gt_matte, (new_h, new_w))
return {'image': new_img, 'trimap': new_trimap, 'gt_matte': new_gt_matte}
return {'image': new_img,'gt_matte': new_gt_matte}
class GenTrimap(object):
def __call__(self, sample):
gt_matte = sample['gt_matte']
trimap = self.gen_trimap(gt_matte)
sample['trimap'] = trimap
return sample
@staticmethod
def gen_trimap(matte):
"""
根据归matte生成归一化的trimap
"""
matte = np.array(matte)
k_size = random.choice(range(2, 5))
iterations = np.random.randint(5, 15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(k_size, k_size))
dilated = cv2.dilate(matte, kernel, iterations=iterations)
eroded = cv2.erode(matte, kernel, iterations=iterations)
trimap = np.zeros(matte.shape)
trimap.fill(0.5)
trimap[eroded > 254.5] = 1.0
trimap[dilated < 0.5] = 0.0
trimap = Image.fromarray(trimap)
return trimap
class ToTensor(object):
@ -106,6 +103,8 @@ class ToTensor(object):
image = F.pil_to_tensor(image)
trimap = F.pil_to_tensor(trimap)
gt_matte = F.pil_to_tensor(gt_matte)
return {'image': image,
'trimap': trimap,
'gt_matte': gt_matte}
@ -132,8 +131,6 @@ class Normalize(object):
image = image.type(torch.FloatTensor)
image = F.normalize(image, self.mean, self.std, self.inplace)
sample['image'] = image
sample['trimap'] = trimap / 255 # 归一化
return sample
@ -152,13 +149,16 @@ if __name__ == '__main__':
# test MattingDataset.gen_trimap
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
matte = matte.convert('RGB')
trimap = MattingDataset.gen_trimap(matte)
trimap.save('test_trimap.png')
trimap1 = GenTrimap().gen_trimap(matte)
trimap1 = np.array(trimap1) * 255
trimap1 = np.uint8(trimap1)
trimap1 = Image.fromarray(trimap1)
trimap1.save('test_trimap.png')
# test MattingDataset
transform = transforms.Compose([
Rescale(512),
GenTrimap(),
ToTensor(),
# Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

View File

@ -301,12 +301,15 @@ def soc_adaptation_iter(
if __name__ == '__main__':
from matting_dataset import MattingDataset, Rescale, ToTensor, Normalize, ToTrainArray, ConvertImageDtype
from matting_dataset import MattingDataset, Rescale, \
ToTensor, Normalize, ToTrainArray, \
ConvertImageDtype, GenTrimap
from torchvision import transforms
from torch.utils.data import DataLoader
from models.modnet import MODNet
transform = transforms.Compose([
Rescale(512),
GenTrimap(),
ToTensor(),
ConvertImageDtype(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
@ -314,7 +317,7 @@ if __name__ == '__main__':
])
mattingDataset = MattingDataset(transform=transform)
bs = 16 # batch size
bs = 4 # batch size
lr = 0.01 # learn rate
epochs = 40 # total epochs
@ -327,13 +330,28 @@ if __name__ == '__main__':
dataloader = DataLoader(mattingDataset,
batch_size=bs,
shuffle=True,
num_workers=0)
shuffle=True)
for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss = \
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
break
if epoch % 4 == 0 and epoch > 1:
torch.save({
'epoch': epoch,
'model_state_dict': modnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
}, f'pretrained/modnet_custom_portrait_matting_{epoch+1}.ckpt')
lr_scheduler.step()
print(f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}')
if epoch == 4:
break
torch.save({
'epoch': epochs,
'model_state_dict': modnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
}, f'pretrained/modnet_custom_portrait_matting_last_epoch.ckpt')