File size: 3,499 Bytes
4e89a1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision.transforms as transforms
import os
from PIL import Image, ImageOps, ImageFilter
import os.path as osp
import sys
import random
import shutil
class IRSTD_Dataset(Data.Dataset):
def __init__(self, args, mode='train'):
dataset_dir = args.dataset_dir
if mode == 'train':
txtfile = 'trainval.txt'
elif mode == 'val':
txtfile = 'test.txt'
self.list_dir = osp.join(dataset_dir, txtfile)
self.imgs_dir = osp.join(dataset_dir, 'images')
self.label_dir = osp.join(dataset_dir, 'masks')
self.names = []
with open(self.list_dir, 'r') as f:
self.names += [line.strip() for line in f.readlines()]
self.mode = mode
self.crop_size = args.crop_size
self.base_size = args.base_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])
def __getitem__(self, i):
name = self.names[i]
img_path = osp.join(self.imgs_dir, name + '.png')
label_path = osp.join(self.label_dir, name + '.png')
img = Image.open(img_path).convert('RGB')
mask = Image.open(label_path)
if self.mode == 'train':
img, mask = self._sync_transform(img, mask)
elif self.mode == 'val':
img, mask = self._testval_sync_transform(img, mask)
else:
raise ValueError("Unkown self.mode")
img, mask = self.transform(img), transforms.ToTensor()(mask)
return img, mask
def __len__(self):
return len(self.names)
def _sync_transform(self, img, mask):
# random mirror
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
# random scale (short edge)
long_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
oh = long_size
ow = int(1.0 * w * long_size / h + 0.5)
short_size = ow
else:
ow = long_size
oh = int(1.0 * h * long_size / w + 0.5)
short_size = oh
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
padw = crop_size - ow if ow < crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
# gaussian blur as in PSP
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(
radius=random.random()))
return img, mask
def _testval_sync_transform(self, img, mask):
base_size = self.base_size
img = img.resize((base_size, base_size), Image.BILINEAR)
mask = mask.resize((base_size, base_size), Image.NEAREST)
return img, mask |