mirror of https://github.com/ZHKKKe/MODNet.git
[*] two step train draft
parent
5b02db5b2b
commit
4725324d53
|
|
@ -25,42 +25,31 @@ def parseArgs():
|
|||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = parseArgs()
|
||||
def train(modnet, datasetPath: str, batch_size: int, startEpoch: int, modelsPath: str):
|
||||
lr = 0.01 # learn rate
|
||||
epochs = 40 # total epochs
|
||||
|
||||
batch_size = args.batchCount
|
||||
lr = 0.01 # learn rate
|
||||
epochs = 40 # total epochs
|
||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=startEpoch)
|
||||
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
modnet = nn.DataParallel(modnet)
|
||||
dataset = SegDataset(os.path.join(datasetPath, "images"), os.path.join(datasetPath, "masks"))
|
||||
|
||||
if args.pretrainedPath is not None:
|
||||
modnet.load_state_dict(
|
||||
torch.load(args.pretrainedPath)
|
||||
)
|
||||
|
||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1, last_epoch=args.startEpoch)
|
||||
|
||||
dataset = SegDataset(os.path.join(args.datasetPath, "images"), os.path.join(args.datasetPath, "masks"))
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
project = '<YOUR_WORKSPACE/YOUR_PROJECT>'
|
||||
api_token = '<YOURR_API_TOKEN>'
|
||||
neptuneRun = neptune.init(project = project,
|
||||
project = 'stask/modnet'
|
||||
api_token = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmMmU1ZDZlZC00OWQ5LTQ0ODUtYmExMi0zN2M3MTA5ZmM4ZDcifQ=='
|
||||
neptuneRun = neptune.init(project = project,
|
||||
api_token = api_token,
|
||||
source_files=[])
|
||||
|
||||
for epoch in range(0, epochs):
|
||||
for epoch in range(0, epochs):
|
||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1)
|
||||
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1, detail_scale=10, matte_scale=1)
|
||||
if idx % 100 == 0:
|
||||
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||
|
|
@ -70,11 +59,31 @@ for epoch in range(0, epochs):
|
|||
neptuneRun["training/epoch/matte_loss"].log(matte_loss)
|
||||
neptuneRun["training/epoch/semantic_iou"].log(semantic_iou)
|
||||
|
||||
modelPath = os.path.join(args.modelsPath, f"model_epoch{epoch}.ckpt")
|
||||
modelPath = os.path.join(modelsPath, f"model_epoch{epoch}.ckpt")
|
||||
torch.save(modnet.state_dict(), modelPath)
|
||||
logger.info(f"model saved to {modelPath}")
|
||||
lr_scheduler.step()
|
||||
|
||||
torch.save(modnet.state_dict(), os.path.join(args.modelsPath, "model.ckpt"))
|
||||
torch.save(modnet.state_dict(), os.path.join(modelsPath, "model.ckpt"))
|
||||
|
||||
neptuneRun.stop()
|
||||
neptuneRun.stop()
|
||||
|
||||
def tune(modnet, datasetPath: str, batch_size: int, modelsPath: str):
|
||||
pass
|
||||
|
||||
def twoStepTrain(datasetPath: str, batch_size: int, startEpoch, pretrainedPath, modelsPath):
|
||||
|
||||
modnet = MODNet(backbone_pretrained=True)
|
||||
modnet = nn.DataParallel(modnet)
|
||||
|
||||
if pretrainedPath is not None:
|
||||
modnet.load_state_dict(
|
||||
torch.load(pretrainedPath)
|
||||
)
|
||||
|
||||
train(modnet, datasetPath, batch_size, startEpoch, modelsPath)
|
||||
tune(modnet, datasetPath, batch_size, modelsPath)
|
||||
|
||||
args = parseArgs()
|
||||
|
||||
train(args.datasetPath, args.batchCount, args.startEpoch, args.pretraindPath, args.modelsPath)
|
||||
Loading…
Reference in New Issue