mirror of https://github.com/ZHKKKe/MODNet.git
[+] train loop scripts
parent
5f673d5a34
commit
b6f9f63d83
|
|
@ -0,0 +1,50 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import cv2
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from src.train.trimap import makeTrimap
|
||||||
|
|
||||||
|
class SegDataset(Dataset):
|
||||||
|
"""A custom Dataset(torch.utils.data) implement three functions: __init__, __len__, and __getitem__.
|
||||||
|
Datasets are created from PTFDataModule.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frame_dir: Union[str, Path],
|
||||||
|
mask_dir: Union[str, Path]
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
self.frame_dir = Path(frame_dir)
|
||||||
|
self.mask_dir = Path(mask_dir)
|
||||||
|
self.image_names = glob.glob(f"{self.frame_dir}/*.jpg")
|
||||||
|
self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names]
|
||||||
|
print(self.mask_names)
|
||||||
|
self.transform = torchvision.transforms.Compose([
|
||||||
|
torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
frame_pth = self.image_names[index]
|
||||||
|
mask_pth = self.mask_names[index]
|
||||||
|
|
||||||
|
frame = cv2.imread(frame_pth)
|
||||||
|
frame = self.transform(frame)
|
||||||
|
|
||||||
|
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
|
||||||
|
trimap = torch.from_numpy(makeTrimap(mask)).float()
|
||||||
|
trimap = torch.unsqueeze(trimap,0)
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
mask = torch.unsqueeze(mask,0).float()
|
||||||
|
|
||||||
|
return frame, trimap, mask
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_names)
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import neptune.new as neptune
|
||||||
|
|
||||||
|
from src.models.modnet import MODNet
|
||||||
|
from src.trainer import supervised_training_iter
|
||||||
|
from src.train.dataset import SegDataset
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def parseArgs():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--datasetPath', type=str, required=True, help='path to dataset')
|
||||||
|
parser.add_argument('--modelsPath', type=str, required=True, help='path to save trained MODNet models')
|
||||||
|
parser.add_argument('--pretrainedPath', type=str, help='path of pre-trained MODNet')
|
||||||
|
parser.add_argument('--startEpoch', type=int, default=-1, help='epoch to start with')
|
||||||
|
parser.add_argument('--batchCount', type=int, default=16, help='batches count')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
args = parseArgs()
|
||||||
|
|
||||||
|
batch_size = args.batchCount
|
||||||
|
lr = 0.01 # learn rate
|
||||||
|
epochs = 40 # total epochs
|
||||||
|
|
||||||
|
modnet = MODNet(backbone_pretrained=False)
|
||||||
|
modnet = nn.DataParallel(modnet)
|
||||||
|
|
||||||
|
if args.pretrainedPath is not None:
|
||||||
|
modnet.load_state_dict(
|
||||||
|
torch.load(args.pretrainedPath)
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=args.startEpoch)
|
||||||
|
|
||||||
|
dataset = SegDataset(os.path.join(args.datasetPath, "images"), os.path.join(args.datasetPath, "masks"))
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
project = '<YOUR_WORKSPACE/YOUR_PROJECT>'
|
||||||
|
api_token = '<YOURR_API_TOKEN>'
|
||||||
|
neptuneRun = neptune.init(project = project,
|
||||||
|
api_token = api_token,
|
||||||
|
source_files=[])
|
||||||
|
|
||||||
|
for epoch in range(0, epochs):
|
||||||
|
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||||
|
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1)
|
||||||
|
if idx % 100 == 0:
|
||||||
|
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||||
|
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||||
|
|
||||||
|
neptuneRun["training/epoch/semantic_loss"].log(semantic_loss)
|
||||||
|
neptuneRun["training/epoch/detail_loss"].log(detail_loss)
|
||||||
|
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
|
||||||
|
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
|
||||||
|
|
||||||
|
modelPath = os.path.join(args.modelsPath, f"model_epoch{epoch}.ckpt")
|
||||||
|
torch.save(modnet.state_dict(), modelPath)
|
||||||
|
logger.info(f"model saved to {modelPath}")
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
torch.save(modnet.state_dict(), os.path.join(args.modelsPath, "model.ckpt"))
|
||||||
|
|
||||||
|
neptuneRun.stop()
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def makeEdgeMask(mask, width):
|
||||||
|
kernel = np.ones((width,width), np.uint8)
|
||||||
|
|
||||||
|
erosion = cv2.erode(mask, kernel, iterations = 1)
|
||||||
|
dilation = cv2.dilate(mask, kernel, iterations = 1)
|
||||||
|
|
||||||
|
return dilation - erosion
|
||||||
|
|
||||||
|
def makeTrimap(mask, width = 5):
|
||||||
|
edgeMask = makeEdgeMask(mask, width)
|
||||||
|
trimap = mask.astype(np.float)
|
||||||
|
trimap[edgeMask == 1] = 0.5
|
||||||
|
return trimap
|
||||||
Loading…
Reference in New Issue