mirror of https://github.com/ZHKKKe/MODNet.git
89 lines
3.7 KiB
Python
89 lines
3.7 KiB
Python
import os
|
|
import argparse
|
|
import logging
|
|
import logging.handlers
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import neptune.new as neptune
|
|
|
|
from src.models.modnet import MODNet
|
|
from src.trainer import supervised_training_iter
|
|
from src.train.dataset import SegDataset
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def parseArgs():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--datasetPath', type=str, required=True, help='path to dataset')
|
|
parser.add_argument('--modelsPath', type=str, required=True, help='path to save trained MODNet models')
|
|
parser.add_argument('--pretrainedPath', type=str, help='path of pre-trained MODNet')
|
|
parser.add_argument('--startEpoch', type=int, default=-1, help='epoch to start with')
|
|
parser.add_argument('--batchCount', type=int, default=16, help='batches count')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def train(modnet, datasetPath: str, batch_size: int, startEpoch: int, modelsPath: str):
|
|
lr = 0.01 # learn rate
|
|
epochs = 40 # total epochs
|
|
|
|
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=startEpoch)
|
|
|
|
dataset = SegDataset(os.path.join(datasetPath, "images"), os.path.join(datasetPath, "masks"))
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
pin_memory=True
|
|
)
|
|
|
|
project = 'motionlearning/modnet'
|
|
api_token = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI2NmJhMzJiMC1iY2M0LTQ5NGUtOGI0OS0yYzA4ZmNiNTRiMmEifQ=='
|
|
neptuneRun = neptune.init(project = project,
|
|
api_token = api_token,
|
|
source_files=[])
|
|
|
|
for epoch in range(0, epochs):
|
|
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
|
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1, detail_scale=10, matte_scale=1)
|
|
if idx % 100 == 0:
|
|
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
|
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
|
|
|
neptuneRun["training/epoch/semantic_loss"].log(semantic_loss)
|
|
neptuneRun["training/epoch/detail_loss"].log(detail_loss)
|
|
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
|
|
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
|
|
|
|
modelPath = os.path.join(modelsPath, f"model_epoch{epoch}.ckpt")
|
|
torch.save(modnet.state_dict(), modelPath)
|
|
logger.info(f"model saved to {modelPath}")
|
|
lr_scheduler.step()
|
|
|
|
torch.save(modnet.state_dict(), os.path.join(modelsPath, "model.ckpt"))
|
|
|
|
neptuneRun.stop()
|
|
|
|
def tune(modnet, datasetPath: str, batch_size: int, modelsPath: str):
|
|
pass
|
|
|
|
def twoStepTrain(datasetPath: str, batch_size: int, startEpoch, pretrainedPath, modelsPath):
|
|
|
|
modnet = MODNet(backbone_pretrained=True)
|
|
modnet = nn.DataParallel(modnet)
|
|
|
|
if pretrainedPath is not None:
|
|
modnet.load_state_dict(
|
|
torch.load(pretrainedPath)
|
|
)
|
|
|
|
train(modnet, datasetPath, batch_size, startEpoch, modelsPath)
|
|
tune(modnet, datasetPath, batch_size, modelsPath)
|
|
|
|
args = parseArgs()
|
|
|
|
train(args.datasetPath, args.batchCount, args.startEpoch, args.pretraindPath, args.modelsPath) |