mirror of https://github.com/ZHKKKe/MODNet.git
调试模型训练代码
parent
58c54f6e6c
commit
27db49f9de
|
|
@ -95,3 +95,5 @@ ENV/
|
||||||
|
|
||||||
# Project files
|
# Project files
|
||||||
.vscode
|
.vscode
|
||||||
|
.idea/
|
||||||
|
src/datasets/PPM-100
|
||||||
|
|
@ -139,7 +139,13 @@ class Normalize(object):
|
||||||
|
|
||||||
class ToTrainArray(object):
|
class ToTrainArray(object):
|
||||||
def __call__(self, sample):
|
def __call__(self, sample):
|
||||||
return [sample['image'], sample['trimap'], sample['gt_matte']]
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
image = sample['image'].to(device)
|
||||||
|
trimap = sample['trimap'].to(device)
|
||||||
|
gt_matte = sample['gt_matte'].to(device)
|
||||||
|
|
||||||
|
return [image, trimap, gt_matte]
|
||||||
|
# return [sample['image'], sample['trimap'], sample['gt_matte']]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
@ -154,7 +160,7 @@ if __name__ == '__main__':
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
Rescale(512),
|
Rescale(512),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
# Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
])
|
])
|
||||||
mattingDataset = MattingDataset(transform=transform)
|
mattingDataset = MattingDataset(transform=transform)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,9 @@ class GaussianBlurLayer(nn.Module):
|
||||||
# MODNet Training Functions
|
# MODNet Training Functions
|
||||||
# ----------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------
|
||||||
|
|
||||||
blurer = GaussianBlurLayer(1, 3) #.cuda()
|
blurer = GaussianBlurLayer(1, 3) #.cuda
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
blurer.cuda()
|
||||||
|
|
||||||
|
|
||||||
def supervised_training_iter(
|
def supervised_training_iter(
|
||||||
|
|
@ -317,6 +319,9 @@ if __name__ == '__main__':
|
||||||
epochs = 40 # total epochs
|
epochs = 40 # total epochs
|
||||||
|
|
||||||
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
|
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
modnet = modnet.cuda()
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue