Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| import json | |
| import numpy as np | |
| # import collections.abc | |
| # sys.modules['collections'].Mapping = collections.abc.Mapping | |
| import networkx as nx | |
| from torch.utils.data import Dataset | |
| from my_utils.refs import cat_ref, sem_ref, joint_ref, data_mode_ref | |
| from collections import deque | |
| def build_graph(tree, K=32): | |
| ''' | |
| Function to build graph from the node list. | |
| Args: | |
| nodes: list of nodes | |
| K: the maximum number of nodes in the graph | |
| Returns: | |
| adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes | |
| edge_list: list of edges, for visualization | |
| ''' | |
| adj = np.zeros((K, K), dtype=np.float32) | |
| parents = [] | |
| tree_list = [] | |
| for node in tree: | |
| tree_list.append( | |
| { | |
| 'id': node['id'], | |
| 'parent_id': node['parent'], | |
| } | |
| ) | |
| # 1-ring relationship | |
| if node['parent'] != -1: | |
| adj[node['id'], node['parent']] = 1 | |
| parents.append(node['parent']) | |
| else: | |
| adj[node['id'], node['id']] = 1 | |
| parents.append(-1) | |
| for child_id in node['children']: | |
| adj[node['id'], child_id] = 1 | |
| return { | |
| 'adj': adj, | |
| 'parents': np.array(parents, dtype=np.int8), | |
| 'tree_list': tree_list | |
| } | |
| from collections import defaultdict | |
| from functools import cmp_to_key | |
| def bfs_tree_simple(tree_list): | |
| order = [0] * len(tree_list) | |
| queue = [] | |
| current_node_idx = 0 | |
| for node_idx, node in enumerate(tree_list): | |
| if node['parent_id'] == -1: | |
| queue.append(node['id']) | |
| order[node_idx] = current_node_idx | |
| current_node_idx += 1 | |
| break | |
| while len(queue) > 0: | |
| current_node = queue.pop(0) | |
| for node_idx, node in enumerate(tree_list): | |
| if node['parent_id'] == current_node: | |
| queue.append(node['id']) | |
| order[node_idx] = current_node_idx | |
| current_node_idx += 1 | |
| return order | |
| def bfs_tree(tree_list, aabb_list, epsilon=1e-3): | |
| # 初始化遍历顺序列表 | |
| order = [0] * len(tree_list) | |
| current_order = 0 | |
| # 构建父节点到子节点的索引映射 | |
| parent_map = defaultdict(list) | |
| for idx, node in enumerate(tree_list): | |
| parent_map[node['parent_id']].append(idx) | |
| # 查找根节点 | |
| root_indices = [idx for idx, node in enumerate(tree_list) if node['parent_id'] == -1] | |
| if not root_indices: | |
| return order | |
| # 初始化队列(存储节点索引) | |
| queue = [root_indices[0]] | |
| order[root_indices[0]] = current_order | |
| current_order += 1 | |
| # 比较函数:按中心坐标排序 | |
| def compare_centers(a, b): | |
| # 获取两个节点的中心坐标 | |
| center_a = [(aabb_list[a][i] + aabb_list[a][i+3])/2 for i in range(3)] | |
| center_b = [(aabb_list[b][i] + aabb_list[b][i+3])/2 for i in range(3)] | |
| # 逐级比较坐标(考虑epsilon阈值) | |
| for coord in range(3): | |
| delta = abs(center_a[coord] - center_b[coord]) | |
| if delta > epsilon: | |
| return -1 if center_a[coord] < center_b[coord] else 1 | |
| return 0 # 所有坐标差均小于阈值时保持原顺序 | |
| # BFS遍历 | |
| while queue: | |
| current_idx = queue.pop(0) | |
| current_id = tree_list[current_idx]['id'] | |
| # 获取子节点索引并排序 | |
| children = parent_map.get(current_id, []) | |
| sorted_children = sorted(children, key=cmp_to_key(compare_centers)) | |
| # 处理子节点 | |
| for child_idx in sorted_children: | |
| order[child_idx] = current_order | |
| current_order += 1 | |
| queue.append(child_idx) | |
| return order | |
| class BaseDataset(Dataset): | |
| def __init__(self, hparams): | |
| super().__init__() | |
| self.hparams = hparams | |
| def _filter_models(self, models_ids): | |
| ''' | |
| Filter out models that has more than K nodes. | |
| ''' | |
| json_data_root = self.hparams.json_root | |
| filtered = [] | |
| for i, model_id in enumerate(models_ids): | |
| if i % 100 == 0: | |
| print(f'Checking model {i}/{len(models_ids)}') | |
| path = os.path.join(json_data_root, model_id, self.json_name) | |
| with open(path, 'r') as f: | |
| json_file = json.load(f) | |
| if len(json_file['diffuse_tree']) <= self.hparams.K: | |
| filtered.append(model_id) | |
| return filtered | |
| def get_acd_mapping(self): | |
| self.category_mapping = { | |
| 'armoire': 'StorageFurniture', | |
| 'bookcase': 'StorageFurniture', | |
| 'chest_of_drawers': 'StorageFurniture', | |
| 'desk': 'Table', | |
| 'dishwasher': 'Dishwasher', | |
| 'hanging_cabinet': 'StorageFurniture', | |
| 'kitchen_cabinet': 'StorageFurniture', | |
| 'microwave': 'Microwave', | |
| 'nightstand': 'StorageFurniture', | |
| 'oven': 'Oven', | |
| 'refrigerator': 'Refrigerator', | |
| 'sink_cabinet': 'StorageFurniture', | |
| 'tv_stand': 'StorageFurniture', | |
| 'washer': 'WashingMachine', | |
| 'table': 'Table', | |
| 'cabinet': 'StorageFurniture', | |
| 'hanging_cabinet': 'StorageFurniture', | |
| } | |
| def _random_permute(self, graph, nodes): | |
| ''' | |
| Function to randomly permute the nodes and update the graph and node attribute info. | |
| Args: | |
| graph: a dictionary containing the adjacency matrix, edge list, and root node | |
| nodes: a list of nodes | |
| Returns: | |
| graph_permuted: a dictionary containing the updated adjacency matrix, edge list, and root node | |
| nodes_permuted: a list of permuted nodes | |
| ''' | |
| N = len(nodes) | |
| order = np.random.permutation(N) | |
| graph_permuted = self._reorder_nodes(graph, order) | |
| exchange = [0] * len(order) | |
| for i in range(len(order)): | |
| exchange[order[i]] = i | |
| nodes_permuted = nodes[exchange, :] | |
| return graph_permuted, nodes_permuted | |
| def _permute_by_order(self, graph, nodes, order): | |
| ''' | |
| Function to permute the nodes and update the graph and node attribute info by order. | |
| Args: | |
| graph: a dictionary containing the adjacency matrix, edge list, and root node | |
| nodes: a list of nodes | |
| order: a list of indices for reordering | |
| Returns: | |
| graph_permuted: a dictionary containing the updated adjacency matrix, edge list, and root node | |
| nodes_permuted: a list of permuted nodes | |
| ''' | |
| graph_permuted = self._reorder_nodes(graph, order) | |
| if nodes is None: | |
| return graph_permuted, None | |
| else: | |
| exchange = [0] * len(order) | |
| for i in range(len(order)): | |
| exchange[order[i]] = i | |
| nodes_permuted = nodes[exchange, :] | |
| return graph_permuted, nodes_permuted | |
| def _prepare_node_data(self, node): | |
| # semantic label | |
| label = np.array([sem_ref['fwd'][node['name']]], dtype=np.float32) / 5. - 0.8 # (1,), range from -0.8 to 0.8 | |
| # joint type | |
| joint_type = np.array([joint_ref['fwd'][node['joint']['type']] / 5.], dtype=np.float32) - 0.5 # (1,), range from -0.8 to 0.8 | |
| # aabb | |
| aabb_center = np.array(node['aabb']['center'], dtype=np.float32) # (3,), range from -1 to 1 | |
| aabb_size = np.array(node['aabb']['size'], dtype=np.float32) # (3,), range from -1 to 1 | |
| aabb_max = aabb_center + aabb_size / 2 | |
| aabb_min = aabb_center - aabb_size / 2 | |
| # joint axis and range | |
| if node['joint']['type'] == 'fixed': | |
| axis_dir = np.zeros((3,), dtype=np.float32) | |
| axis_ori = aabb_center | |
| joint_range = np.zeros((2,), dtype=np.float32) | |
| else: | |
| if node['joint']['type'] == 'revolute' or node['joint']['type'] == 'continuous': | |
| joint_range = np.array([node['joint']['range'][1]], dtype=np.float32) / 360. | |
| joint_range = np.concatenate([joint_range, np.zeros((1,), dtype=np.float32)], axis=0) # (2,) | |
| elif node['joint']['type'] == 'prismatic' or node['joint']['type'] == 'screw': | |
| joint_range = np.array([node['joint']['range'][1]], dtype=np.float32) | |
| joint_range = np.concatenate([np.zeros((1,), dtype=np.float32), joint_range], axis=0) # (2,) | |
| axis_dir = np.array(node['joint']['axis']['direction'], dtype=np.float32) * 0.7 # (3,), range from -0.7 to 0.7 | |
| # make sure the axis is pointing to the positive direction | |
| if np.sum(axis_dir > 0) < np.sum(-axis_dir > 0): | |
| axis_dir = -axis_dir | |
| joint_range = -joint_range | |
| axis_ori = np.array(node['joint']['axis']['origin'], dtype=np.float32) # (3,), range from -1 to 1 | |
| if (node['joint']['type'] == 'prismatic' or node['joint']['type'] == 'screw') and node['name'] != 'door': | |
| axis_ori = aabb_center | |
| # prepare node data by given mod name | |
| # aabb = np.concatenate([aabb_max, aabb_min], axis=0) | |
| # axis = np.concatenate([axis_dir, axis_ori], axis=0) | |
| # node_data_all = [aabb, joint_type.repeat(6), axis, joint_range.repeat(3), label.repeat(6)] | |
| # node_data_list = [node_data_all[data_mode_ref[mod_name]] for mod_name in self.hparams.data_mode] | |
| # node_data = np.concatenate(node_data_list, axis=0) | |
| node_label = np.ones(6, dtype=np.float32) | |
| node_data = np.concatenate([aabb_max, aabb_min, joint_type.repeat(6), axis_dir, axis_ori, joint_range.repeat(3), label.repeat(6), node_label], axis=0) | |
| if self.hparams.mode_num == 5: | |
| node_data = np.concatenate([aabb_max, aabb_min, joint_type.repeat(6), axis_dir, axis_ori, joint_range.repeat(3), label.repeat(6)], axis=0) | |
| return node_data | |
| def _reorder_nodes(self, graph, order): | |
| ''' | |
| Function to reorder nodes in the graph and | |
| update the adjacency matrix, edge list, and root node. | |
| Args: | |
| graph: a dictionary containing the adjacency matrix, edge list, and root node | |
| order: a list of indices for reordering | |
| Returns: | |
| new_graph: a dictionary containing the updated adjacency matrix, edge list, and root node | |
| ''' | |
| N = len(order) | |
| mapping = {i: order[i] for i in range(N)} | |
| mapping.update({i: i for i in range(N, self.hparams.K)}) | |
| G = nx.from_numpy_array(graph['adj'], create_using=nx.Graph) | |
| G_ = nx.relabel_nodes(G, mapping) | |
| new_adj = nx.adjacency_matrix(G_, G.nodes).todense() | |
| exchange = [0] * len(order) | |
| for i in range(len(order)): | |
| exchange[order[i]] = i | |
| return { | |
| 'adj': new_adj.astype(np.float32), | |
| 'parents': graph['parents'][exchange] | |
| } | |
| def _prepare_input_GT(self, file, model_id): | |
| ''' | |
| Function to parse input item from a json file for the CAGE training. | |
| ''' | |
| tree = file['diffuse_tree'] | |
| K = self.hparams.K # max number of nodes | |
| cond = {} # conditional information and axillary data | |
| cond['parents'] = np.zeros(K, dtype=np.int8) | |
| # prepare node data | |
| nodes = [] | |
| for node in tree: | |
| node_data = self._prepare_node_data(node) # (36,) | |
| nodes.append(node_data) | |
| nodes = np.array(nodes, dtype=np.float32) | |
| n_nodes = len(nodes) | |
| # prepare graph | |
| graph = build_graph(tree, self.hparams.K) | |
| if self.mode == 'train': # perturb the node order for training | |
| graph, nodes = self._random_permute(graph, nodes) | |
| # pad the nodes to K with empty nodes | |
| if n_nodes < K: | |
| empty_node = np.zeros((nodes[0].shape[0],)) | |
| data = np.concatenate([nodes, [empty_node] * (K - n_nodes)], axis=0, dtype=np.float32) # (K, 36) | |
| else: | |
| data = nodes | |
| mode_num = data.shape[1] // 6 | |
| data = data.reshape(K*mode_num, 6) # (K * n_attr, 6) | |
| # attr mask (for Local Attention) | |
| attr_mask = np.eye(K, K, dtype=bool) | |
| attr_mask = attr_mask.repeat(mode_num, axis=0).repeat(mode_num, axis=1) | |
| cond['attr_mask'] = attr_mask | |
| # key padding mask (for Global Attention) | |
| pad_mask = np.zeros((K*mode_num, K*mode_num), dtype=bool) | |
| pad_mask[:, :n_nodes*mode_num] = 1 | |
| cond['key_pad_mask'] = pad_mask | |
| # adj mask (for Graph Relation Attention) | |
| adj_mask = graph['adj'][:].astype(bool) | |
| adj_mask = adj_mask.repeat(mode_num, axis=0).repeat(mode_num, axis=1) | |
| adj_mask[n_nodes*mode_num:, :] = 1 | |
| cond['adj_mask'] = adj_mask | |
| # object category | |
| if self.map_cat: # for ACD dataset | |
| category = file['meta']['obj_cat'] | |
| category = self.category_mapping[category] | |
| cond['cat'] = cat_ref[category] | |
| else: | |
| cond['cat'] = cat_ref.get(file['meta']['obj_cat'], None) | |
| if cond['cat'] is None: | |
| cond['cat'] = self.category_mapping.get(file['meta']['obj_cat'], None) | |
| if cond['cat'] is None: | |
| cond['cat'] = 2 | |
| else: | |
| cond['cat'] = cat_ref.get(cond['cat'], None) | |
| # cond['cat'] = cat_ref[file['meta']['obj_cat']] | |
| if cond['cat'] is None: | |
| cond['cat'] = 2 | |
| # axillary info | |
| cond['name'] = model_id | |
| cond['adj'] = graph['adj'] | |
| cond['parents'][:n_nodes] = graph['parents'] | |
| cond['n_nodes'] = n_nodes | |
| cond['obj_cat'] = file['meta']['obj_cat'] | |
| return data, cond | |
| def _prepare_input(self, model_id, pred_file, gt_file=None): | |
| ''' | |
| Function to parse input item from pred_file, and parse GT from gt_file if available. | |
| ''' | |
| K = self.hparams.K # max number of nodes | |
| cond = {} # conditional information and axillary data | |
| # prepare node data | |
| n_nodes = len(pred_file['diffuse_tree']) | |
| # prepare graph | |
| pred_graph = build_graph(pred_file['diffuse_tree'], K) | |
| # dummy GT data | |
| data = np.zeros((K*5, 6), dtype=np.float32) | |
| # attr mask (for Local Attention) | |
| attr_mask = np.eye(K, K, dtype=bool) | |
| attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1) | |
| cond['attr_mask'] = attr_mask | |
| # key padding mask (for Global Attention) | |
| pad_mask = np.zeros((K*5, K*5), dtype=bool) | |
| pad_mask[:, :n_nodes*5] = 1 | |
| cond['key_pad_mask'] = pad_mask | |
| # adj mask (for Graph Relation Attention) | |
| adj_mask = pred_graph['adj'][:].astype(bool) | |
| adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1) | |
| adj_mask[n_nodes*5:, :] = 1 | |
| cond['adj_mask'] = adj_mask | |
| # placeholder category, won't be used if category is given (below) | |
| cond['cat'] = cat_ref['StorageFurniture'] | |
| cond['obj_cat'] = 'StorageFurniture' | |
| # if object category is given as input | |
| if not self.hparams.get('test_label_free', False): | |
| assert 'meta' in pred_file, 'meta not found in the json file.' | |
| assert 'obj_cat' in pred_file['meta'], 'obj_cat not found in the metadata of the json file.' | |
| category = pred_file['meta']['obj_cat'] | |
| if self.map_cat: # for ACD dataset | |
| category = self.category_mapping[category] | |
| cond['cat'] = cat_ref[category] | |
| cond['obj_cat'] = category | |
| # axillary info | |
| cond['name'] = model_id | |
| cond['adj'] = pred_graph['adj'] | |
| cond['parents'] = np.zeros(K, dtype=np.int8) | |
| cond['parents'][:n_nodes] = pred_graph['parents'] | |
| cond['n_nodes'] = n_nodes | |
| return data, cond | |
| def __getitem__(self, index): | |
| raise NotImplementedError | |
| def __len__(self): | |
| raise NotImplementedError | |