Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) | |
| import json | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| from dataset.base_dataset import BaseDataset | |
| import random | |
| from tqdm import tqdm | |
| import imageio | |
| import torch | |
| def make_white_background(src_img): | |
| '''Make the white background for the input RGBA image.''' | |
| src_img.load() | |
| background = Image.new("RGB", src_img.size, (255, 255, 255)) | |
| background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel | |
| return background | |
| class MyDataset(BaseDataset): | |
| """ | |
| Dataset for training and testing on the PartNet-Mobility and ACD datasets (with our preprocessing). | |
| The GT graph is given. | |
| """ | |
| def __init__(self, hparams, model_ids, mode="train", json_name="object.json"): | |
| self.hparams = hparams | |
| self.json_name = json_name | |
| self.model_ids = self._filter_models(model_ids) | |
| self.mode = mode | |
| self.map_cat = False | |
| self.get_acd_mapping() | |
| self.no_GT = ( | |
| True if self.hparams.get("test_no_GT", False) and self.hparams.get("test_pred_G", False) | |
| else False | |
| ) | |
| self.pred_G = ( | |
| False | |
| if mode in ["train", "val"] | |
| else self.hparams.get("test_pred_G", False) | |
| ) | |
| if mode == 'test': | |
| if "acd" in hparams.test_which: | |
| self.map_cat = True | |
| self.files = self._cache_data() | |
| print(f"[INFO] {mode} dataset: {len(self)} data samples loaded.") | |
| def _cache_data_train(self): | |
| json_data_root = self.hparams.json_root | |
| data_root = self.hparams.root | |
| # number of views per model and in total | |
| n_views_per_model = self.hparams.n_views_per_model | |
| n_views = n_views_per_model * len(self.model_ids) | |
| # json files for each model | |
| json_files = [] | |
| # mapping to the index of the corresponding model in json_files | |
| model_mappings = [] | |
| # space for dinov2 patch features | |
| feats = np.empty((n_views, 512, 768), dtype=np.float16) | |
| # space for object masks on image patches | |
| obj_masks = np.empty((n_views, 256), dtype=bool) | |
| # input images (not required in training) | |
| imgs = None | |
| # load data for non-aug views | |
| i = 0 # index for views | |
| for j, model_id in enumerate(self.model_ids): | |
| print(model_id) | |
| # if j % 10 == 0 and torch.distributed.get_rank() == 0: | |
| # print(f"\rLoading training data: {j}/{len(self.model_ids)}") | |
| # 3D data | |
| with open(os.path.join(json_data_root, model_id, self.json_name), "r") as f: | |
| json_file = json.load(f) | |
| json_files.append(json_file) | |
| filenames = os.listdir(os.path.join(data_root, model_id, 'features')) | |
| filenames = [f for f in filenames if 'high_res' not in f] | |
| filenames = filenames[:self.hparams.n_views_per_model] | |
| for filename in filenames: | |
| view_feat = np.load(os.path.join(data_root, model_id, 'features', filename)) | |
| first_frame_feat = view_feat[0] | |
| if self.hparams.frame_mode == 'last_frame': | |
| second_frame_feat = view_feat[-2] | |
| elif self.hparams.frame_mode == 'random_state_frame': | |
| second_frame_feat = view_feat[-1] | |
| else: | |
| raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame") | |
| feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16) | |
| feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16) | |
| i = i + 1 | |
| model_mappings += [j] * n_views_per_model | |
| # object masks for all views | |
| # all_obj_masks = np.load( | |
| # os.path.join(json_data_root, model_id, "features/patch_obj_masks.npy") | |
| # ) # (20, Np) | |
| # obj_masks[i : i + n_views_per_model] = all_obj_masks[:n_views_per_model] | |
| return { | |
| "len": n_views, | |
| "gt_files": json_files, | |
| "features": feats, | |
| "obj_masks": None, | |
| "model_mappings": model_mappings, | |
| "imgs": imgs, | |
| } | |
| def _cache_data_non_train(self): | |
| # number of views per model and in total | |
| n_views_per_model = 2 | |
| n_views = n_views_per_model * len(self.model_ids) | |
| # json files for each model | |
| gt_files = [] | |
| pred_files = [] # for predicted graphs | |
| # mapping to the index of the corresponding model in json_files | |
| model_mappings = [] | |
| # space for dinov2 patch features | |
| feats = np.empty((n_views, 512, 768), dtype=np.float16) | |
| # space for input images | |
| first_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8) | |
| second_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8) | |
| # transformation for input images | |
| transform = T.Compose( | |
| [ | |
| T.Resize(256, interpolation=T.InterpolationMode.BICUBIC), | |
| T.CenterCrop(224), | |
| T.Resize(128, interpolation=T.InterpolationMode.BICUBIC), | |
| ] | |
| ) | |
| i = 0 # index for views | |
| desc = f'Loading {self.mode} data' | |
| for j, model_id in tqdm(enumerate(self.model_ids), total=len(self.model_ids), desc=desc): | |
| with open(os.path.join(self.hparams.json_root, model_id, self.json_name), "r") as f: | |
| json_file = json.load(f) | |
| gt_files.append(json_file) | |
| # filename_dir = os.path.join(self.hparams.root, model_id, 'features') | |
| for filename in ['18.npy', '19.npy']: | |
| view_feat = np.load(os.path.join(self.hparams.root, model_id, 'features', filename)) | |
| first_frame_feat = view_feat[0] | |
| if self.hparams.frame_mode == 'last_frame': | |
| second_frame_feat = view_feat[-2] | |
| elif self.hparams.frame_mode == 'random_state_frame': | |
| second_frame_feat = view_feat[-1] | |
| else: | |
| raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame") | |
| feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16) | |
| feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16) | |
| video_path = os.path.join(self.hparams.root, model_id, 'imgs', 'animation_' + filename.replace('.npy', '.mp4')) | |
| reader = imageio.get_reader(video_path) | |
| frames = [] | |
| for frame in reader: | |
| frames.append(frame) | |
| reader.close() | |
| first_img = Image.fromarray(frames[0]) | |
| if first_img.mode == 'RGBA': | |
| first_img = make_white_background(first_img) | |
| first_img = np.asarray(transform(first_img), dtype=np.int8) | |
| first_imgs[i] = first_img | |
| if self.hparams.frame_mode == 'last_frame': | |
| second_img = Image.fromarray(frames[-1]) | |
| elif self.hparams.frame_mode == 'random_state_frame': | |
| second_img_path = video_path.replace('animation', 'random').replace('.mp4', '.png') | |
| second_img = Image.open(second_img_path) | |
| if second_img.mode == 'RGBA': | |
| second_img = make_white_background(second_img) | |
| second_img = np.asarray(transform(second_img), dtype=np.int8) | |
| second_imgs[i] = second_img | |
| i = i + 1 | |
| # mapping to json file | |
| model_mappings += [j] * n_views_per_model | |
| return { | |
| "len": n_views, | |
| "gt_files": gt_files, | |
| "pred_files": pred_files, | |
| "features": feats, | |
| "model_mappings": model_mappings, | |
| "imgs": [first_imgs, second_imgs], | |
| } | |
| def _cache_data(self): | |
| """ | |
| Function to cache data from disk. | |
| """ | |
| if self.mode == "train": | |
| return self._cache_data_train() | |
| else: | |
| return self._cache_data_non_train() | |
| def _get_item_train_val(self, index): | |
| model_i = self.files["model_mappings"][index] | |
| gt_file = self.files["gt_files"][model_i] | |
| data, cond = self._prepare_input_GT( | |
| file=gt_file, model_id=self.model_ids[model_i] | |
| ) | |
| if self.mode == "val": | |
| # input image for visualization | |
| img_first = self.files["imgs"][0][index] | |
| img_last = self.files["imgs"][1][index] | |
| cond["img"] = np.concatenate([img_first, img_last], axis=1) | |
| # else: | |
| # # object masks on patches | |
| # # obj_mask = self.files["obj_masks"][index][None, ...].repeat(self.hparams.K * 5, axis=0) | |
| # cond["img_obj_mask"] = [None] | |
| return data, cond | |
| def _get_item_test(self, index): | |
| model_i = self.files["model_mappings"][index] | |
| gt_file = None if self.no_GT else self.files["gt_files"][model_i] | |
| if self.hparams.get('G_dir', None) is None: | |
| data, cond = self._prepare_input_GT(file=gt_file, model_id=self.model_ids[model_i]) | |
| else: | |
| if index % 2 == 0: | |
| filename = '18.json' | |
| else: | |
| filename = '19.json' | |
| pred_file_path = os.path.join(self.hparams.G_dir, self.model_ids[model_i], filename) | |
| with open(pred_file_path, "r") as f: | |
| pred_file = json.load(f) | |
| data, cond = self._prepare_input(model_id=self.model_ids[model_i], pred_file=pred_file, gt_file=gt_file) | |
| # input image for visualization | |
| img_first = self.files["imgs"][0][index] | |
| img_last = self.files["imgs"][1][index] | |
| cond["img"] = np.concatenate([img_first, img_last], axis=1) | |
| return data, cond | |
| def __getitem__(self, index): | |
| # input image features | |
| feat = self.files["features"][index] | |
| # prepare input, GT data and other axillary info | |
| if self.mode == "test": | |
| data, cond = self._get_item_test(index) | |
| else: | |
| data, cond = self._get_item_train_val(index) | |
| return data, cond, feat | |
| def __len__(self): | |
| return self.files["len"] | |
| if __name__ == '__main__': | |
| from types import SimpleNamespace | |
| class EnhancedNamespace(SimpleNamespace): | |
| def get(self, key, default=None): | |
| return getattr(self, key, default) | |
| hparams = { | |
| "name": "dm_singapo", | |
| "json_root": "/home/users/ruiqi.wu/singapo/", # root directory of the dataset | |
| "batch_size": 20, # batch size for training | |
| "num_workers": 8, # number of workers for data loading | |
| "K": 32, # maximum number of nodes (parts) in the graph (object) | |
| "split_file": "/home/users/ruiqi.wu/singapo/data/data_split.json", | |
| "n_views_per_model": 5, | |
| "root": "/home/users/ruiqi.wu/manipulate_3d_generate/data/blender_version", | |
| "frame_mode": "last_frame" | |
| } | |
| hparams = EnhancedNamespace(**hparams) | |
| with open(hparams.split_file , "r") as f: | |
| splits = json.load(f) | |
| train_ids = splits["train"] | |
| val_ids = [i for i in train_ids if "augmented" not in i] | |
| val_ids = [val_id for val_id in val_ids if os.path.exists(os.path.join(hparams.root, val_id, "features"))] | |
| dataset = MyDataset(hparams, model_ids=val_ids[:20], mode="valid") | |
| for i in range(20): | |
| data, cond, feat = dataset.__getitem__(i) | |
| import ipdb | |
| ipdb.set_trace() |