mirror of https://github.com/ZHKKKe/MODNet.git
[*] image resizing in dataset
parent
b6f9f63d83
commit
4e2f7ddfbe
|
|
@ -2,7 +2,7 @@ import os
|
|||
import glob
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple, Union, Dict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
|
@ -16,34 +16,44 @@ class SegDataset(Dataset):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_dir: Union[str, Path],
|
||||
mask_dir: Union[str, Path]
|
||||
self,
|
||||
frame_dir: Union[str, Path],
|
||||
mask_dir: Union[str, Path]
|
||||
) -> None:
|
||||
|
||||
self.frame_dir = Path(frame_dir)
|
||||
self.mask_dir = Path(mask_dir)
|
||||
self.image_names = glob.glob(f"{self.frame_dir}/*.jpg")
|
||||
self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names]
|
||||
print(self.mask_names)
|
||||
self.mask_names = [os.path.join(self.mask_dir,(x.split('/')[-1])[:-4]+".png") for x in self.image_names]
|
||||
#print(self.mask_names)
|
||||
self.transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
||||
|
||||
torchvision.transforms.ToPILImage(),
|
||||
torchvision.transforms.Resize((512,512)),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
self.transform2 = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToPILImage(),
|
||||
torchvision.transforms.Resize((512,512)),
|
||||
torchvision.transforms.ToTensor()
|
||||
])
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
frame_pth = self.image_names[index]
|
||||
mask_pth = self.mask_names[index]
|
||||
|
||||
frame = cv2.imread(frame_pth)
|
||||
frame = cv2.imread(frame_pth)
|
||||
frame = self.transform(frame)
|
||||
|
||||
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
|
||||
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
|
||||
trimap = torch.from_numpy(makeTrimap(mask)).float()
|
||||
trimap = torch.unsqueeze(trimap,0)
|
||||
trimap = torch.unsqueeze(trimap,0)
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = torch.unsqueeze(mask,0).float()
|
||||
|
||||
mask = self.transform2(mask)
|
||||
trimap = self.transform2(trimap)
|
||||
|
||||
return frame, trimap, mask
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ neptuneRun = neptune.init(project = project,
|
|||
|
||||
for epoch in range(0, epochs):
|
||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1)
|
||||
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1)
|
||||
if idx % 100 == 0:
|
||||
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
|
||||
|
|
|
|||
Loading…
Reference in New Issue