DIPO / dataset /utils.py
xinjie.wang
init commit
c28dddb
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import numpy as np
from PIL import Image
from my_utils.refs import joint_ref, sem_ref
def rescale_axis(jtype, axis_d, axis_o, box_center):
'''
Function to rescale the axis for rendering
Args:
- jtype (int): joint type
- axis_d (np.array): axis direction
- axis_o (np.array): axis origin
- box_center (np.array): bounding box center
Returns:
- center (np.array): rescaled axis origin
- axis_d (np.array): rescaled axis direction
'''
if jtype == 0 or jtype == 1:
return [0., 0., 0.], [0., 0., 0.]
if jtype == 3 or jtype == 4:
center = box_center
else:
center = axis_o + np.dot(axis_d, box_center-axis_o) * axis_d
return center.tolist(), axis_d.tolist()
def make_white_background(src_img):
'''Make the white background for the input RGBA image.'''
src_img.load()
background = Image.new("RGB", src_img.size, (255, 255, 255))
background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
return background
def build_graph(tree, K=32):
'''
Function to build graph from the node list.
Args:
nodes: list of nodes
K: the maximum number of nodes in the graph
Returns:
adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes
edge_list: list of edges, for visualization
'''
adj = np.zeros((K, K), dtype=np.float32)
parents = []
for node in tree:
# 1-ring relationship
if node['parent'] != -1:
adj[node['id'], node['parent']] = 1
parents.append(node['parent'])
else:
adj[node['id'], node['id']] = 1
parents.append(-1)
for child_id in node['children']:
adj[node['id'], child_id] = 1
return {
'adj': adj,
'parents': np.array(parents, dtype=np.int8)
}
def load_input_from(pred_file, K=32):
'''
Function to parse input item from a file containing the predicted graph
'''
cond = {} # conditional information and axillary data
# prepare node data
n_nodes = len(pred_file['diffuse_tree'])
# prepare graph
pred_graph = build_graph(pred_file['diffuse_tree'], K)
# attr mask (for Local Attention)
attr_mask = np.eye(K, K, dtype=bool)
attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
cond['attr_mask'] = attr_mask
# key padding mask (for Global Attention)
pad_mask = np.zeros((K*5, K*5), dtype=bool)
pad_mask[:, :n_nodes*5] = 1
cond['key_pad_mask'] = pad_mask
# adj mask (for Graph Relation Attention)
adj_mask = pred_graph['adj'][:].astype(bool)
adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1)
adj_mask[n_nodes*5:, :] = 1
cond['adj_mask'] = adj_mask
# placeholder
data = np.zeros((K*5, 6), dtype=bool)
cond['cat'] = 2
# axillary info
cond['adj'] = pred_graph['adj']
cond['parents'] = np.zeros(K, dtype=np.int8)
cond['parents'][:n_nodes] = pred_graph['parents']
cond['n_nodes'] = n_nodes
return data, cond
def convert_data_range(x):
'''postprocessing: convert the raw model output to the original range, following CAGE'''
x = x.reshape(-1, 30) # (K, 36)
aabb_max = x[:, 0:3]
aabb_min = x[:, 3:6]
center = (aabb_max + aabb_min) / 2.0
size = (aabb_max - aabb_min).clip(min=5e-3)
j_type = np.mean(x[:, 6:12], axis=1)
j_type = ((j_type + 0.5) * 5).clip(min=1.0, max=5.0).round()
axis_d = x[:, 12:15]
axis_d = axis_d / (
np.linalg.norm(axis_d, axis=1, keepdims=True) + np.finfo(float).eps
)
axis_o = x[:, 15:18]
j_range = (x[:, 18:20] + x[:, 20:22] + x[:, 22:24]) / 3
j_range = j_range.clip(min=-1.0, max=1.0)
j_range[:, 0] = j_range[:, 0] * 360
j_range[:, 1] = j_range[:, 1]
label = np.mean(x[:, 24:30], axis=1)
label = ((label + 0.8) * 5).clip(min=0.0, max=7.0).round()
return {
"center": center,
"size": size,
"type": j_type,
"axis_d": axis_d,
"axis_o": axis_o,
"range": j_range,
"label": label,
}
def parse_tree(data, n_nodes, par, adj):
tree = []
# convert to json format
for i in range(n_nodes):
node = {"id": i}
node["name"] = sem_ref["bwd"][int(data["label"][i].item())]
node["parent"] = int(par[i])
node["children"] = [
int(child) for child in np.where(adj[i] == 1)[0] if child != par[i]
]
node["aabb"] = {}
node["aabb"]["center"] = data["center"][i].tolist()
node["aabb"]["size"] = data["size"][i].tolist()
node["joint"] = {}
if node['name'] == 'base':
node["joint"]["type"] = 'fixed'
else:
node["joint"]["type"] = joint_ref["bwd"][int(data["type"][i].item())]
if node["joint"]["type"] == "fixed":
node["joint"]["range"] = [0.0, 0.0]
elif node["joint"]["type"] == "revolute":
node["joint"]["range"] = [0.0, float(data["range"][i][0])]
elif node["joint"]["type"] == "continuous":
node["joint"]["range"] = [0.0, 360.0]
elif (
node["joint"]["type"] == "prismatic" or node["joint"]["type"] == "screw"
):
node["joint"]["range"] = [0.0, float(data["range"][i][1])]
node["joint"]["axis"] = {}
# relocate the axis to visualize well
axis_o, axis_d = rescale_axis(
int(data["type"][i].item()),
data["axis_d"][i],
data["axis_o"][i],
data["center"][i],
)
node["joint"]["axis"]["direction"] = axis_d
node["joint"]["axis"]["origin"] = axis_o
# append node to the tree
tree.append(node)
return tree
def convert_json(x, c, prefix=''):
out = {"meta": {}, "diffuse_tree": []}
n_nodes = c[f"{prefix}n_nodes"][0].item()
par = c[f"{prefix}parents"][0].cpu().numpy().tolist()
adj = c[f"{prefix}adj"][0].cpu().numpy()
np.fill_diagonal(adj, 0) # remove self-loop for the root node
if f"{prefix}obj_cat" in c:
out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][0]
# convert the data to original range
data = convert_data_range(x)
# parse the tree
tree = parse_tree(data, n_nodes, par, adj)
out["diffuse_tree"] = tree
return out