[*] image resizing in dataset

pull/170/head
Kapulkin Stanislav 2022-02-08 19:24:30 +03:00
parent b6f9f63d83
commit 4e2f7ddfbe
2 changed files with 23 additions and 13 deletions

View File

@ -2,7 +2,7 @@ import os
import glob import glob
import cv2 import cv2
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union, Dict from typing import Tuple, Union
import torch import torch
import torchvision import torchvision
@ -24,12 +24,19 @@ class SegDataset(Dataset):
self.frame_dir = Path(frame_dir) self.frame_dir = Path(frame_dir)
self.mask_dir = Path(mask_dir) self.mask_dir = Path(mask_dir)
self.image_names = glob.glob(f"{self.frame_dir}/*.jpg") self.image_names = glob.glob(f"{self.frame_dir}/*.jpg")
self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names] self.mask_names = [os.path.join(self.mask_dir,(x.split('/')[-1])[:-4]+".png") for x in self.image_names]
print(self.mask_names) #print(self.mask_names)
self.transform = torchvision.transforms.Compose([ self.transform = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize((512,512)),
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.transform2 = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize((512,512)),
torchvision.transforms.ToTensor()
])
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
frame_pth = self.image_names[index] frame_pth = self.image_names[index]
@ -44,6 +51,9 @@ class SegDataset(Dataset):
mask = torch.from_numpy(mask) mask = torch.from_numpy(mask)
mask = torch.unsqueeze(mask,0).float() mask = torch.unsqueeze(mask,0).float()
mask = self.transform2(mask)
trimap = self.transform2(trimap)
return frame, trimap, mask return frame, trimap, mask
def __len__(self): def __len__(self):

View File

@ -60,7 +60,7 @@ neptuneRun = neptune.init(project = project,
for epoch in range(0, epochs): for epoch in range(0, epochs):
for idx, (image, trimap, gt_matte) in enumerate(dataloader): for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1) semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1)
if idx % 100 == 0: if idx % 100 == 0:
logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')
logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')