diff --git a/.gitignore b/.gitignore index 134f2ac..2c4e7f9 100644 --- a/.gitignore +++ b/.gitignore @@ -94,4 +94,6 @@ ENV/ # Project files -.vscode \ No newline at end of file +.vscode +.idea/ +src/datasets/PPM-100 \ No newline at end of file diff --git a/src/matting_dataset.py b/src/matting_dataset.py index 82ca78a..5eeba9b 100644 --- a/src/matting_dataset.py +++ b/src/matting_dataset.py @@ -139,7 +139,13 @@ class Normalize(object): class ToTrainArray(object): def __call__(self, sample): - return [sample['image'], sample['trimap'], sample['gt_matte']] + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + image = sample['image'].to(device) + trimap = sample['trimap'].to(device) + gt_matte = sample['gt_matte'].to(device) + + return [image, trimap, gt_matte] + # return [sample['image'], sample['trimap'], sample['gt_matte']] if __name__ == '__main__': @@ -154,7 +160,7 @@ if __name__ == '__main__': transform = transforms.Compose([ Rescale(512), ToTensor(), - Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + # Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) mattingDataset = MattingDataset(transform=transform) diff --git a/src/trainer.py b/src/trainer.py index 102fe53..02292c9 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -79,7 +79,9 @@ class GaussianBlurLayer(nn.Module): # MODNet Training Functions # ---------------------------------------------------------------------------------- -blurer = GaussianBlurLayer(1, 3) #.cuda() +blurer = GaussianBlurLayer(1, 3) #.cuda +if torch.cuda.is_available(): + blurer.cuda() def supervised_training_iter( @@ -317,6 +319,9 @@ if __name__ == '__main__': epochs = 40 # total epochs modnet = torch.nn.DataParallel(MODNet()) #.cuda() + if torch.cuda.is_available(): + modnet = modnet.cuda() + optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)