diff --git a/src/train/train.py b/src/train/train.py index 57e666b..7bc554b 100644 --- a/src/train/train.py +++ b/src/train/train.py @@ -25,56 +25,65 @@ def parseArgs(): args = parser.parse_args() return args -args = parseArgs() +def train(modnet, datasetPath: str, batch_size: int, startEpoch: int, modelsPath: str): + lr = 0.01 # learn rate + epochs = 40 # total epochs -batch_size = args.batchCount -lr = 0.01 # learn rate -epochs = 40 # total epochs + optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=startEpoch) -modnet = MODNet(backbone_pretrained=False) -modnet = nn.DataParallel(modnet) + dataset = SegDataset(os.path.join(datasetPath, "images"), os.path.join(datasetPath, "masks")) -if args.pretrainedPath is not None: - modnet.load_state_dict( - torch.load(args.pretrainedPath) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=True ) -optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) -lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=args.startEpoch) + project = 'stask/modnet' + api_token = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmMmU1ZDZlZC00OWQ5LTQ0ODUtYmExMi0zN2M3MTA5ZmM4ZDcifQ==' + neptuneRun = neptune.init(project = project, + api_token = api_token, + source_files=[]) -dataset = SegDataset(os.path.join(args.datasetPath, "images"), os.path.join(args.datasetPath, "masks")) + for epoch in range(0, epochs): + for idx, (image, trimap, gt_matte) in enumerate(dataloader): + semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1, detail_scale=10, matte_scale=1) + if idx % 100 == 0: + logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') + logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') + + neptuneRun["training/epoch/semantic_loss"].log(semantic_loss) + neptuneRun["training/epoch/detail_loss"].log(detail_loss) + neptuneRun["training/epoch/matte_loss"].log(matte_loss) + neptuneRun["training/epoch/semantic_iou"].log(semantic_iou) -dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=batch_size, - shuffle=True, - pin_memory=True -) + modelPath = os.path.join(modelsPath, f"model_epoch{epoch}.ckpt") + torch.save(modnet.state_dict(), modelPath) + logger.info(f"model saved to {modelPath}") + lr_scheduler.step() + torch.save(modnet.state_dict(), os.path.join(modelsPath, "model.ckpt")) -project = '' -api_token = '' -neptuneRun = neptune.init(project = project, - api_token = api_token, - source_files=[]) + neptuneRun.stop() -for epoch in range(0, epochs): - for idx, (image, trimap, gt_matte) in enumerate(dataloader): - semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1) - if idx % 100 == 0: - logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') - logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') - - neptuneRun["training/epoch/semantic_loss"].log(semantic_loss) - neptuneRun["training/epoch/detail_loss"].log(detail_loss) - neptuneRun["training/epoch/matte_loss"].log(matte_loss) - neptuneRun["training/epoch/semantic_iou"].log(semantic_iou) +def tune(modnet, datasetPath: str, batch_size: int, modelsPath: str): + pass - modelPath = os.path.join(args.modelsPath, f"model_epoch{epoch}.ckpt") - torch.save(modnet.state_dict(), modelPath) - logger.info(f"model saved to {modelPath}") - lr_scheduler.step() +def twoStepTrain(datasetPath: str, batch_size: int, startEpoch, pretrainedPath, modelsPath): -torch.save(modnet.state_dict(), os.path.join(args.modelsPath, "model.ckpt")) + modnet = MODNet(backbone_pretrained=True) + modnet = nn.DataParallel(modnet) -neptuneRun.stop() + if pretrainedPath is not None: + modnet.load_state_dict( + torch.load(pretrainedPath) + ) + + train(modnet, datasetPath, batch_size, startEpoch, modelsPath) + tune(modnet, datasetPath, batch_size, modelsPath) + +args = parseArgs() + +train(args.datasetPath, args.batchCount, args.startEpoch, args.pretraindPath, args.modelsPath) \ No newline at end of file