[*] image resizing in dataset

pull/170/head
Kapulkin Stanislav 2022-02-08 19:24:30 +03:00
parent b6f9f63d83
commit 4e2f7ddfbe
2 changed files with 23 additions and 13 deletions

View File

@ -2,7 +2,7 @@ import os
import glob
import cv2
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union, Dict
from typing import Tuple, Union
import torch
import torchvision
@ -16,34 +16,44 @@ class SegDataset(Dataset):
"""
def __init__(
self,
frame_dir: Union[str, Path],
mask_dir: Union[str, Path]
self,
frame_dir: Union[str, Path],
mask_dir: Union[str, Path]
) -> None:
self.frame_dir = Path(frame_dir)
self.mask_dir = Path(mask_dir)
self.image_names = glob.glob(f"{self.frame_dir}/*.jpg")
self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names]
print(self.mask_names)
self.mask_names = [os.path.join(self.mask_dir,(x.split('/')[-1])[:-4]+".png") for x in self.image_names]
#print(self.mask_names)
self.transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize((512,512)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.transform2 = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize((512,512)),
torchvision.transforms.ToTensor()
])
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
frame_pth = self.image_names[index]
mask_pth = self.mask_names[index]
frame = cv2.imread(frame_pth)
frame = cv2.imread(frame_pth)
frame = self.transform(frame)
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE)
trimap = torch.from_numpy(makeTrimap(mask)).float()
trimap = torch.unsqueeze(trimap,0)
trimap = torch.unsqueeze(trimap,0)
mask = torch.from_numpy(mask)
mask = torch.unsqueeze(mask,0).float()
mask = self.transform2(mask)
trimap = self.transform2(trimap)
return frame, trimap, mask
def __len__(self):

View File

@ -60,7 +60,7 @@ neptuneRun = neptune.init(project = project,
for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1)
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1)
if idx % 100 == 0:
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')