mirror of https://github.com/ZHKKKe/MODNet.git
[*] image resizing in dataset
parent
b6f9f63d83
commit
4e2f7ddfbe
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import glob
|
import glob
|
||||||
import cv2
|
import cv2
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Tuple, Union, Dict
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
@ -16,34 +16,44 @@ class SegDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
frame_dir: Union[str, Path],
|
frame_dir: Union[str, Path],
|
||||||
mask_dir: Union[str, Path]
|
mask_dir: Union[str, Path]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.frame_dir = Path(frame_dir)
|
self.frame_dir = Path(frame_dir)
|
||||||
self.mask_dir = Path(mask_dir)
|
self.mask_dir = Path(mask_dir)
|
||||||
self.image_names = glob.glob(f"{self.frame_dir}/*.jpg")
|
self.image_names = glob.glob(f"{self.frame_dir}/*.jpg")
|
||||||
self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names]
|
self.mask_names = [os.path.join(self.mask_dir,(x.split('/')[-1])[:-4]+".png") for x in self.image_names]
|
||||||
print(self.mask_names)
|
#print(self.mask_names)
|
||||||
self.transform = torchvision.transforms.Compose([
|
self.transform = torchvision.transforms.Compose([
|
||||||
torchvision.transforms.ToTensor(),
|
torchvision.transforms.ToPILImage(),
|
||||||
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
torchvision.transforms.Resize((512,512)),
|
||||||
|
torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
self.transform2 = torchvision.transforms.Compose([
|
||||||
|
torchvision.transforms.ToPILImage(),
|
||||||
|
torchvision.transforms.Resize((512,512)),
|
||||||
|
torchvision.transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
frame_pth = self.image_names[index]
|
frame_pth = self.image_names[index]
|
||||||
mask_pth = self.mask_names[index]
|
mask_pth = self.mask_names[index]
|
||||||
|
|
||||||
frame = cv2.imread(frame_pth)
|
frame = cv2.imread(frame_pth)
|
||||||
frame = self.transform(frame)
|
frame = self.transform(frame)
|
||||||
|
|
||||||
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
|
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
|
||||||
trimap = torch.from_numpy(makeTrimap(mask)).float()
|
trimap = torch.from_numpy(makeTrimap(mask)).float()
|
||||||
trimap = torch.unsqueeze(trimap,0)
|
trimap = torch.unsqueeze(trimap,0)
|
||||||
mask = torch.from_numpy(mask)
|
mask = torch.from_numpy(mask)
|
||||||
mask = torch.unsqueeze(mask,0).float()
|
mask = torch.unsqueeze(mask,0).float()
|
||||||
|
|
||||||
|
mask = self.transform2(mask)
|
||||||
|
trimap = self.transform2(trimap)
|
||||||
|
|
||||||
return frame, trimap, mask
|
return frame, trimap, mask
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ neptuneRun = neptune.init(project = project,
|
||||||
|
|
||||||
for epoch in range(0, epochs):
|
for epoch in range(0, epochs):
|
||||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||||
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1)
|
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1)
|
||||||
if idx % 100 == 0:
|
if idx % 100 == 0:
|
||||||
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||||
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue