mirror of https://github.com/ZHKKKe/MODNet.git
调试模型训练代码
parent
58c54f6e6c
commit
27db49f9de
|
|
@ -94,4 +94,6 @@ ENV/
|
|||
|
||||
|
||||
# Project files
|
||||
.vscode
|
||||
.vscode
|
||||
.idea/
|
||||
src/datasets/PPM-100
|
||||
|
|
@ -139,7 +139,13 @@ class Normalize(object):
|
|||
|
||||
class ToTrainArray(object):
|
||||
def __call__(self, sample):
|
||||
return [sample['image'], sample['trimap'], sample['gt_matte']]
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
image = sample['image'].to(device)
|
||||
trimap = sample['trimap'].to(device)
|
||||
gt_matte = sample['gt_matte'].to(device)
|
||||
|
||||
return [image, trimap, gt_matte]
|
||||
# return [sample['image'], sample['trimap'], sample['gt_matte']]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
@ -154,7 +160,7 @@ if __name__ == '__main__':
|
|||
transform = transforms.Compose([
|
||||
Rescale(512),
|
||||
ToTensor(),
|
||||
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
# Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
mattingDataset = MattingDataset(transform=transform)
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,9 @@ class GaussianBlurLayer(nn.Module):
|
|||
# MODNet Training Functions
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
blurer = GaussianBlurLayer(1, 3) #.cuda()
|
||||
blurer = GaussianBlurLayer(1, 3) #.cuda
|
||||
if torch.cuda.is_available():
|
||||
blurer.cuda()
|
||||
|
||||
|
||||
def supervised_training_iter(
|
||||
|
|
@ -317,6 +319,9 @@ if __name__ == '__main__':
|
|||
epochs = 40 # total epochs
|
||||
|
||||
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
|
||||
if torch.cuda.is_available():
|
||||
modnet = modnet.cuda()
|
||||
|
||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue