mirror of https://github.com/ZHKKKe/MODNet.git
[+] train loop scripts
parent
5f673d5a34
commit
b6f9f63d83
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
import glob
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple, Union, Dict
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from src.train.trimap import makeTrimap
|
||||
|
||||
class SegDataset(Dataset):
|
||||
"""A custom Dataset(torch.utils.data) implement three functions: __init__, __len__, and __getitem__.
|
||||
Datasets are created from PTFDataModule.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_dir: Union[str, Path],
|
||||
mask_dir: Union[str, Path]
|
||||
) -> None:
|
||||
|
||||
self.frame_dir = Path(frame_dir)
|
||||
self.mask_dir = Path(mask_dir)
|
||||
self.image_names = glob.glob(f"{self.frame_dir}/*.jpg")
|
||||
self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names]
|
||||
print(self.mask_names)
|
||||
self.transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
||||
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
frame_pth = self.image_names[index]
|
||||
mask_pth = self.mask_names[index]
|
||||
|
||||
frame = cv2.imread(frame_pth)
|
||||
frame = self.transform(frame)
|
||||
|
||||
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
|
||||
trimap = torch.from_numpy(makeTrimap(mask)).float()
|
||||
trimap = torch.unsqueeze(trimap,0)
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = torch.unsqueeze(mask,0).float()
|
||||
|
||||
return frame, trimap, mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_names)
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import os
|
||||
import argparse
|
||||
import logging
|
||||
import logging.handlers
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import neptune.new as neptune
|
||||
|
||||
from src.models.modnet import MODNet
|
||||
from src.trainer import supervised_training_iter
|
||||
from src.train.dataset import SegDataset
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parseArgs():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--datasetPath', type=str, required=True, help='path to dataset')
|
||||
parser.add_argument('--modelsPath', type=str, required=True, help='path to save trained MODNet models')
|
||||
parser.add_argument('--pretrainedPath', type=str, help='path of pre-trained MODNet')
|
||||
parser.add_argument('--startEpoch', type=int, default=-1, help='epoch to start with')
|
||||
parser.add_argument('--batchCount', type=int, default=16, help='batches count')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = parseArgs()
|
||||
|
||||
batch_size = args.batchCount
|
||||
lr = 0.01 # learn rate
|
||||
epochs = 40 # total epochs
|
||||
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
modnet = nn.DataParallel(modnet)
|
||||
|
||||
if args.pretrainedPath is not None:
|
||||
modnet.load_state_dict(
|
||||
torch.load(args.pretrainedPath)
|
||||
)
|
||||
|
||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=args.startEpoch)
|
||||
|
||||
dataset = SegDataset(os.path.join(args.datasetPath, "images"), os.path.join(args.datasetPath, "masks"))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
|
||||
project = '<YOUR_WORKSPACE/YOUR_PROJECT>'
|
||||
api_token = '<YOURR_API_TOKEN>'
|
||||
neptuneRun = neptune.init(project = project,
|
||||
api_token = api_token,
|
||||
source_files=[])
|
||||
|
||||
for epoch in range(0, epochs):
|
||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1)
|
||||
if idx % 100 == 0:
|
||||
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||
|
||||
neptuneRun["training/epoch/semantic_loss"].log(semantic_loss)
|
||||
neptuneRun["training/epoch/detail_loss"].log(detail_loss)
|
||||
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
|
||||
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
|
||||
|
||||
modelPath = os.path.join(args.modelsPath, f"model_epoch{epoch}.ckpt")
|
||||
torch.save(modnet.state_dict(), modelPath)
|
||||
logger.info(f"model saved to {modelPath}")
|
||||
lr_scheduler.step()
|
||||
|
||||
torch.save(modnet.state_dict(), os.path.join(args.modelsPath, "model.ckpt"))
|
||||
|
||||
neptuneRun.stop()
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
import numpy as np
|
||||
import cv2
|
||||
|
||||
def makeEdgeMask(mask, width):
|
||||
kernel = np.ones((width,width), np.uint8)
|
||||
|
||||
erosion = cv2.erode(mask, kernel, iterations = 1)
|
||||
dilation = cv2.dilate(mask, kernel, iterations = 1)
|
||||
|
||||
return dilation - erosion
|
||||
|
||||
def makeTrimap(mask, width = 5):
|
||||
edgeMask = makeEdgeMask(mask, width)
|
||||
trimap = mask.astype(np.float)
|
||||
trimap[edgeMask == 1] = 0.5
|
||||
return trimap
|
||||
Loading…
Reference in New Issue