mirror of https://github.com/ZHKKKe/MODNet.git
[*] two step train draft
parent
5b02db5b2b
commit
4725324d53
|
|
@ -25,24 +25,14 @@ def parseArgs():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
args = parseArgs()
|
def train(modnet, datasetPath: str, batch_size: int, startEpoch: int, modelsPath: str):
|
||||||
|
|
||||||
batch_size = args.batchCount
|
|
||||||
lr = 0.01 # learn rate
|
lr = 0.01 # learn rate
|
||||||
epochs = 40 # total epochs
|
epochs = 40 # total epochs
|
||||||
|
|
||||||
modnet = MODNet(backbone_pretrained=False)
|
|
||||||
modnet = nn.DataParallel(modnet)
|
|
||||||
|
|
||||||
if args.pretrainedPath is not None:
|
|
||||||
modnet.load_state_dict(
|
|
||||||
torch.load(args.pretrainedPath)
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=args.startEpoch)
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=startEpoch)
|
||||||
|
|
||||||
dataset = SegDataset(os.path.join(args.datasetPath, "images"), os.path.join(args.datasetPath, "masks"))
|
dataset = SegDataset(os.path.join(datasetPath, "images"), os.path.join(datasetPath, "masks"))
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
@ -51,16 +41,15 @@ dataloader = torch.utils.data.DataLoader(
|
||||||
pin_memory=True
|
pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
project = 'stask/modnet'
|
||||||
project = '<YOUR_WORKSPACE/YOUR_PROJECT>'
|
api_token = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmMmU1ZDZlZC00OWQ5LTQ0ODUtYmExMi0zN2M3MTA5ZmM4ZDcifQ=='
|
||||||
api_token = '<YOURR_API_TOKEN>'
|
|
||||||
neptuneRun = neptune.init(project = project,
|
neptuneRun = neptune.init(project = project,
|
||||||
api_token = api_token,
|
api_token = api_token,
|
||||||
source_files=[])
|
source_files=[])
|
||||||
|
|
||||||
for epoch in range(0, epochs):
|
for epoch in range(0, epochs):
|
||||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||||
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1)
|
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1, detail_scale=10, matte_scale=1)
|
||||||
if idx % 100 == 0:
|
if idx % 100 == 0:
|
||||||
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||||
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||||
|
|
@ -70,11 +59,31 @@ for epoch in range(0, epochs):
|
||||||
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
|
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
|
||||||
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
|
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
|
||||||
|
|
||||||
modelPath = os.path.join(args.modelsPath, f"model_epoch{epoch}.ckpt")
|
modelPath = os.path.join(modelsPath, f"model_epoch{epoch}.ckpt")
|
||||||
torch.save(modnet.state_dict(), modelPath)
|
torch.save(modnet.state_dict(), modelPath)
|
||||||
logger.info(f"model saved to {modelPath}")
|
logger.info(f"model saved to {modelPath}")
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
torch.save(modnet.state_dict(), os.path.join(args.modelsPath, "model.ckpt"))
|
torch.save(modnet.state_dict(), os.path.join(modelsPath, "model.ckpt"))
|
||||||
|
|
||||||
neptuneRun.stop()
|
neptuneRun.stop()
|
||||||
|
|
||||||
|
def tune(modnet, datasetPath: str, batch_size: int, modelsPath: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def twoStepTrain(datasetPath: str, batch_size: int, startEpoch, pretrainedPath, modelsPath):
|
||||||
|
|
||||||
|
modnet = MODNet(backbone_pretrained=True)
|
||||||
|
modnet = nn.DataParallel(modnet)
|
||||||
|
|
||||||
|
if pretrainedPath is not None:
|
||||||
|
modnet.load_state_dict(
|
||||||
|
torch.load(pretrainedPath)
|
||||||
|
)
|
||||||
|
|
||||||
|
train(modnet, datasetPath, batch_size, startEpoch, modelsPath)
|
||||||
|
tune(modnet, datasetPath, batch_size, modelsPath)
|
||||||
|
|
||||||
|
args = parseArgs()
|
||||||
|
|
||||||
|
train(args.datasetPath, args.batchCount, args.startEpoch, args.pretraindPath, args.modelsPath)
|
||||||
Loading…
Reference in New Issue