From 3850cf610602267522cccccd75e26fdacc3983e1 Mon Sep 17 00:00:00 2001 From: Zhanghan Ke Date: Tue, 18 May 2021 03:01:51 +0800 Subject: [PATCH] update SOC code --- src/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/trainer.py b/src/trainer.py index 21ce5e6..bd3d8be 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -277,12 +277,12 @@ def soc_adaptation_iter( # NOTE: using the formulas in our paper to calculate the following losses has similar results # sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only) - backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail) + backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail, reduction='none') backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) backup_detail_loss = torch.mean(backup_detail_loss) # sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only) - backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte) + backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte, reduction='none') backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) backup_matte_loss = torch.mean(backup_matte_loss)