DIPO / dataset /mydataset.py
xinjie.wang
init commit
c28dddb
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
import json
import numpy as np
from PIL import Image
import torchvision.transforms as T
from dataset.base_dataset import BaseDataset
import random
from tqdm import tqdm
import imageio
import torch
def make_white_background(src_img):
'''Make the white background for the input RGBA image.'''
src_img.load()
background = Image.new("RGB", src_img.size, (255, 255, 255))
background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
return background
class MyDataset(BaseDataset):
"""
Dataset for training and testing on the PartNet-Mobility and ACD datasets (with our preprocessing).
The GT graph is given.
"""
def __init__(self, hparams, model_ids, mode="train", json_name="object.json"):
self.hparams = hparams
self.json_name = json_name
self.model_ids = self._filter_models(model_ids)
self.mode = mode
self.map_cat = False
self.get_acd_mapping()
self.no_GT = (
True if self.hparams.get("test_no_GT", False) and self.hparams.get("test_pred_G", False)
else False
)
self.pred_G = (
False
if mode in ["train", "val"]
else self.hparams.get("test_pred_G", False)
)
if mode == 'test':
if "acd" in hparams.test_which:
self.map_cat = True
self.files = self._cache_data()
print(f"[INFO] {mode} dataset: {len(self)} data samples loaded.")
def _cache_data_train(self):
json_data_root = self.hparams.json_root
data_root = self.hparams.root
# number of views per model and in total
n_views_per_model = self.hparams.n_views_per_model
n_views = n_views_per_model * len(self.model_ids)
# json files for each model
json_files = []
# mapping to the index of the corresponding model in json_files
model_mappings = []
# space for dinov2 patch features
feats = np.empty((n_views, 512, 768), dtype=np.float16)
# space for object masks on image patches
obj_masks = np.empty((n_views, 256), dtype=bool)
# input images (not required in training)
imgs = None
# load data for non-aug views
i = 0 # index for views
for j, model_id in enumerate(self.model_ids):
print(model_id)
# if j % 10 == 0 and torch.distributed.get_rank() == 0:
# print(f"\rLoading training data: {j}/{len(self.model_ids)}")
# 3D data
with open(os.path.join(json_data_root, model_id, self.json_name), "r") as f:
json_file = json.load(f)
json_files.append(json_file)
filenames = os.listdir(os.path.join(data_root, model_id, 'features'))
filenames = [f for f in filenames if 'high_res' not in f]
filenames = filenames[:self.hparams.n_views_per_model]
for filename in filenames:
view_feat = np.load(os.path.join(data_root, model_id, 'features', filename))
first_frame_feat = view_feat[0]
if self.hparams.frame_mode == 'last_frame':
second_frame_feat = view_feat[-2]
elif self.hparams.frame_mode == 'random_state_frame':
second_frame_feat = view_feat[-1]
else:
raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
i = i + 1
model_mappings += [j] * n_views_per_model
# object masks for all views
# all_obj_masks = np.load(
# os.path.join(json_data_root, model_id, "features/patch_obj_masks.npy")
# ) # (20, Np)
# obj_masks[i : i + n_views_per_model] = all_obj_masks[:n_views_per_model]
return {
"len": n_views,
"gt_files": json_files,
"features": feats,
"obj_masks": None,
"model_mappings": model_mappings,
"imgs": imgs,
}
def _cache_data_non_train(self):
# number of views per model and in total
n_views_per_model = 2
n_views = n_views_per_model * len(self.model_ids)
# json files for each model
gt_files = []
pred_files = [] # for predicted graphs
# mapping to the index of the corresponding model in json_files
model_mappings = []
# space for dinov2 patch features
feats = np.empty((n_views, 512, 768), dtype=np.float16)
# space for input images
first_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
second_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
# transformation for input images
transform = T.Compose(
[
T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.Resize(128, interpolation=T.InterpolationMode.BICUBIC),
]
)
i = 0 # index for views
desc = f'Loading {self.mode} data'
for j, model_id in tqdm(enumerate(self.model_ids), total=len(self.model_ids), desc=desc):
with open(os.path.join(self.hparams.json_root, model_id, self.json_name), "r") as f:
json_file = json.load(f)
gt_files.append(json_file)
# filename_dir = os.path.join(self.hparams.root, model_id, 'features')
for filename in ['18.npy', '19.npy']:
view_feat = np.load(os.path.join(self.hparams.root, model_id, 'features', filename))
first_frame_feat = view_feat[0]
if self.hparams.frame_mode == 'last_frame':
second_frame_feat = view_feat[-2]
elif self.hparams.frame_mode == 'random_state_frame':
second_frame_feat = view_feat[-1]
else:
raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
video_path = os.path.join(self.hparams.root, model_id, 'imgs', 'animation_' + filename.replace('.npy', '.mp4'))
reader = imageio.get_reader(video_path)
frames = []
for frame in reader:
frames.append(frame)
reader.close()
first_img = Image.fromarray(frames[0])
if first_img.mode == 'RGBA':
first_img = make_white_background(first_img)
first_img = np.asarray(transform(first_img), dtype=np.int8)
first_imgs[i] = first_img
if self.hparams.frame_mode == 'last_frame':
second_img = Image.fromarray(frames[-1])
elif self.hparams.frame_mode == 'random_state_frame':
second_img_path = video_path.replace('animation', 'random').replace('.mp4', '.png')
second_img = Image.open(second_img_path)
if second_img.mode == 'RGBA':
second_img = make_white_background(second_img)
second_img = np.asarray(transform(second_img), dtype=np.int8)
second_imgs[i] = second_img
i = i + 1
# mapping to json file
model_mappings += [j] * n_views_per_model
return {
"len": n_views,
"gt_files": gt_files,
"pred_files": pred_files,
"features": feats,
"model_mappings": model_mappings,
"imgs": [first_imgs, second_imgs],
}
def _cache_data(self):
"""
Function to cache data from disk.
"""
if self.mode == "train":
return self._cache_data_train()
else:
return self._cache_data_non_train()
def _get_item_train_val(self, index):
model_i = self.files["model_mappings"][index]
gt_file = self.files["gt_files"][model_i]
data, cond = self._prepare_input_GT(
file=gt_file, model_id=self.model_ids[model_i]
)
if self.mode == "val":
# input image for visualization
img_first = self.files["imgs"][0][index]
img_last = self.files["imgs"][1][index]
cond["img"] = np.concatenate([img_first, img_last], axis=1)
# else:
# # object masks on patches
# # obj_mask = self.files["obj_masks"][index][None, ...].repeat(self.hparams.K * 5, axis=0)
# cond["img_obj_mask"] = [None]
return data, cond
def _get_item_test(self, index):
model_i = self.files["model_mappings"][index]
gt_file = None if self.no_GT else self.files["gt_files"][model_i]
if self.hparams.get('G_dir', None) is None:
data, cond = self._prepare_input_GT(file=gt_file, model_id=self.model_ids[model_i])
else:
if index % 2 == 0:
filename = '18.json'
else:
filename = '19.json'
pred_file_path = os.path.join(self.hparams.G_dir, self.model_ids[model_i], filename)
with open(pred_file_path, "r") as f:
pred_file = json.load(f)
data, cond = self._prepare_input(model_id=self.model_ids[model_i], pred_file=pred_file, gt_file=gt_file)
# input image for visualization
img_first = self.files["imgs"][0][index]
img_last = self.files["imgs"][1][index]
cond["img"] = np.concatenate([img_first, img_last], axis=1)
return data, cond
def __getitem__(self, index):
# input image features
feat = self.files["features"][index]
# prepare input, GT data and other axillary info
if self.mode == "test":
data, cond = self._get_item_test(index)
else:
data, cond = self._get_item_train_val(index)
return data, cond, feat
def __len__(self):
return self.files["len"]
if __name__ == '__main__':
from types import SimpleNamespace
class EnhancedNamespace(SimpleNamespace):
def get(self, key, default=None):
return getattr(self, key, default)
hparams = {
"name": "dm_singapo",
"json_root": "/home/users/ruiqi.wu/singapo/", # root directory of the dataset
"batch_size": 20, # batch size for training
"num_workers": 8, # number of workers for data loading
"K": 32, # maximum number of nodes (parts) in the graph (object)
"split_file": "/home/users/ruiqi.wu/singapo/data/data_split.json",
"n_views_per_model": 5,
"root": "/home/users/ruiqi.wu/manipulate_3d_generate/data/blender_version",
"frame_mode": "last_frame"
}
hparams = EnhancedNamespace(**hparams)
with open(hparams.split_file , "r") as f:
splits = json.load(f)
train_ids = splits["train"]
val_ids = [i for i in train_ids if "augmented" not in i]
val_ids = [val_id for val_id in val_ids if os.path.exists(os.path.join(hparams.root, val_id, "features"))]
dataset = MyDataset(hparams, model_ids=val_ids[:20], mode="valid")
for i in range(20):
data, cond, feat = dataset.__getitem__(i)
import ipdb
ipdb.set_trace()