diff --git a/src/train/dataset.py b/src/train/dataset.py new file mode 100644 index 0000000..20b3ab1 --- /dev/null +++ b/src/train/dataset.py @@ -0,0 +1,50 @@ +import os +import glob +import cv2 +from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union, Dict + +import torch +import torchvision +from torch.utils.data import Dataset + +from src.train.trimap import makeTrimap + +class SegDataset(Dataset): + """A custom Dataset(torch.utils.data) implement three functions: __init__, __len__, and __getitem__. + Datasets are created from PTFDataModule. + """ + + def __init__( + self, + frame_dir: Union[str, Path], + mask_dir: Union[str, Path] + ) -> None: + + self.frame_dir = Path(frame_dir) + self.mask_dir = Path(mask_dir) + self.image_names = glob.glob(f"{self.frame_dir}/*.jpg") + self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names] + print(self.mask_names) + self.transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + frame_pth = self.image_names[index] + mask_pth = self.mask_names[index] + + frame = cv2.imread(frame_pth) + frame = self.transform(frame) + + mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE) + trimap = torch.from_numpy(makeTrimap(mask)).float() + trimap = torch.unsqueeze(trimap,0) + mask = torch.from_numpy(mask) + mask = torch.unsqueeze(mask,0).float() + + return frame, trimap, mask + + def __len__(self): + return len(self.image_names) \ No newline at end of file diff --git a/src/train/train.py b/src/train/train.py new file mode 100644 index 0000000..458655d --- /dev/null +++ b/src/train/train.py @@ -0,0 +1,80 @@ +import os +import argparse +import logging +import logging.handlers + +import torch +import torch.nn as nn + +import neptune.new as neptune + +from src.models.modnet import MODNet +from src.trainer import supervised_training_iter +from src.train.dataset import SegDataset + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +def parseArgs(): + parser = argparse.ArgumentParser() + parser.add_argument('--datasetPath', type=str, required=True, help='path to dataset') + parser.add_argument('--modelsPath', type=str, required=True, help='path to save trained MODNet models') + parser.add_argument('--pretrainedPath', type=str, help='path of pre-trained MODNet') + parser.add_argument('--startEpoch', type=int, default=-1, help='epoch to start with') + parser.add_argument('--batchCount', type=int, default=16, help='batches count') + args = parser.parse_args() + return args + +args = parseArgs() + +batch_size = args.batchCount +lr = 0.01 # learn rate +epochs = 40 # total epochs + +modnet = MODNet(backbone_pretrained=False) +modnet = nn.DataParallel(modnet) + +if args.pretrainedPath is not None: + modnet.load_state_dict( + torch.load(args.pretrainedPath) + ) + +optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) +lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=args.startEpoch) + +dataset = SegDataset(os.path.join(args.datasetPath, "images"), os.path.join(args.datasetPath, "masks")) + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=True +) + + +project = '' +api_token = '' +neptuneRun = neptune.init(project = project, + api_token = api_token, + source_files=[]) + +for epoch in range(0, epochs): + for idx, (image, trimap, gt_matte) in enumerate(dataloader): + semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1) + if idx % 100 == 0: + logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') + logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') + + neptuneRun["training/epoch/semantic_loss"].log(semantic_loss) + neptuneRun["training/epoch/detail_loss"].log(detail_loss) + neptuneRun["training/epoch/matte_loss"].log(matte_loss) + neptuneRun["training/epoch/semantic_iou"].log(semantic_iou) + + modelPath = os.path.join(args.modelsPath, f"model_epoch{epoch}.ckpt") + torch.save(modnet.state_dict(), modelPath) + logger.info(f"model saved to {modelPath}") + lr_scheduler.step() + +torch.save(modnet.state_dict(), os.path.join(args.modelsPath, "model.ckpt")) + +neptuneRun.stop() diff --git a/src/train/trimap.py b/src/train/trimap.py new file mode 100644 index 0000000..e54d3e9 --- /dev/null +++ b/src/train/trimap.py @@ -0,0 +1,16 @@ +import numpy as np +import cv2 + +def makeEdgeMask(mask, width): + kernel = np.ones((width,width), np.uint8) + + erosion = cv2.erode(mask, kernel, iterations = 1) + dilation = cv2.dilate(mask, kernel, iterations = 1) + + return dilation - erosion + +def makeTrimap(mask, width = 5): + edgeMask = makeEdgeMask(mask, width) + trimap = mask.astype(np.float) + trimap[edgeMask == 1] = 0.5 + return trimap \ No newline at end of file