mirror of https://github.com/ZHKKKe/MODNet.git
模型训练代码
parent
27db49f9de
commit
4424ab5b72
|
|
@ -1,4 +1,4 @@
|
|||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import random
|
||||
import cv2
|
||||
|
|
@ -36,35 +36,13 @@ class MattingDataset(Dataset):
|
|||
|
||||
image = Image.open(image_file_name)
|
||||
matte = Image.open(matte_file_name)
|
||||
# matte = matte.convert('RGB')
|
||||
trimap = self.gen_trimap(matte)
|
||||
|
||||
data = {'image': image, 'trimap': trimap, 'gt_matte': matte}
|
||||
data = {'image': image, 'gt_matte': matte}
|
||||
|
||||
if self.transform:
|
||||
data = self.transform(data)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def gen_trimap(matte):
|
||||
"""
|
||||
根据归matte生成归一化的trimap
|
||||
"""
|
||||
matte = np.array(matte)
|
||||
k_size = random.choice(range(2, 5))
|
||||
iterations = np.random.randint(5, 15)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
||||
(k_size, k_size))
|
||||
dilated = cv2.dilate(matte, kernel, iterations=iterations)
|
||||
eroded = cv2.erode(matte, kernel, iterations=iterations)
|
||||
|
||||
trimap = np.zeros(matte.shape)
|
||||
trimap.fill(128)
|
||||
trimap[eroded > 254.5] = 255
|
||||
trimap[dilated < 0.5] = 0
|
||||
trimap = Image.fromarray(np.uint8(trimap))
|
||||
return trimap
|
||||
|
||||
|
||||
class Rescale(object):
|
||||
"""Rescale the image in a sample to a given size.
|
||||
|
|
@ -80,22 +58,41 @@ class Rescale(object):
|
|||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
image, trimap, gt_matte = sample['image'], sample['trimap'], sample['gt_matte']
|
||||
image, gt_matte = sample['image'], sample['gt_matte']
|
||||
|
||||
# w, h = image.size
|
||||
# if h > w:
|
||||
# new_h, new_w = self.output_size * h / w, self.output_size
|
||||
# else:
|
||||
# new_h, new_w = self.output_size, self.output_size * w / h
|
||||
|
||||
# new_h, new_w = int(new_h), int(new_w)
|
||||
new_h, new_w = int(self.output_size), int(self.output_size)
|
||||
|
||||
new_img = F.resize(image, (new_h, new_w))
|
||||
new_trimap = F.resize(trimap, (new_h, new_w))
|
||||
new_gt_matte = F.resize(gt_matte, (new_h, new_w))
|
||||
|
||||
return {'image': new_img, 'trimap': new_trimap, 'gt_matte': new_gt_matte}
|
||||
return {'image': new_img,'gt_matte': new_gt_matte}
|
||||
|
||||
|
||||
class GenTrimap(object):
|
||||
def __call__(self, sample):
|
||||
gt_matte = sample['gt_matte']
|
||||
trimap = self.gen_trimap(gt_matte)
|
||||
sample['trimap'] = trimap
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def gen_trimap(matte):
|
||||
"""
|
||||
根据归matte生成归一化的trimap
|
||||
"""
|
||||
matte = np.array(matte)
|
||||
k_size = random.choice(range(2, 5))
|
||||
iterations = np.random.randint(5, 15)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
||||
(k_size, k_size))
|
||||
dilated = cv2.dilate(matte, kernel, iterations=iterations)
|
||||
eroded = cv2.erode(matte, kernel, iterations=iterations)
|
||||
|
||||
trimap = np.zeros(matte.shape)
|
||||
trimap.fill(0.5)
|
||||
trimap[eroded > 254.5] = 1.0
|
||||
trimap[dilated < 0.5] = 0.0
|
||||
trimap = Image.fromarray(trimap)
|
||||
return trimap
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
|
|
@ -106,6 +103,8 @@ class ToTensor(object):
|
|||
image = F.pil_to_tensor(image)
|
||||
trimap = F.pil_to_tensor(trimap)
|
||||
gt_matte = F.pil_to_tensor(gt_matte)
|
||||
|
||||
|
||||
return {'image': image,
|
||||
'trimap': trimap,
|
||||
'gt_matte': gt_matte}
|
||||
|
|
@ -132,8 +131,6 @@ class Normalize(object):
|
|||
image = image.type(torch.FloatTensor)
|
||||
image = F.normalize(image, self.mean, self.std, self.inplace)
|
||||
sample['image'] = image
|
||||
|
||||
sample['trimap'] = trimap / 255 # 归一化
|
||||
return sample
|
||||
|
||||
|
||||
|
|
@ -152,13 +149,16 @@ if __name__ == '__main__':
|
|||
|
||||
# test MattingDataset.gen_trimap
|
||||
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
|
||||
matte = matte.convert('RGB')
|
||||
trimap = MattingDataset.gen_trimap(matte)
|
||||
trimap.save('test_trimap.png')
|
||||
trimap1 = GenTrimap().gen_trimap(matte)
|
||||
trimap1 = np.array(trimap1) * 255
|
||||
trimap1 = np.uint8(trimap1)
|
||||
trimap1 = Image.fromarray(trimap1)
|
||||
trimap1.save('test_trimap.png')
|
||||
|
||||
# test MattingDataset
|
||||
transform = transforms.Compose([
|
||||
Rescale(512),
|
||||
GenTrimap(),
|
||||
ToTensor(),
|
||||
# Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
|
|
|||
|
|
@ -301,12 +301,15 @@ def soc_adaptation_iter(
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from matting_dataset import MattingDataset, Rescale, ToTensor, Normalize, ToTrainArray, ConvertImageDtype
|
||||
from matting_dataset import MattingDataset, Rescale, \
|
||||
ToTensor, Normalize, ToTrainArray, \
|
||||
ConvertImageDtype, GenTrimap
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from models.modnet import MODNet
|
||||
transform = transforms.Compose([
|
||||
Rescale(512),
|
||||
GenTrimap(),
|
||||
ToTensor(),
|
||||
ConvertImageDtype(),
|
||||
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
|
|
@ -314,7 +317,7 @@ if __name__ == '__main__':
|
|||
])
|
||||
mattingDataset = MattingDataset(transform=transform)
|
||||
|
||||
bs = 16 # batch size
|
||||
bs = 4 # batch size
|
||||
lr = 0.01 # learn rate
|
||||
epochs = 40 # total epochs
|
||||
|
||||
|
|
@ -327,13 +330,28 @@ if __name__ == '__main__':
|
|||
|
||||
dataloader = DataLoader(mattingDataset,
|
||||
batch_size=bs,
|
||||
shuffle=True,
|
||||
num_workers=0)
|
||||
shuffle=True)
|
||||
|
||||
for epoch in range(0, epochs):
|
||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||
semantic_loss, detail_loss, matte_loss = \
|
||||
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
|
||||
break
|
||||
lr_scheduler.step()
|
||||
break
|
||||
if epoch % 4 == 0 and epoch > 1:
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': modnet.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
|
||||
}, f'pretrained/modnet_custom_portrait_matting_{epoch+1}.ckpt')
|
||||
lr_scheduler.step()
|
||||
print(f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}')
|
||||
if epoch == 4:
|
||||
break
|
||||
|
||||
torch.save({
|
||||
'epoch': epochs,
|
||||
'model_state_dict': modnet.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
|
||||
}, f'pretrained/modnet_custom_portrait_matting_last_epoch.ckpt')
|
||||
|
|
|
|||
Loading…
Reference in New Issue