[+] train loop scripts

pull/170/head
Kapulkin Stanislav 2022-02-07 00:19:06 +03:00
parent 5f673d5a34
commit b6f9f63d83
3 changed files with 146 additions and 0 deletions

50
src/train/dataset.py Normal file
View File

@ -0,0 +1,50 @@
import os
import glob
import cv2
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union, Dict
import torch
import torchvision
from torch.utils.data import Dataset
from src.train.trimap import makeTrimap
class SegDataset(Dataset):
"""A custom Dataset(torch.utils.data) implement three functions: __init__, __len__, and __getitem__.
Datasets are created from PTFDataModule.
"""
def __init__(
self,
frame_dir: Union[str, Path],
mask_dir: Union[str, Path]
) -> None:
self.frame_dir = Path(frame_dir)
self.mask_dir = Path(mask_dir)
self.image_names = glob.glob(f"{self.frame_dir}/*.jpg")
self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names]
print(self.mask_names)
self.transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
frame_pth = self.image_names[index]
mask_pth = self.mask_names[index]
frame = cv2.imread(frame_pth)
frame = self.transform(frame)
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
trimap = torch.from_numpy(makeTrimap(mask)).float()
trimap = torch.unsqueeze(trimap,0)
mask = torch.from_numpy(mask)
mask = torch.unsqueeze(mask,0).float()
return frame, trimap, mask
def __len__(self):
return len(self.image_names)

80
src/train/train.py Normal file
View File

@ -0,0 +1,80 @@
import os
import argparse
import logging
import logging.handlers
import torch
import torch.nn as nn
import neptune.new as neptune
from src.models.modnet import MODNet
from src.trainer import supervised_training_iter
from src.train.dataset import SegDataset
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
def parseArgs():
parser = argparse.ArgumentParser()
parser.add_argument('--datasetPath', type=str, required=True, help='path to dataset')
parser.add_argument('--modelsPath', type=str, required=True, help='path to save trained MODNet models')
parser.add_argument('--pretrainedPath', type=str, help='path of pre-trained MODNet')
parser.add_argument('--startEpoch', type=int, default=-1, help='epoch to start with')
parser.add_argument('--batchCount', type=int, default=16, help='batches count')
args = parser.parse_args()
return args
args = parseArgs()
batch_size = args.batchCount
lr = 0.01 # learn rate
epochs = 40 # total epochs
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
if args.pretrainedPath is not None:
modnet.load_state_dict(
torch.load(args.pretrainedPath)
)
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=args.startEpoch)
dataset = SegDataset(os.path.join(args.datasetPath, "images"), os.path.join(args.datasetPath, "masks"))
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True
)
project = '<YOUR_WORKSPACE/YOUR_PROJECT>'
api_token = '<YOURR_API_TOKEN>'
neptuneRun = neptune.init(project = project,
api_token = api_token,
source_files=[])
for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1)
if idx % 100 == 0:
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
neptuneRun["training/epoch/semantic_loss"].log(semantic_loss)
neptuneRun["training/epoch/detail_loss"].log(detail_loss)
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
modelPath = os.path.join(args.modelsPath, f"model_epoch{epoch}.ckpt")
torch.save(modnet.state_dict(), modelPath)
logger.info(f"model saved to {modelPath}")
lr_scheduler.step()
torch.save(modnet.state_dict(), os.path.join(args.modelsPath, "model.ckpt"))
neptuneRun.stop()

16
src/train/trimap.py Normal file
View File

@ -0,0 +1,16 @@
import numpy as np
import cv2
def makeEdgeMask(mask, width):
kernel = np.ones((width,width), np.uint8)
erosion = cv2.erode(mask, kernel, iterations = 1)
dilation = cv2.dilate(mask, kernel, iterations = 1)
return dilation - erosion
def makeTrimap(mask, width = 5):
edgeMask = makeEdgeMask(mask, width)
trimap = mask.astype(np.float)
trimap[edgeMask == 1] = 0.5
return trimap