调试模型训练代码

pull/177/head
actboy 2022-02-09 21:45:36 +08:00
parent 58c54f6e6c
commit 27db49f9de
3 changed files with 17 additions and 4 deletions

4
.gitignore vendored
View File

@ -94,4 +94,6 @@ ENV/
# Project files
.vscode
.vscode
.idea/
src/datasets/PPM-100

View File

@ -139,7 +139,13 @@ class Normalize(object):
class ToTrainArray(object):
def __call__(self, sample):
return [sample['image'], sample['trimap'], sample['gt_matte']]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image = sample['image'].to(device)
trimap = sample['trimap'].to(device)
gt_matte = sample['gt_matte'].to(device)
return [image, trimap, gt_matte]
# return [sample['image'], sample['trimap'], sample['gt_matte']]
if __name__ == '__main__':
@ -154,7 +160,7 @@ if __name__ == '__main__':
transform = transforms.Compose([
Rescale(512),
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
mattingDataset = MattingDataset(transform=transform)

View File

@ -79,7 +79,9 @@ class GaussianBlurLayer(nn.Module):
# MODNet Training Functions
# ----------------------------------------------------------------------------------
blurer = GaussianBlurLayer(1, 3) #.cuda()
blurer = GaussianBlurLayer(1, 3) #.cuda
if torch.cuda.is_available():
blurer.cuda()
def supervised_training_iter(
@ -317,6 +319,9 @@ if __name__ == '__main__':
epochs = 40 # total epochs
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
if torch.cuda.is_available():
modnet = modnet.cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)