[*] two step train draft

pull/170/head
Kapulkin Stanislav 2022-02-09 18:21:07 +03:00
parent 5b02db5b2b
commit 4725324d53
1 changed files with 49 additions and 40 deletions

View File

@ -25,24 +25,14 @@ def parseArgs():
args = parser.parse_args()
return args
args = parseArgs()
batch_size = args.batchCount
def train(modnet, datasetPath: str, batch_size: int, startEpoch: int, modelsPath: str):
lr = 0.01 # learn rate
epochs = 40 # total epochs
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
if args.pretrainedPath is not None:
modnet.load_state_dict(
torch.load(args.pretrainedPath)
)
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=args.startEpoch)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=startEpoch)
dataset = SegDataset(os.path.join(args.datasetPath, "images"), os.path.join(args.datasetPath, "masks"))
dataset = SegDataset(os.path.join(datasetPath, "images"), os.path.join(datasetPath, "masks"))
dataloader = torch.utils.data.DataLoader(
dataset,
@ -51,16 +41,15 @@ dataloader = torch.utils.data.DataLoader(
pin_memory=True
)
project = '<YOUR_WORKSPACE/YOUR_PROJECT>'
api_token = '<YOURR_API_TOKEN>'
project = 'stask/modnet'
api_token = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmMmU1ZDZlZC00OWQ5LTQ0ODUtYmExMi0zN2M3MTA5ZmM4ZDcifQ=='
neptuneRun = neptune.init(project = project,
api_token = api_token,
source_files=[])
for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1)
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1, detail_scale=10, matte_scale=1)
if idx % 100 == 0:
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
@ -70,11 +59,31 @@ for epoch in range(0, epochs):
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
modelPath = os.path.join(args.modelsPath, f"model_epoch{epoch}.ckpt")
modelPath = os.path.join(modelsPath, f"model_epoch{epoch}.ckpt")
torch.save(modnet.state_dict(), modelPath)
logger.info(f"model saved to {modelPath}")
lr_scheduler.step()
torch.save(modnet.state_dict(), os.path.join(args.modelsPath, "model.ckpt"))
torch.save(modnet.state_dict(), os.path.join(modelsPath, "model.ckpt"))
neptuneRun.stop()
def tune(modnet, datasetPath: str, batch_size: int, modelsPath: str):
pass
def twoStepTrain(datasetPath: str, batch_size: int, startEpoch, pretrainedPath, modelsPath):
modnet = MODNet(backbone_pretrained=True)
modnet = nn.DataParallel(modnet)
if pretrainedPath is not None:
modnet.load_state_dict(
torch.load(pretrainedPath)
)
train(modnet, datasetPath, batch_size, startEpoch, modelsPath)
tune(modnet, datasetPath, batch_size, modelsPath)
args = parseArgs()
train(args.datasetPath, args.batchCount, args.startEpoch, args.pretraindPath, args.modelsPath)