File size: 6,413 Bytes
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import numpy as np
from PIL import Image
from my_utils.refs import joint_ref, sem_ref

def rescale_axis(jtype, axis_d, axis_o, box_center):
    '''
    Function to rescale the axis for rendering
    
    Args:
    - jtype (int): joint type
    - axis_d (np.array): axis direction
    - axis_o (np.array): axis origin
    - box_center (np.array): bounding box center

    Returns:
    - center (np.array): rescaled axis origin
    - axis_d (np.array): rescaled axis direction
    '''
    if jtype == 0 or jtype == 1:
        return [0., 0., 0.], [0., 0., 0.]
    if jtype == 3 or jtype == 4:
        center = box_center
    else:
        center = axis_o + np.dot(axis_d, box_center-axis_o) * axis_d
    return center.tolist(), axis_d.tolist()

def make_white_background(src_img):
    '''Make the white background for the input RGBA image.'''
    src_img.load() 
    background = Image.new("RGB", src_img.size, (255, 255, 255))
    background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
    return background

def build_graph(tree, K=32):
    '''
    Function to build graph from the node list.
    
    Args:
        nodes: list of nodes
        K: the maximum number of nodes in the graph
    Returns:
        adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes
        edge_list: list of edges, for visualization
    '''
    adj = np.zeros((K, K), dtype=np.float32)
    parents = []
    for node in tree:
        # 1-ring relationship
        if node['parent'] != -1:
            adj[node['id'], node['parent']] = 1
            parents.append(node['parent'])
        else:
            adj[node['id'], node['id']] = 1
            parents.append(-1)
        for child_id in node['children']:
            adj[node['id'], child_id] = 1 

    return {
        'adj': adj,
        'parents': np.array(parents, dtype=np.int8)
    }

def load_input_from(pred_file, K=32):
    '''
    Function to parse input item from a file containing the predicted graph
    '''
    
    cond = {} # conditional information and axillary data
    # prepare node data
    n_nodes = len(pred_file['diffuse_tree'])
    # prepare graph
    pred_graph = build_graph(pred_file['diffuse_tree'], K)

    # attr mask (for Local Attention)
    attr_mask = np.eye(K, K, dtype=bool)
    attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
    cond['attr_mask'] = attr_mask

    # key padding mask (for Global Attention)
    pad_mask = np.zeros((K*5, K*5), dtype=bool)
    pad_mask[:, :n_nodes*5] = 1
    cond['key_pad_mask'] = pad_mask

    # adj mask (for Graph Relation Attention)
    adj_mask = pred_graph['adj'][:].astype(bool)
    adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1)
    adj_mask[n_nodes*5:, :] = 1
    cond['adj_mask'] = adj_mask

    # placeholder
    data = np.zeros((K*5, 6), dtype=bool)
    cond['cat'] = 2

    # axillary info
    cond['adj'] = pred_graph['adj']
    cond['parents'] = np.zeros(K, dtype=np.int8)
    cond['parents'][:n_nodes] = pred_graph['parents']
    cond['n_nodes'] = n_nodes

    return data, cond

def convert_data_range(x):
    '''postprocessing: convert the raw model output to the original range, following CAGE'''
    x = x.reshape(-1, 30)  # (K, 36)
    aabb_max = x[:, 0:3]
    aabb_min = x[:, 3:6]
    center = (aabb_max + aabb_min) / 2.0
    size = (aabb_max - aabb_min).clip(min=5e-3)

    j_type = np.mean(x[:, 6:12], axis=1)
    j_type = ((j_type + 0.5) * 5).clip(min=1.0, max=5.0).round()

    axis_d = x[:, 12:15]
    axis_d = axis_d / (
        np.linalg.norm(axis_d, axis=1, keepdims=True) + np.finfo(float).eps
    )
    axis_o = x[:, 15:18]

    j_range = (x[:, 18:20] + x[:, 20:22] + x[:, 22:24]) / 3
    j_range = j_range.clip(min=-1.0, max=1.0)
    j_range[:, 0] = j_range[:, 0] * 360
    j_range[:, 1] = j_range[:, 1]

    label = np.mean(x[:, 24:30], axis=1)
    label = ((label + 0.8) * 5).clip(min=0.0, max=7.0).round()
    return {
        "center": center,
        "size": size,
        "type": j_type,
        "axis_d": axis_d,
        "axis_o": axis_o,
        "range": j_range,
        "label": label,
    }

def parse_tree(data, n_nodes, par, adj):
    tree = []
    # convert to json format
    for i in range(n_nodes):
        node = {"id": i}
        node["name"] = sem_ref["bwd"][int(data["label"][i].item())]
        node["parent"] = int(par[i])
        node["children"] = [
            int(child) for child in np.where(adj[i] == 1)[0] if child != par[i]
        ]
        node["aabb"] = {}
        node["aabb"]["center"] = data["center"][i].tolist()
        node["aabb"]["size"] = data["size"][i].tolist()
        node["joint"] = {}
        if node['name'] == 'base':
            node["joint"]["type"] = 'fixed'
        else:
            node["joint"]["type"] = joint_ref["bwd"][int(data["type"][i].item())]
        if node["joint"]["type"] == "fixed":
            node["joint"]["range"] = [0.0, 0.0]
        elif node["joint"]["type"] == "revolute":
            node["joint"]["range"] = [0.0, float(data["range"][i][0])]
        elif node["joint"]["type"] == "continuous":
            node["joint"]["range"] = [0.0, 360.0]
        elif (
            node["joint"]["type"] == "prismatic" or node["joint"]["type"] == "screw"
        ):
            node["joint"]["range"] = [0.0, float(data["range"][i][1])]
        node["joint"]["axis"] = {}
        # relocate the axis to visualize well
        axis_o, axis_d = rescale_axis(
            int(data["type"][i].item()),
            data["axis_d"][i],
            data["axis_o"][i],
            data["center"][i],
        )
        node["joint"]["axis"]["direction"] = axis_d
        node["joint"]["axis"]["origin"] = axis_o
        # append node to the tree
        tree.append(node)
    return tree

def convert_json(x, c, prefix=''):
    out = {"meta": {}, "diffuse_tree": []}
    n_nodes = c[f"{prefix}n_nodes"][0].item()
    par = c[f"{prefix}parents"][0].cpu().numpy().tolist()
    adj = c[f"{prefix}adj"][0].cpu().numpy()
    np.fill_diagonal(adj, 0) # remove self-loop for the root node
    if f"{prefix}obj_cat" in c:
        out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][0]

    # convert the data to original range
    data = convert_data_range(x)
    # parse the tree
    tree = parse_tree(data, n_nodes, par, adj)
    out["diffuse_tree"] = tree
    return out