update SOC code

develop
Zhanghan Ke 2021-05-18 03:01:51 +08:00 committed by GitHub
parent bf6c3ea3f3
commit 3850cf6106
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -277,12 +277,12 @@ def soc_adaptation_iter(
# NOTE: using the formulas in our paper to calculate the following losses has similar results
# sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail)
backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail, reduction='none')
backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
backup_detail_loss = torch.mean(backup_detail_loss)
# sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte)
backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte, reduction='none')
backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
backup_matte_loss = torch.mean(backup_matte_loss)