xinjjj commited on
Commit
ce34030
·
verified ·
1 Parent(s): c28dddb

Upload 29 files

Browse files

Upload large files.

.gitattributes CHANGED
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/1_open_1.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/1_open_2.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/1.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/close1.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/close10.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/close2.png filter=lfs diff=lfs merge=lfs -text
42
+ examples/close3.png filter=lfs diff=lfs merge=lfs -text
43
+ examples/close4.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/close5.png filter=lfs diff=lfs merge=lfs -text
45
+ examples/close7.png filter=lfs diff=lfs merge=lfs -text
46
+ examples/close8.png filter=lfs diff=lfs merge=lfs -text
47
+ examples/close9.jpg filter=lfs diff=lfs merge=lfs -text
48
+ examples/open1.png filter=lfs diff=lfs merge=lfs -text
49
+ examples/open10.png filter=lfs diff=lfs merge=lfs -text
50
+ examples/open2.png filter=lfs diff=lfs merge=lfs -text
51
+ examples/open3.png filter=lfs diff=lfs merge=lfs -text
52
+ examples/open4.png filter=lfs diff=lfs merge=lfs -text
53
+ examples/open5.png filter=lfs diff=lfs merge=lfs -text
54
+ examples/open6.png filter=lfs diff=lfs merge=lfs -text
55
+ examples/open7.png filter=lfs diff=lfs merge=lfs -text
56
+ examples/open8.png filter=lfs diff=lfs merge=lfs -text
57
+ examples/open9.jpg filter=lfs diff=lfs merge=lfs -text
ckpts/dipo.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:493f551499b95af57b5bb6e872d1107a9cf4056fbf151fc45f416f96a919dad6
3
+ size 24565754
examples/1.png ADDED

Git LFS Details

  • SHA256: c301a718a1401acdc67b1c3ad0a03ce7d44b4fbecdff035e362fbcab4e8146c7
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
examples/1_open_1.png ADDED

Git LFS Details

  • SHA256: 500bb1e80e9f140c48c26be0cc911b45f35bae3402dcd4b65222636c5de4209a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
examples/1_open_2.png ADDED

Git LFS Details

  • SHA256: d5ddfcb9c0e18bccef98906cf14d2806ed5cd1baf2ae2a9b8d47aa5dee0fc728
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
examples/close1.png ADDED

Git LFS Details

  • SHA256: aab9059f5a24cfb4e5aa83109dec10e0b60c3305d824ea7ca4b3815b388299fc
  • Pointer size: 132 Bytes
  • Size of remote file: 2.89 MB
examples/close10.png ADDED

Git LFS Details

  • SHA256: 75a59cba68214dbc3c6f3c4f314aaa7e53444f1ca97f71dc2543db1d21cdaa8e
  • Pointer size: 131 Bytes
  • Size of remote file: 324 kB
examples/close2.png ADDED

Git LFS Details

  • SHA256: 021207db2e7bc4edd2d8564ae4b6a471ffbe863b355c2de779671992cb309a88
  • Pointer size: 132 Bytes
  • Size of remote file: 2.31 MB
examples/close3.png ADDED

Git LFS Details

  • SHA256: 9611338fcc37ad15f97740e1c8c55f4f7664f55258ee2c7576fab6c539ae5097
  • Pointer size: 132 Bytes
  • Size of remote file: 2.41 MB
examples/close4.png ADDED

Git LFS Details

  • SHA256: a378ff1c881453ac05e7d89a9d1b8d63c9cece9d2ded4a5802d2eff8337bc4f4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.13 MB
examples/close5.png ADDED

Git LFS Details

  • SHA256: 79e7d13acaa9c8615deb7d8a6590b9e451754f55b2e2ff17cd646048f8092b88
  • Pointer size: 131 Bytes
  • Size of remote file: 581 kB
examples/close6.png ADDED
examples/close7.png ADDED

Git LFS Details

  • SHA256: 924ea038fa387b3740c6369e351a48bc1f8ed01d1b25ed6f290b582e46066533
  • Pointer size: 131 Bytes
  • Size of remote file: 640 kB
examples/close8.png ADDED

Git LFS Details

  • SHA256: 754f8b420729f5b0d680c083f507f0ab9f3506c0ebfb0514305c613de389c04c
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
examples/close9.jpg ADDED

Git LFS Details

  • SHA256: dd19189cbaf67063ff6a5e0cdfe1159a2cb1a7f771ccf372bf1d7b2704db4266
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
examples/open1.png ADDED

Git LFS Details

  • SHA256: 082b51f5894c82c5be5ebf7d137bb9db5e1ce4a2aed7f4571dab422c6c6e7ba5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.82 MB
examples/open10.png ADDED

Git LFS Details

  • SHA256: 41a9182d705c1482b58ab33f5575936df1d02c77acc3b6b56a316de47b9b3d0a
  • Pointer size: 131 Bytes
  • Size of remote file: 350 kB
examples/open2.png ADDED

Git LFS Details

  • SHA256: ed531e4310d76d4e07c673e76e6ca693ea616f94c0b0cfa13fdf9c83fc0dca0d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.4 MB
examples/open3.png ADDED

Git LFS Details

  • SHA256: 79ecc7922eae1e28d74870213a957cfa8c0f12faaf23f5324acd27f8cff8ce12
  • Pointer size: 132 Bytes
  • Size of remote file: 2.62 MB
examples/open4.png ADDED

Git LFS Details

  • SHA256: 28bcb2b1429fcbc7c9b8d6555d134fd1fbf0f6a7d2d538f2a5f4c72a8d8d3879
  • Pointer size: 132 Bytes
  • Size of remote file: 2.52 MB
