diff --git a/src/trainer.py b/src/trainer.py index bff6c82..21ce5e6 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -93,9 +93,12 @@ def supervised_training_iter( modnet (torch.nn.Module): instance of MODNet optimizer (torch.optim.Optimizer): optimizer for supervised training image (torch.autograd.Variable): input RGB image + its pixel values should be normalized trimap (torch.autograd.Variable): trimap used to calculate the losses - NOTE: foreground=1, background=0, unknown=0.5 + its pixel values can be 0, 0.5, or 1 + (foreground=1, background=0, unknown=0.5) gt_matte (torch.autograd.Variable): ground truth alpha matte + its pixel values are between [0, 1] semantic_scale (float): scale of the semantic loss NOTE: please adjust according to your dataset detail_scale (float): scale of the detail loss @@ -184,6 +187,7 @@ def soc_adaptation_iter( backup_modnet (torch.nn.Module): backup of the trained MODNet optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC image (torch.autograd.Variable): input RGB image + its pixel values should be normalized soc_semantic_scale (float): scale of the SOC semantic loss NOTE: please adjust according to your dataset soc_detail_scale (float): scale of the SOC detail loss