MODNet/src/train/train.py

89 lines
3.7 KiB
Python

import os
import argparse
import logging
import logging.handlers
import torch
import torch.nn as nn
import neptune.new as neptune
from src.models.modnet import MODNet
from src.trainer import supervised_training_iter
from src.train.dataset import SegDataset
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
def parseArgs():
parser = argparse.ArgumentParser()
parser.add_argument('--datasetPath', type=str, required=True, help='path to dataset')
parser.add_argument('--modelsPath', type=str, required=True, help='path to save trained MODNet models')
parser.add_argument('--pretrainedPath', type=str, help='path of pre-trained MODNet')
parser.add_argument('--startEpoch', type=int, default=-1, help='epoch to start with')
parser.add_argument('--batchCount', type=int, default=16, help='batches count')
args = parser.parse_args()
return args
def train(modnet, datasetPath: str, batch_size: int, startEpoch: int, modelsPath: str):
lr = 0.01 # learn rate
epochs = 40 # total epochs
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=startEpoch)
dataset = SegDataset(os.path.join(datasetPath, "images"), os.path.join(datasetPath, "masks"))
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True
)
project = 'motionlearning/modnet'
api_token = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI2NmJhMzJiMC1iY2M0LTQ5NGUtOGI0OS0yYzA4ZmNiNTRiMmEifQ=='
neptuneRun = neptune.init(project = project,
api_token = api_token,
source_files=[])
for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1, detail_scale=10, matte_scale=1)
if idx % 100 == 0:
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
neptuneRun["training/epoch/semantic_loss"].log(semantic_loss)
neptuneRun["training/epoch/detail_loss"].log(detail_loss)
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
modelPath = os.path.join(modelsPath, f"model_epoch{epoch}.ckpt")
torch.save(modnet.state_dict(), modelPath)
logger.info(f"model saved to {modelPath}")
lr_scheduler.step()
torch.save(modnet.state_dict(), os.path.join(modelsPath, "model.ckpt"))
neptuneRun.stop()
def tune(modnet, datasetPath: str, batch_size: int, modelsPath: str):
pass
def twoStepTrain(datasetPath: str, batch_size: int, startEpoch, pretrainedPath, modelsPath):
modnet = MODNet(backbone_pretrained=True)
modnet = nn.DataParallel(modnet)
if pretrainedPath is not None:
modnet.load_state_dict(
torch.load(pretrainedPath)
)
train(modnet, datasetPath, batch_size, startEpoch, modelsPath)
tune(modnet, datasetPath, batch_size, modelsPath)
args = parseArgs()
train(args.datasetPath, args.batchCount, args.startEpoch, args.pretraindPath, args.modelsPath)