examples/open5.png ADDED

Git LFS Details

  • SHA256: 462dbf2f0380151d416752eb62b53f2ebf1bc78591909fee7532eea73fd71963
  • Pointer size: 131 Bytes
  • Size of remote file: 959 kB
examples/open6.png ADDED

Git LFS Details

  • SHA256: 747335a8c3d6bb5558d964eb992a61225e02a6e726e5cb7642f7f0c472bce547
  • Pointer size: 131 Bytes
  • Size of remote file: 288 kB
examples/open7.png ADDED

Git LFS Details

  • SHA256: 7ec08b66d9747070f8450a3787b70db4be008615d7a1e468fd78409759e6fce5
  • Pointer size: 131 Bytes
  • Size of remote file: 682 kB
examples/open8.png ADDED

Git LFS Details

  • SHA256: 8415042543b27900ae583789ae6f6c46c980e245551b5879e728d465f38d9d8f
  • Pointer size: 131 Bytes
  • Size of remote file: 281 kB
examples/open9.jpg ADDED

Git LFS Details

  • SHA256: 2a5e780c5bfbb7fb8ceedb719260212d58caf02891969fd2d193f7658662bdb5
  • Pointer size: 131 Bytes
  • Size of remote file: 312 kB
systems/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ systems = {}
2
+
3
+
4
+ def register(name):
5
+ def decorator(cls):
6
+ systems[name] = cls
7
+ return cls
8
+
9
+ return decorator
10
+
11
+
12
+ def make(name, config, load_from_checkpoint=None):
13
+ if load_from_checkpoint is None:
14
+ system = systems[name](config)
15
+ else:
16
+ system = systems[name].load_from_checkpoint(
17
+ load_from_checkpoint, strict=False, config=config
18
+ )
19
+ return system
20
+
21
+
22
+ from . import system_origin
systems/base.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ import json
4
+ import math
5
+ import numpy as np
6
+ import lightning.pytorch as pl
7
+ from metrics.iou_cdist import IoU_cDist
8
+ from my_utils.savermixins import SaverMixin
9
+ from my_utils.refs import sem_ref, joint_ref
10
+ from dataset.utils import convert_data_range, parse_tree
11
+ from my_utils.plot import viz_graph, make_grid, add_text
12
+ from my_utils.render import draw_boxes_axiss_anim, prepare_meshes
13
+ from PIL import Image
14
+
15
+ class BaseSystem(pl.LightningModule, SaverMixin):
16
+ def __init__(self, hparams):
17
+ super().__init__()
18
+ self.hparams.update(hparams)
19
+
20
+ def setup(self, stage: str):
21
+ # config the logger dir for images
22
+ self.hparams.save_dir = os.path.join(self.hparams.exp_dir, 'output', stage)
23
+ os.makedirs(self.hparams.save_dir, exist_ok=True)
24
+
25
+
26
+ # --------------------------------- visualization ---------------------------------
27
+
28
+ def convert_json(self, x, c, idx, prefix=''):
29
+ out = {"meta": {}, "diffuse_tree": []}
30
+
31
+ n_nodes = c[f"{prefix}n_nodes"][idx].item()
32
+ par = c[f"{prefix}parents"][idx].cpu().numpy().tolist()
33
+ adj = c[f"{prefix}adj"][idx].cpu().numpy()
34
+ np.fill_diagonal(adj, 0) # remove self-loop for the root node
35
+ if f"{prefix}obj_cat" in c:
36
+ out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][idx]
37
+
38
+ # convert the data to original range
39
+ data = convert_data_range(x.cpu().numpy())
40
+ # parse the tree
41
+ out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj)
42
+ return out
43
+
44
+ # def save_val_img(self, pred, gt, cond):
45
+ # B = pred.shape[0]
46
+ # pred_imgs, gt_imgs, gt_graphs_view = [], [], []
47
+ # for b in range(B):
48
+ # print(b)
49
+ # # convert to humnan readable format json
50
+ # pred_json = self.convert_json(pred[b], cond, b)
51
+ # gt_json = self.convert_json(gt[b], cond, b)
52
+ # # visualize bbox and axis
53
+ # pred_meshes = prepare_meshes(pred_json)
54
+ # bbox_0, bbox_1, axiss = (
55
+ # pred_meshes["bbox_0"],
56
+ # pred_meshes["bbox_1"],
57
+ # pred_meshes["axiss"],
58
+ # )
59
+ # pred_img = draw_boxes_axiss_anim(
60
+ # bbox_0, bbox_1, axiss, mode="graph", resolution=128
61
+ # )
62
+ # gt_meshes = prepare_meshes(gt_json)
63
+ # bbox_0, bbox_1, axiss = (
64
+ # gt_meshes["bbox_0"],
65
+ # gt_meshes["bbox_1"],
66
+ # gt_meshes["axiss"],
67
+ # )
68
+ # gt_img = draw_boxes_axiss_anim(
69
+ # bbox_0, bbox_1, axiss, mode="graph", resolution=128
70
+ # )
71
+ # # visualize graph
72
+ # # gt_graph = viz_graph(gt_json, res=128)
73
+ # # gt_graph = add_text(cond["name"][b], gt_graph)
74
+ # # GT views
75
+ # rgb_view = cond["img"][b].cpu().numpy()
76
+
77
+ # pred_imgs.append(pred_img)
78
+ # gt_imgs.append(gt_img)
79
+ # gt_graphs_view.append(rgb_view)
80
+ # # gt_graphs_view.append(gt_graph)
81
+
82
+ # # save images for generated results
83
+ # epoch = str(self.current_epoch).zfill(5)
84
+ # # pred_thumbnails = np.concatenate(pred_imgs, axis=1) # concat batch in width
85
+
86
+ # import ipdb
87
+ # ipdb.set_trace()
88
+ # # save images for ground truth
89
+ # for i in range(math.ceil(len(gt_graphs_view) / 8)):
90
+ # start = i * 8
91
+ # end = min((i + 1) * 8, len(gt_graphs_view))
92
+ # pred_thumbnails = np.concatenate(pred_imgs[start:end], axis=1)
93
+ # gt_graph_imgs = np.concatenate(gt_graphs_view[start:end], axis=1)
94
+ # gt_thumbnails = np.concatenate(gt_imgs[start:end], axis=1) # concat batch in width
95
+ # grid = np.concatenate([gt_graph_imgs, gt_thumbnails, pred_thumbnails], axis=0)
96
+ # self.save_rgb_image(f"new_out_valid_{i}.png", grid)
97
+
98
+ def save_test_step(self, pred, gt, cond, batch_idx, res=128):
99
+ exp_name = self._get_exp_name()
100
+ model_name = cond["name"][0].replace("/", '@')
101
+ save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}"
102
+
103
+ # input image
104
+ input_img = cond["img"][0].cpu().numpy()
105
+ # GT recordings
106
+ if not self.hparams.get('test_no_GT', False):
107
+ gt_json = self.convert_json(gt[0], cond, 0)
108
+ # gt_graph = viz_graph(gt_json, res=256)
109
+ gt_meshes = prepare_meshes(gt_json)
110
+ bbox_0, bbox_1, axiss = (
111
+ gt_meshes["bbox_0"],
112
+ gt_meshes["bbox_1"],
113
+ gt_meshes["axiss"],
114
+ )
115
+ gt_img = draw_boxes_axiss_anim(bbox_0, bbox_1, axiss, mode="graph", resolution=res)
116
+ else:
117
+ # gt_graph = 255 * np.ones((res, res, 3), dtype=np.uint8)
118
+ gt_img = 255 * np.ones((res, 2 * res, 3), dtype=np.uint8)
119
+ gt_block = np.concatenate([input_img, gt_img], axis=1)
120
+
121
+ # recordings for generated results
122
+ img_blocks = []
123
+ for b in range(pred.shape[0]):
124
+ pred_json = self.convert_json(pred[b], cond, 0)
125
+ # visualize bbox and axis
126
+ pred_meshes = prepare_meshes(pred_json)
127
+ bbox_0, bbox_1, axiss = (
128
+ pred_meshes["bbox_0"],
129
+ pred_meshes["bbox_1"],
130
+ pred_meshes["axiss"],
131
+ )
132
+ pred_img = draw_boxes_axiss_anim(
133
+ bbox_0, bbox_1, axiss, mode="graph", resolution=res
134
+ )
135
+ img_blocks.append(pred_img)
136
+ self.save_json(f"{save_dir}/{b}/object.json", pred_json)
137
+ # save images for generated results
138
+ img_grid = make_grid(img_blocks, cols=5)
139
+ # visualize the input graph
140
+ # input_graph = viz_graph(pred_json, res=256)
141
+
142
+ # save images
143
+ # self.save_rgb_image(f"{save_dir}/gt_graph.png", gt_graph)
144
+ self.save_rgb_image(f"{save_dir}/output.png", img_grid)
145
+ self.save_rgb_image(f"{save_dir}/gt.png", gt_block)
146
+ # self.save_rgb_image(f"{save_dir}/input_graph.png", input_graph)
147
+
148
+ def _save_html_end(self):
149
+ exp_name = self._get_exp_name()
150
+ save_dir = self.get_save_path(exp_name)
151
+ cases = sorted(os.listdir(save_dir), key=lambda x: int(x.split("@")[0]))
152
+ html_head = """
153
+ <!DOCTYPE html>
154
+ <html lang="en">
155
+ <head>
156
+ <meta charset="UTF-8">
157
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
158
+ <title>Test Image Results</title>
159
+ <style>
160
+ table {
161
+ width: 100%;
162
+ border-collapse: collapse;
163
+ }
164
+ th, td {
165
+ border: 1px solid black;
166
+ padding: 8px;
167
+ text-align: left;
168
+ }
169
+ .separator {
170
+ border-top: 2px solid black;
171
+ }
172
+ </style>
173
+ </head>
174
+ <body>
175
+ <table>
176
+
177
+ """
178
+ total = len(cases)
179
+ each = 200
180
+ n_pages = total // each + 1
181
+ for p in range(n_pages):
182
+ html_content = html_head
183
+ for i in range(p * each, min((p + 1) * each, total)):
184
+ case = cases[i]
185
+ if self.hparams.get("test_no_GT", False):
186
+ aid_iou = rid_iou = aid_cdist = rid_cdist = aid_cd = rid_cd = aor = "N/A"
187
+ else:
188
+ with open(os.path.join(save_dir, case, "metrics.json"), "r") as f:
189
+ metrics = json.load(f)["avg"]
190
+ aid_iou = round(metrics["AS-IoU"], 4)
191
+ rid_iou = round(metrics["RS-IoU"], 4)
192
+ aid_cdist = round(metrics["AS-cDist"], 4)
193
+ rid_cdist = round(metrics["RS-cDist"], 4)
194
+ aid_cd = round(metrics["AS-CD"], 4)
195
+ rid_cd = round(metrics["RS-CD"], 4)
196
+ aor = metrics["AOR"]
197
+ if aor is not None:
198
+ aor = round(aor, 4)
199
+ html_content += f"""
200
+ <tr>
201
+ <th>Object ID</th>
202
+ <th>Metrics (avg) </th>
203
+ <th>Input image + GT object + GT graph</th>
204
+ <th>Input graph </th>
205
+ </tr>
206
+ <tr>
207
+ <td rowspan="3">{case}</td>
208
+ <td>
209
+ [AS-cDist] {aid_cdist}<br>
210
+ [RS-cDist] {rid_cdist}<br>
211
+ -----------------------<br>
212
+ [AS-IoU] {aid_iou}<br>
213
+ [RS-IoU] {rid_iou}<br>
214
+ -----------------------<br>
215
+ [RS-CD] {rid_cd}<br>
216
+ [AS-CD] {aid_cd}<br>
217
+ -----------------------<br>
218
+ [AOR] {aor}<br>
219
+ </td>
220
+ <td>
221
+ <img src="{exp_name}/{case}/gt.png" alt="GT Image" style="height: 128px; width: 3*128px;">
222
+ <img src="{exp_name}/{case}/gt_graph.png" alt="Graph Image" style="height: 128px; width: 3*128px;">
223
+ </td>
224
+ <td>
225
+ <img src="{exp_name}/{case}/input_graph.png" alt="Graph Image" style="height: 128px; width: 3*128px;">
226
+ </td>
227
+ </tr>
228
+ <tr><th colspan="3">Generated samples</th></tr>
229
+ <tr>
230
+ <td colspan="3"><img src="{exp_name}/{case}/output.png" alt="Generated Image" style="height: 3*128px; width: 10*128px;"></td>
231
+ </tr>
232
+ <tr class="separator"><td colspan="4"></td></tr>
233
+ """
234
+ html_content += """</table></body></html>"""
235
+ outfile = self.get_save_path(f"{exp_name}_page_{p+1}.html")
236
+ with open(outfile, "w") as file:
237
+ file.write(html_content)
238
+
239
+ def val_compute_metrics(self, pred, gt, cond):
240
+ loss_dict = {}
241
+ B = pred.shape[0]
242
+ as_ious = 0.0
243
+ rs_ious = 0.0
244
+ as_cdists = 0.0
245
+ rs_cdists = 0.0
246
+ for b in range(B):
247
+ gt_json = self.convert_json(gt[b], cond, b)
248
+ pred_json = self.convert_json(pred[b], cond, b)
249
+ scores = IoU_cDist(
250
+ pred_json,
251
+ gt_json,
252
+ num_states=5,
253
+ compare_handles=True,
254
+ iou_include_base=True,
255
+ )
256
+ as_ious += scores['AS-IoU']
257
+ rs_ious += scores['RS-IoU']
258
+ as_cdists += scores['AS-cDist']
259
+ rs_cdists += scores['RS-cDist']
260
+
261
+ as_ious /= B
262
+ rs_ious /= B
263
+ as_cdists /= B
264
+ rs_cdists /= B
265
+
266
+ loss_dict['val/AS-IoU'] = as_ious
267
+ loss_dict['val/RS-IoU'] = rs_ious
268
+ loss_dict['val/AS-cDist'] = as_cdists
269
+ loss_dict['val/RS-cDist'] = rs_cdists
270
+
271
+ return loss_dict
272
+
273
+ def _get_exp_name(self):
274
+ which_ds = self.hparams.get("test_which", 'pm')
275
+ is_pred_G = self.hparams.get("test_pred_G", False)
276
+ is_label_free = self.hparams.get("test_label_free", False)
277
+ guidance_scaler = self.hparams.get("guidance_scaler", 0)
278
+ # config saving directory
279
+ exp_postfix = f"_w={guidance_scaler}_{which_ds}"
280
+ if is_pred_G:
281
+ exp_postfix += "_pred_G"
282
+ if is_label_free:
283
+ exp_postfix += "_label_free"
284
+
285
+ exp_name = "epoch_" + str(self.current_epoch).zfill(3) + exp_postfix
286
+ return exp_name
systems/dino_dummy.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67b13dadf868704eb0e5a1b55355d54bce806b7f9d8d877cdf4142f759544bbd
3
+ size 1572992
systems/plot.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ import matplotlib
4
+ matplotlib.use('Agg')
5
+ import numpy as np
6
+ import networkx as nx
7
+ from io import BytesIO
8
+ from PIL import Image, ImageDraw
9
+ from matplotlib import pyplot as plt
10
+ from sklearn.decomposition import PCA
11
+ from singapo_utils.refs import graph_color_ref
12
+
13
+ def add_text(text, imgarr):
14
+ '''
15
+ Function to add text to image
16
+
17
+ Args:
18
+ - text (str): text to add
19
+ - imgarr (np.array): image array
20
+
21
+ Returns:
22
+ - img (np.array): image array with text
23
+ '''
24
+ img = Image.fromarray(imgarr)
25
+ I = ImageDraw.Draw(img)
26
+ I.text((10, 10), text, fill='black')
27
+ return np.asarray(img)
28
+
29
+ def get_color(ref, n_nodes):
30
+ '''
31
+ Function to color the nodes
32
+
33
+ Args:
34
+ - ref (list): list of color reference
35
+ - n_nodes (int): number of nodes
36
+
37
+ Returns:
38
+ - colors (list): list of colors
39
+ '''
40
+ N = len(ref)
41
+ colors = []
42
+ for i in range(n_nodes):
43
+ colors.append(np.array([[int(i) for i in ref[i%N][4:-1].split(',')]]) / 255.)
44
+ return colors
45
+
46
+
47
+ def make_grid(images, cols=5):
48
+ """
49
+ Arrange list of images into a N x cols grid.
50
+
51
+ Args:
52
+ - images (list): List of Numpy arrays representing the images.
53
+ - cols (int): Number of columns for the grid.
54
+
55
+ Returns:
56
+ - grid (numpy array): Numpy array representing the image grid.
57
+ """
58
+ # Determine the dimensions of each image
59
+ img_h, img_w, _ = images[0].shape
60
+ rows = len(images) // cols
61
+
62
+ # Initialize a blank canvas
63
+ grid = np.zeros((rows * img_h, cols * img_w, 3), dtype=images[0].dtype)
64
+
65
+ # Place each image onto the grid
66
+ for idx, img in enumerate(images):
67
+ y = (idx // cols) * img_h
68
+ x = (idx % cols) * img_w
69
+ grid[y: y + img_h, x: x + img_w] = img
70
+
71
+ return grid
72
+
73
+ def viz_graph(info_dict, res=256):
74
+ '''
75
+ Function to plot the directed graph
76
+
77
+ Args:
78
+ - info_dict (dict): output json containing the graph information
79
+ - res (int): resolution of the image
80
+
81
+ Returns:
82
+ - img_arr (np.array): image array
83
+ '''
84
+ # build tree
85
+ tree = info_dict['diffuse_tree']
86
+ edges = []
87
+ for node in tree:
88
+ edges += [(node['id'], child) for child in node['children']]
89
+ G = nx.DiGraph()
90
+ G.add_edges_from(edges)
91
+
92
+ # plot tree
93
+ plt.figure(figsize=(res/100, res/100))
94
+
95
+ colors = get_color(graph_color_ref, len(tree))
96
+ pos = nx.nx_agraph.graphviz_layout(G, prog="twopi", args="")
97
+ node_order = sorted(G.nodes())
98
+ nx.draw(G, pos, node_color=colors, nodelist=node_order, edge_color='k', with_labels=False)
99
+
100
+ buf = BytesIO()
101
+ plt.savefig(buf, format="png", dpi=100)
102
+ buf.seek(0)
103
+ img = Image.open(buf)
104
+ img_arr = np.asarray(img)
105
+ buf.close()
106
+ plt.clf()
107
+ plt.close()
108
+ return img_arr[:, :, :3]
109
+
110
+ def viz_patch_feat_pca(feat):
111
+ pca = PCA(n_components=3)
112
+ pca.fit(feat)
113
+ feat_pca = pca.transform(feat)
114
+
115
+ t = np.array(feat_pca)
116
+ t_min = t.min(axis=0, keepdims=True)
117
+ t_max = t.max(axis=0, keepdims=True)
118
+ normalized_t = (t - t_min) / (t_max - t_min)
119
+
120
+ array = (normalized_t * 255).astype(np.uint8)
121
+ img_array = array.reshape(16, 16, 3)
122
+ return img_array
systems/system_origin.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ import torch
4
+ import subprocess
5
+ import numpy as np
6
+ import models
7
+ import systems
8
+ import torch.nn.functional as F
9
+ from diffusers import DDPMScheduler
10
+ from systems.base import BaseSystem
11
+ from my_utils.lr_schedulers import LinearWarmupCosineAnnealingLR
12
+ from datetime import datetime
13
+ import logging
14
+
15
+ @systems.register("sys_origin")
16
+ class SingapoSystem(BaseSystem):
17
+ """Trainer for the B9 model, incorporating the classifier-free for image condition."""
18
+
19
+ def __init__(self, hparams):
20
+ super().__init__(hparams)
21
+ self.model = models.make(hparams.model.name, hparams.model)
22
+ # configure the scheduler of DDPM
23
+ self.scheduler = DDPMScheduler(**self.hparams.scheduler.config)
24
+ # load the dummy DINO features
25
+ self.dummy_dino = np.load('systems/dino_dummy.npy').astype(np.float32)
26
+ # use the manual optimization
27
+ self.automatic_optimization = False
28
+ # save the hyperparameters
29
+ self.save_hyperparameters()
30
+
31
+ self.custom_logger = logging.getLogger(__name__)
32
+ self.custom_logger.setLevel(logging.INFO)
33
+ if self.global_rank == 0:
34
+ self.custom_logger.addHandler(logging.StreamHandler())
35
+
36
+ def load_cage_weights(self, pretrained_ckpt=None):
37
+ ckpt = torch.load(pretrained_ckpt)
38
+ state_dict = ckpt["state_dict"]
39
+ # remove the "model." prefix from the keys
40
+ state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
41
+ # load the weights
42
+ self.model.load_state_dict(state_dict, strict=False)
43
+ # separate the weights of CAGE and our new modules
44
+ print("[INFO] loaded model weights of the pretrained CAGE.")
45
+
46
+
47
+ def fg_loss(self, all_attn_maps, loss_masks):
48
+ """
49
+ Excite the attention maps within the object regions, while weaken the attention outside the object regions.
50
+
51
+ Args:
52
+ all_attn_maps: cross-attention maps from all layers, shape (B*L, H, 160, 256)
53
+ loss_masks: object seg mask on the image patches, shape (B, 160, 256)
54
+
55
+ Returns:
56
+ loss: loss on the attention maps
57
+ """
58
+ valid_mask = loss_masks['valid_nodes']
59
+ fg_mask = loss_masks['fg']
60
+ # get the number of layers and batch size
61
+ L = self.hparams.model.n_layers
62
+ H = all_attn_maps.shape[1]
63
+ # Reshape all the masks to the shape of the attention maps
64
+ valid_node = valid_mask[:, :, 0].unsqueeze(1).expand(-1, H, -1).unsqueeze(-1).expand(-1, -1, -1, 256).repeat(L, 1, 1, 1)
65
+ obj_region = fg_mask.unsqueeze(1).expand(-1, H, -1, -1).repeat(L, 1, 1, 1)
66
+ # construct masks for the object and non-object regions
67
+ fg_region = torch.logical_and(valid_node, obj_region)
68
+ bg_region = torch.logical_and(valid_node, ~obj_region)
69
+ # loss to excite the foreground regions
70
+ loss = 1. - all_attn_maps[fg_region].mean() + all_attn_maps[bg_region].mean()
71
+ return loss
72
+
73
+ def diffuse_process(self, inputs):
74
+ x = inputs["x"]
75
+ # Sample Gaussian noise
76
+ noise = torch.randn(x.shape, device=self.device, dtype=x.dtype)
77
+ # Sample a random timestep for each image
78
+ timesteps = torch.randint(
79
+ 0,
80
+ self.scheduler.config.num_train_timesteps,
81
+ (x.shape[0],),
82
+ device=self.device,
83
+ dtype=torch.long,
84
+ )
85
+ # Add Gaussian noise to the input
86
+ noisy_x = self.scheduler.add_noise(x, noise, timesteps)
87
+ # update the inputs
88
+ inputs["noise"] = noise
89
+ inputs["timesteps"] = timesteps
90
+ inputs["noisy_x"] = noisy_x
91
+
92
+ def prepare_inputs(self, batch, mode='train', n_samples=1):
93
+ x, c, f = batch
94
+
95
+ cat = c["cat"] # object category
96
+ attr_mask = c["attr_mask"] # attention mask for local self-attention (follow the CAGE)
97
+ key_pad_mask = c["key_pad_mask"] # key padding mask for global self-attention (follow the CAGE)
98
+ graph_mask = c["adj_mask"] # attention mask for graph relation self-attention (follow the CAGE)
99
+
100
+ inputs = {}
101
+ if mode == 'train':
102
+ # the number of sampled timesteps per iteration
103
+ n_repeat = self.hparams.n_time_samples
104
+ # for sampling multiple timesteps
105
+ x = x.repeat(n_repeat, 1, 1)
106
+ cat = cat.repeat(n_repeat)
107
+ f = f.repeat(n_repeat, 1, 1)
108
+ key_pad_mask = key_pad_mask.repeat(n_repeat, 1, 1)
109
+ graph_mask = graph_mask.repeat(n_repeat, 1, 1)
110
+ attr_mask = attr_mask.repeat(n_repeat, 1, 1)
111
+ elif mode == 'val':
112
+ noisy_x = torch.randn(x.shape, device=x.device)
113
+ dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
114
+ inputs["noisy_x"] = noisy_x
115
+ inputs["dummy_f"] = dummy_f
116
+ elif mode == 'test':
117
+ # for sampling multiple outputs
118
+ x = x.repeat(n_samples, 1, 1)
119
+ cat = cat.repeat(n_samples)
120
+ f = f.repeat(n_samples, 1, 1)
121
+ key_pad_mask = key_pad_mask.repeat(n_samples, 1, 1)
122
+ graph_mask = graph_mask.repeat(n_samples, 1, 1)
123
+ attr_mask = attr_mask.repeat(n_samples, 1, 1)
124
+ noisy_x = torch.randn(x.shape, device=x.device)
125
+ dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
126
+ inputs["noisy_x"] = noisy_x
127
+ inputs["dummy_f"] = dummy_f.repeat(1, 2, 1)
128
+ else:
129
+ raise ValueError(f"Invalid mode: {mode}")
130
+
131
+ inputs["x"] = x
132
+ inputs["f"] = f
133
+ inputs["cat"] = cat
134
+ inputs["key_pad_mask"] = key_pad_mask
135
+ inputs["graph_mask"] = graph_mask
136
+ inputs["attr_mask"] = attr_mask
137
+
138
+ return inputs
139
+
140
+ def prepare_loss_mask(self, batch):
141
+ x, c, _ = batch
142
+ n_repeat = self.hparams.n_time_samples # the number of sampled timesteps per iteration
143
+
144
+ # mask on the image patches for the foreground regions
145
+ # mask_fg = c["img_obj_mask"]
146
+ # if mask_fg is not None:
147
+ # mask_fg = mask_fg.repeat(n_repeat, 1, 1)
148
+
149
+ # mask on the valid nodes
150
+ index_tensor = torch.arange(x.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0) # (1, N)
151
+ valid_nodes = index_tensor < (c['n_nodes'] * 5).unsqueeze(-1)
152
+ mask_valid_nodes = valid_nodes.unsqueeze(-1).expand_as(x)
153
+ mask_valid_nodes = mask_valid_nodes.repeat(n_repeat, 1, 1)
154
+
155
+ return {"fg": None, "valid_nodes": mask_valid_nodes}
156
+
157
+ def manage_cfg(self, inputs):
158
+ '''
159
+ Manage the classifier-free training for the image and graph condition.
160
+ The CFG for object category is managed by the model (i.e. the CombinedTimestepLabelEmbeddings module in norm1 for each attention block)
161
+ '''
162
+ img_drop_prob = self.hparams.get("img_drop_prob", 0.0)
163
+ graph_drop_prob = self.hparams.get("graph_drop_prob", 0.0)
164
+ drop_img, drop_graph = False, False
165
+
166
+ if img_drop_prob > 0.0:
167
+ drop_img = torch.rand(1) < img_drop_prob
168
+ if drop_img.item():
169
+ dummy_batch = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(inputs['f'])
170
+ inputs['f'] = dummy_batch # use the dummy DINO features
171
+
172
+ if graph_drop_prob > 0.0:
173
+ if not drop_img:
174
+ drop_graph = torch.rand(1) < graph_drop_prob
175
+ if drop_graph.item():
176
+ inputs['graph_mask'] = None # for varify the model only, replace with the below line later and retrain the model
177
+ # inputs['graph_mask'] = inputs['key_pad_mask'] # use the key padding mask
178
+
179
+ def compute_loss(self, batch, inputs, outputs):
180
+ loss_dict = {}
181
+ # loss_weight = self.hparams.get("loss_fg_weight", 1.0)
182
+
183
+ # prepare the loss masks
184
+ loss_masks = self.prepare_loss_mask(batch)
185
+
186
+ # diffusion model loss: MSE on the residual noise
187
+ loss_mse = F.mse_loss(outputs['noise_pred'] * loss_masks['valid_nodes'], inputs['noise'] * loss_masks['valid_nodes'])
188
+ # attention mask loss: BCE loss on the attention maps
189
+ # loss_fg = loss_weight * self.fg_loss(outputs['attn_maps'], loss_masks)
190
+
191
+ # total loss
192
+ loss = loss_mse
193
+
194
+ # log the losses
195
+ loss_dict["train/loss_mse"] = loss_mse
196
+ loss_dict["train/loss_total"] = loss
197
+
198
+ return loss, loss_dict
199
+
200
+ def training_step(self, batch, batch_idx):
201
+ # prepare the inputs and GT
202
+ inputs = self.prepare_inputs(batch, mode='train')
203
+
204
+ # manage the classifier-free training
205
+ self.manage_cfg(inputs)
206
+
207
+ # forward: diffusion process
208
+ self.diffuse_process(inputs)
209
+
210
+ # reverse: denoising process
211
+ outputs = self.model(
212
+ x=inputs['noisy_x'],
213
+ cat=inputs['cat'],
214
+ timesteps=inputs['timesteps'],
215
+ feat=inputs['f'],
216
+ key_pad_mask=inputs['key_pad_mask'],
217
+ graph_mask=inputs['graph_mask'],
218
+ attr_mask=inputs['attr_mask'],
219
+ )
220
+
221
+ # compute the loss
222
+ loss, loss_dict = self.compute_loss(batch, inputs, outputs)
223
+
224
+ # manual backward
225
+ opt1, opt2 = self.optimizers()
226
+ opt1.zero_grad()
227
+ opt2.zero_grad()
228
+ self.manual_backward(loss)
229
+ opt1.step()
230
+ opt2.step()
231
+
232
+ if batch_idx % 20 == 0 and self.global_rank == 0:
233
+ now = datetime.now()
234
+ now_str = now.strftime("%Y-%m-%d %H:%M:%S")
235
+ loss_str = f'Epoch:{self.current_epoch} | Step:{batch_idx:03d} | '
236
+ for key, value in loss_dict.items():
237
+ loss_str += f"{key}: {value.item():.4f} | "
238
+ self.custom_logger.info(now_str + ' | ' + loss_str)
239
+ # logging
240
+ # self.log_dict(loss_dict, sync_dist=True, on_step=True, on_epoch=False)
241
+
242
+ def on_train_epoch_end(self):
243
+ # step the lr scheduler every epoch
244
+ sch1, sch2 = self.lr_schedulers()
245
+ sch1.step()
246
+ sch2.step()
247
+
248
+ def inference(self, inputs, is_label_free=False):
249
+ device = inputs['x'].device
250
+ omega = self.hparams.get("guidance_scaler", 0)
251
+ noisy_x = inputs['noisy_x']
252
+
253
+ # set scheduler to denoise every 100 steps
254
+ self.scheduler.set_timesteps(100)
255
+ # denoising process
256
+ for t in self.scheduler.timesteps:
257
+ timesteps = torch.tensor([t], device=device)
258
+ outputs_cond = self.model(
259
+ x=noisy_x,
260
+ cat=inputs['cat'],
261
+ timesteps=timesteps,
262
+ feat=inputs['f'],
263
+ key_pad_mask=inputs['key_pad_mask'],
264
+ graph_mask=inputs['graph_mask'],
265
+ attr_mask=inputs['attr_mask'],
266
+ label_free=is_label_free,
267
+ ) # take condtional image as input
268
+ if omega != 0:
269
+ outputs_free = self.model(
270
+ x=noisy_x,
271
+ cat=inputs['cat'],
272
+ timesteps=timesteps,
273
+ feat=inputs['dummy_f'],
274
+ key_pad_mask=inputs['key_pad_mask'],
275
+ graph_mask=inputs['graph_mask'],
276
+ attr_mask=inputs['attr_mask'],
277
+ label_free=is_label_free,
278
+ ) # take the dummy DINO features for the condition-free mode
279
+ noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred']
280
+ else:
281
+ noise_pred = outputs_cond['noise_pred']
282
+ noisy_x = self.scheduler.step(noise_pred, t, noisy_x).prev_sample
283
+
284
+ return noisy_x
285
+
286
+ def validation_step(self, batch, batch_idx):
287
+ # prepare the inputs and GT
288
+ inputs = self.prepare_inputs(batch, mode='val')
289
+ # denoising process for inference
290
+ out = self.inference(inputs)
291
+ # compute the metrics
292
+ # new_out = torch.zeros_like(out).type_as(out).to(out.device)
293
+ # for b in range(out.shape[0]):
294
+ # for k in range(32):
295
+ # if out[b][(k + 1) * 6 - 1].mean() > 0.5:
296
+ # new_out[b][k * 6: (k + 1) * 6] = out[b][k * 6: (k + 1) * 6]
297
+ # zero center
298
+
299
+ # rescale
300
+
301
+ # ready
302
+ # out = new_out
303
+ # new_out = torch.zeros_like(out).type_as(out).to(out.device)
304
+ # for b in range(out.shape[0]):
305
+ # for k in range(32):
306
+ # min_aabb_diff = 1e10
307
+ # min_index = k
308
+ # aabb_center = (out[b][k * 6][:3] + out[b][k * 6 ][3:]) / 2
309
+ # for k_gt in range(32):
310
+ # aabb_gt_center = (batch[1][b][k_gt * 6][:3] + batch[1][b][k_gt * 6][3:]) / 2
311
+ # aabb_diff = torch.norm(aabb_center - aabb_gt_center)
312
+ # if aabb_diff < min_aabb_diff:
313
+ # min_aabb_diff = aabb_diff
314
+ # min_index = k_gt
315
+ # new_out[b][min_index * 6: (min_index + 1) * 6] = out[b][k * 6: (k + 1) * 6]
316
+ # out = new_out
317
+
318
+ log_dict = self.val_compute_metrics(out, inputs['x'], batch[1])
319
+ self.log_dict(log_dict, on_step=True)
320
+
321
+ # visualize the first 10 results
322
+ # self.save_val_img(out[:16], inputs['x'][:16], batch[1])
323
+
324
+ def test_step(self, batch, batch_idx):
325
+ # exp_name = self._get_exp_name()
326
+ # print(self.get_save_path(exp_name))
327
+ # if batch_idx > 2:
328
+ # return
329
+ # return
330
+ is_label_free = self.hparams.get("test_label_free", False)
331
+ exp_name = self._get_exp_name()
332
+ model_name = batch[1]["name"][0].replace("/", '@')
333
+ save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}"
334
+ print(save_dir)
335
+ if os.path.exists(self.get_save_path(f"{save_dir}/output.png")):
336
+
337
+ return
338
+ # prepare the inputs and GT
339
+ inputs = self.prepare_inputs(batch, mode='test', n_samples=5)
340
+ # denoising process for inference
341
+ out = self.inference(inputs, is_label_free)
342
+ # save the results
343
+ self.save_test_step(out, inputs['x'], batch[1], batch_idx)
344
+
345
+ def on_test_end(self):
346
+ # only run the single GPU
347
+ # if self.global_rank == 0:
348
+ # exp_name = self._get_exp_name()
349
+ # # retrieve parts
350
+ # subprocess.run(['python', 'scripts/mesh_retrieval/run_retrieve.py', '--src', self.get_save_path(exp_name), '--json_name', 'object.json', '--gt_data_root', '../singapo'])
351
+ # # save metrics
352
+ # if not self.hparams.get("test_no_GT", False):
353
+ # subprocess.run(['python', 'scripts/eval_metrics.py', '--exp_dir', self.get_save_path(exp_name), '--gt_root', '../acd_data/'])
354
+ # # save html
355
+ # self._save_html_end()
356
+ pass
357
+
358
+ def configure_optimizers(self):
359
+ self.cage_params = self.adapter_params = []
360
+ for name, param in self.model.named_parameters():
361
+ if "img" in name or "norm5" in name or "norm6" in name:
362
+ self.adapter_params.append(param)
363
+ else:
364
+ self.cage_params.append(param)
365
+ optimizer_adapter = torch.optim.AdamW(
366
+ self.adapter_params, **self.hparams.optimizer_adapter.args
367
+ )
368
+ lr_scheduler_adapter = LinearWarmupCosineAnnealingLR(
369
+ optimizer_adapter,
370
+ warmup_epochs=self.hparams.lr_scheduler_adapter.warmup_epochs,
371
+ max_epochs=self.hparams.lr_scheduler_adapter.max_epochs,
372
+ warmup_start_lr=self.hparams.lr_scheduler_adapter.warmup_start_lr,
373
+ eta_min=self.hparams.lr_scheduler_adapter.eta_min,
374
+ )
375
+
376
+ optimizer_cage = torch.optim.AdamW(
377
+ self.cage_params, **self.hparams.optimizer_cage.args
378
+ )
379
+ lr_scheduler_cage = LinearWarmupCosineAnnealingLR(
380
+ optimizer_cage,
381
+ warmup_epochs=self.hparams.lr_scheduler_cage.warmup_epochs,
382
+ max_epochs=self.hparams.lr_scheduler_cage.max_epochs,
383
+ warmup_start_lr=self.hparams.lr_scheduler_cage.warmup_start_lr,
384
+ eta_min=self.hparams.lr_scheduler_cage.eta_min,
385
+ )
386
+ return (
387
+ {"optimizer": optimizer_adapter, "lr_scheduler": lr_scheduler_adapter},
388
+ {"optimizer": optimizer_cage, "lr_scheduler": lr_scheduler_cage},
389
+ )
390
+
391
+