From 4e2f7ddfbec151195a70cf091597a69f2064ca92 Mon Sep 17 00:00:00 2001 From: Kapulkin Stanislav Date: Tue, 8 Feb 2022 19:24:30 +0300 Subject: [PATCH] [*] image resizing in dataset --- src/train/dataset.py | 34 ++++++++++++++++++++++------------ src/train/train.py | 2 +- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/train/dataset.py b/src/train/dataset.py index 20b3ab1..d1aca4c 100644 --- a/src/train/dataset.py +++ b/src/train/dataset.py @@ -2,7 +2,7 @@ import os import glob import cv2 from pathlib import Path -from typing import Callable, List, Optional, Tuple, Union, Dict +from typing import Tuple, Union import torch import torchvision @@ -16,34 +16,44 @@ class SegDataset(Dataset): """ def __init__( - self, - frame_dir: Union[str, Path], - mask_dir: Union[str, Path] + self, + frame_dir: Union[str, Path], + mask_dir: Union[str, Path] ) -> None: self.frame_dir = Path(frame_dir) self.mask_dir = Path(mask_dir) self.image_names = glob.glob(f"{self.frame_dir}/*.jpg") - self.mask_names = [os.path.join(self.mask_dir,"mask"+x.split('/')[-1][:-4][5:]+".png") for x in self.image_names] - print(self.mask_names) + self.mask_names = [os.path.join(self.mask_dir,(x.split('/')[-1])[:-4]+".png") for x in self.image_names] + #print(self.mask_names) self.transform = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) - + torchvision.transforms.ToPILImage(), + torchvision.transforms.Resize((512,512)), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.transform2 = torchvision.transforms.Compose([ + torchvision.transforms.ToPILImage(), + torchvision.transforms.Resize((512,512)), + torchvision.transforms.ToTensor() + ]) def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: frame_pth = self.image_names[index] mask_pth = self.mask_names[index] - frame = cv2.imread(frame_pth) + frame = cv2.imread(frame_pth) frame = self.transform(frame) - mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE) + mask = cv2.imread(mask_pth,cv2.IMREAD_GRAYSCALE) trimap = torch.from_numpy(makeTrimap(mask)).float() - trimap = torch.unsqueeze(trimap,0) + trimap = torch.unsqueeze(trimap,0) mask = torch.from_numpy(mask) mask = torch.unsqueeze(mask,0).float() + mask = self.transform2(mask) + trimap = self.transform2(trimap) + return frame, trimap, mask def __len__(self): diff --git a/src/train/train.py b/src/train/train.py index 458655d..57e666b 100644 --- a/src/train/train.py +++ b/src/train/train.py @@ -60,7 +60,7 @@ neptuneRun = neptune.init(project = project, for epoch in range(0, epochs): for idx, (image, trimap, gt_matte) in enumerate(dataloader): - semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, semantic_scale=1) + semantic_loss, detail_loss, matte_loss, semantic_iou = supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) # , semantic_scale=1, detail_scale=10, matte_scale=1) if idx % 100 == 0: logger.info(f'idx: {idx}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}') logger.info(f'Epoch: {epoch}, semantic_loss: {semantic_loss:.5f}, detail_loss: {detail_loss:.5f}, matte_loss: {matte_loss:.5f}, semantic_iou: {semantic_iou:.5f}')