Spaces:
Running
on
Zero
Running
on
Zero
Upload 29 files
Browse filesUpload large files.
- .gitattributes +22 -0
- ckpts/dipo.ckpt +3 -0
- examples/1.png +3 -0
- examples/1_open_1.png +3 -0
- examples/1_open_2.png +3 -0
- examples/close1.png +3 -0
- examples/close10.png +3 -0
- examples/close2.png +3 -0
- examples/close3.png +3 -0
- examples/close4.png +3 -0
- examples/close5.png +3 -0
- examples/close6.png +0 -0
- examples/close7.png +3 -0
- examples/close8.png +3 -0
- examples/close9.jpg +3 -0
- examples/open1.png +3 -0
- examples/open10.png +3 -0
- examples/open2.png +3 -0
- examples/open3.png +3 -0
- examples/open4.png +3 -0
- examples/open5.png +3 -0
- examples/open6.png +3 -0
- examples/open7.png +3 -0
- examples/open8.png +3 -0
- examples/open9.jpg +3 -0
- systems/__init__.py +22 -0
- systems/base.py +286 -0
- systems/dino_dummy.npy +3 -0
- systems/plot.py +122 -0
- systems/system_origin.py +391 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/1_open_1.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/1_open_2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/1.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/close1.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/close10.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/close2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/close3.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/close4.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/close5.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
examples/close7.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
examples/close8.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
examples/close9.jpg filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
examples/open1.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
examples/open10.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
examples/open2.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
examples/open3.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
examples/open4.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
examples/open5.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
examples/open6.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
examples/open7.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
examples/open8.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
examples/open9.jpg filter=lfs diff=lfs merge=lfs -text
|
ckpts/dipo.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:493f551499b95af57b5bb6e872d1107a9cf4056fbf151fc45f416f96a919dad6
|
| 3 |
+
size 24565754
|
examples/1.png
ADDED
|
Git LFS Details
|
examples/1_open_1.png
ADDED
|
Git LFS Details
|
examples/1_open_2.png
ADDED
|
Git LFS Details
|
examples/close1.png
ADDED
|
Git LFS Details
|
examples/close10.png
ADDED
|
Git LFS Details
|
examples/close2.png
ADDED
|
Git LFS Details
|
examples/close3.png
ADDED
|
Git LFS Details
|
examples/close4.png
ADDED
|
Git LFS Details
|
examples/close5.png
ADDED
|
Git LFS Details
|
examples/close6.png
ADDED
|
examples/close7.png
ADDED
|
Git LFS Details
|
examples/close8.png
ADDED
|
Git LFS Details
|
examples/close9.jpg
ADDED
|
Git LFS Details
|
examples/open1.png
ADDED
|
Git LFS Details
|
examples/open10.png
ADDED
|
Git LFS Details
|
examples/open2.png
ADDED
|
Git LFS Details
|
examples/open3.png
ADDED
|
Git LFS Details
|
examples/open4.png
ADDED
|
Git LFS Details
|
examples/open5.png
ADDED
|
Git LFS Details
|
examples/open6.png
ADDED
|
Git LFS Details
|
examples/open7.png
ADDED
|
Git LFS Details
|
examples/open8.png
ADDED
|
Git LFS Details
|
examples/open9.jpg
ADDED
|
Git LFS Details
|
systems/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
systems = {}
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def register(name):
|
| 5 |
+
def decorator(cls):
|
| 6 |
+
systems[name] = cls
|
| 7 |
+
return cls
|
| 8 |
+
|
| 9 |
+
return decorator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def make(name, config, load_from_checkpoint=None):
|
| 13 |
+
if load_from_checkpoint is None:
|
| 14 |
+
system = systems[name](config)
|
| 15 |
+
else:
|
| 16 |
+
system = systems[name].load_from_checkpoint(
|
| 17 |
+
load_from_checkpoint, strict=False, config=config
|
| 18 |
+
)
|
| 19 |
+
return system
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from . import system_origin
|
systems/base.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
import lightning.pytorch as pl
|
| 7 |
+
from metrics.iou_cdist import IoU_cDist
|
| 8 |
+
from my_utils.savermixins import SaverMixin
|
| 9 |
+
from my_utils.refs import sem_ref, joint_ref
|
| 10 |
+
from dataset.utils import convert_data_range, parse_tree
|
| 11 |
+
from my_utils.plot import viz_graph, make_grid, add_text
|
| 12 |
+
from my_utils.render import draw_boxes_axiss_anim, prepare_meshes
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
class BaseSystem(pl.LightningModule, SaverMixin):
|
| 16 |
+
def __init__(self, hparams):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.hparams.update(hparams)
|
| 19 |
+
|
| 20 |
+
def setup(self, stage: str):
|
| 21 |
+
# config the logger dir for images
|
| 22 |
+
self.hparams.save_dir = os.path.join(self.hparams.exp_dir, 'output', stage)
|
| 23 |
+
os.makedirs(self.hparams.save_dir, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# --------------------------------- visualization ---------------------------------
|
| 27 |
+
|
| 28 |
+
def convert_json(self, x, c, idx, prefix=''):
|
| 29 |
+
out = {"meta": {}, "diffuse_tree": []}
|
| 30 |
+
|
| 31 |
+
n_nodes = c[f"{prefix}n_nodes"][idx].item()
|
| 32 |
+
par = c[f"{prefix}parents"][idx].cpu().numpy().tolist()
|
| 33 |
+
adj = c[f"{prefix}adj"][idx].cpu().numpy()
|
| 34 |
+
np.fill_diagonal(adj, 0) # remove self-loop for the root node
|
| 35 |
+
if f"{prefix}obj_cat" in c:
|
| 36 |
+
out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][idx]
|
| 37 |
+
|
| 38 |
+
# convert the data to original range
|
| 39 |
+
data = convert_data_range(x.cpu().numpy())
|
| 40 |
+
# parse the tree
|
| 41 |
+
out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj)
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
# def save_val_img(self, pred, gt, cond):
|
| 45 |
+
# B = pred.shape[0]
|
| 46 |
+
# pred_imgs, gt_imgs, gt_graphs_view = [], [], []
|
| 47 |
+
# for b in range(B):
|
| 48 |
+
# print(b)
|
| 49 |
+
# # convert to humnan readable format json
|
| 50 |
+
# pred_json = self.convert_json(pred[b], cond, b)
|
| 51 |
+
# gt_json = self.convert_json(gt[b], cond, b)
|
| 52 |
+
# # visualize bbox and axis
|
| 53 |
+
# pred_meshes = prepare_meshes(pred_json)
|
| 54 |
+
# bbox_0, bbox_1, axiss = (
|
| 55 |
+
# pred_meshes["bbox_0"],
|
| 56 |
+
# pred_meshes["bbox_1"],
|
| 57 |
+
# pred_meshes["axiss"],
|
| 58 |
+
# )
|
| 59 |
+
# pred_img = draw_boxes_axiss_anim(
|
| 60 |
+
# bbox_0, bbox_1, axiss, mode="graph", resolution=128
|
| 61 |
+
# )
|
| 62 |
+
# gt_meshes = prepare_meshes(gt_json)
|
| 63 |
+
# bbox_0, bbox_1, axiss = (
|
| 64 |
+
# gt_meshes["bbox_0"],
|
| 65 |
+
# gt_meshes["bbox_1"],
|
| 66 |
+
# gt_meshes["axiss"],
|
| 67 |
+
# )
|
| 68 |
+
# gt_img = draw_boxes_axiss_anim(
|
| 69 |
+
# bbox_0, bbox_1, axiss, mode="graph", resolution=128
|
| 70 |
+
# )
|
| 71 |
+
# # visualize graph
|
| 72 |
+
# # gt_graph = viz_graph(gt_json, res=128)
|
| 73 |
+
# # gt_graph = add_text(cond["name"][b], gt_graph)
|
| 74 |
+
# # GT views
|
| 75 |
+
# rgb_view = cond["img"][b].cpu().numpy()
|
| 76 |
+
|
| 77 |
+
# pred_imgs.append(pred_img)
|
| 78 |
+
# gt_imgs.append(gt_img)
|
| 79 |
+
# gt_graphs_view.append(rgb_view)
|
| 80 |
+
# # gt_graphs_view.append(gt_graph)
|
| 81 |
+
|
| 82 |
+
# # save images for generated results
|
| 83 |
+
# epoch = str(self.current_epoch).zfill(5)
|
| 84 |
+
# # pred_thumbnails = np.concatenate(pred_imgs, axis=1) # concat batch in width
|
| 85 |
+
|
| 86 |
+
# import ipdb
|
| 87 |
+
# ipdb.set_trace()
|
| 88 |
+
# # save images for ground truth
|
| 89 |
+
# for i in range(math.ceil(len(gt_graphs_view) / 8)):
|
| 90 |
+
# start = i * 8
|
| 91 |
+
# end = min((i + 1) * 8, len(gt_graphs_view))
|
| 92 |
+
# pred_thumbnails = np.concatenate(pred_imgs[start:end], axis=1)
|
| 93 |
+
# gt_graph_imgs = np.concatenate(gt_graphs_view[start:end], axis=1)
|
| 94 |
+
# gt_thumbnails = np.concatenate(gt_imgs[start:end], axis=1) # concat batch in width
|
| 95 |
+
# grid = np.concatenate([gt_graph_imgs, gt_thumbnails, pred_thumbnails], axis=0)
|
| 96 |
+
# self.save_rgb_image(f"new_out_valid_{i}.png", grid)
|
| 97 |
+
|
| 98 |
+
def save_test_step(self, pred, gt, cond, batch_idx, res=128):
|
| 99 |
+
exp_name = self._get_exp_name()
|
| 100 |
+
model_name = cond["name"][0].replace("/", '@')
|
| 101 |
+
save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}"
|
| 102 |
+
|
| 103 |
+
# input image
|
| 104 |
+
input_img = cond["img"][0].cpu().numpy()
|
| 105 |
+
# GT recordings
|
| 106 |
+
if not self.hparams.get('test_no_GT', False):
|
| 107 |
+
gt_json = self.convert_json(gt[0], cond, 0)
|
| 108 |
+
# gt_graph = viz_graph(gt_json, res=256)
|
| 109 |
+
gt_meshes = prepare_meshes(gt_json)
|
| 110 |
+
bbox_0, bbox_1, axiss = (
|
| 111 |
+
gt_meshes["bbox_0"],
|
| 112 |
+
gt_meshes["bbox_1"],
|
| 113 |
+
gt_meshes["axiss"],
|
| 114 |
+
)
|
| 115 |
+
gt_img = draw_boxes_axiss_anim(bbox_0, bbox_1, axiss, mode="graph", resolution=res)
|
| 116 |
+
else:
|
| 117 |
+
# gt_graph = 255 * np.ones((res, res, 3), dtype=np.uint8)
|
| 118 |
+
gt_img = 255 * np.ones((res, 2 * res, 3), dtype=np.uint8)
|
| 119 |
+
gt_block = np.concatenate([input_img, gt_img], axis=1)
|
| 120 |
+
|
| 121 |
+
# recordings for generated results
|
| 122 |
+
img_blocks = []
|
| 123 |
+
for b in range(pred.shape[0]):
|
| 124 |
+
pred_json = self.convert_json(pred[b], cond, 0)
|
| 125 |
+
# visualize bbox and axis
|
| 126 |
+
pred_meshes = prepare_meshes(pred_json)
|
| 127 |
+
bbox_0, bbox_1, axiss = (
|
| 128 |
+
pred_meshes["bbox_0"],
|
| 129 |
+
pred_meshes["bbox_1"],
|
| 130 |
+
pred_meshes["axiss"],
|
| 131 |
+
)
|
| 132 |
+
pred_img = draw_boxes_axiss_anim(
|
| 133 |
+
bbox_0, bbox_1, axiss, mode="graph", resolution=res
|
| 134 |
+
)
|
| 135 |
+
img_blocks.append(pred_img)
|
| 136 |
+
self.save_json(f"{save_dir}/{b}/object.json", pred_json)
|
| 137 |
+
# save images for generated results
|
| 138 |
+
img_grid = make_grid(img_blocks, cols=5)
|
| 139 |
+
# visualize the input graph
|
| 140 |
+
# input_graph = viz_graph(pred_json, res=256)
|
| 141 |
+
|
| 142 |
+
# save images
|
| 143 |
+
# self.save_rgb_image(f"{save_dir}/gt_graph.png", gt_graph)
|
| 144 |
+
self.save_rgb_image(f"{save_dir}/output.png", img_grid)
|
| 145 |
+
self.save_rgb_image(f"{save_dir}/gt.png", gt_block)
|
| 146 |
+
# self.save_rgb_image(f"{save_dir}/input_graph.png", input_graph)
|
| 147 |
+
|
| 148 |
+
def _save_html_end(self):
|
| 149 |
+
exp_name = self._get_exp_name()
|
| 150 |
+
save_dir = self.get_save_path(exp_name)
|
| 151 |
+
cases = sorted(os.listdir(save_dir), key=lambda x: int(x.split("@")[0]))
|
| 152 |
+
html_head = """
|
| 153 |
+
<!DOCTYPE html>
|
| 154 |
+
<html lang="en">
|
| 155 |
+
<head>
|
| 156 |
+
<meta charset="UTF-8">
|
| 157 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 158 |
+
<title>Test Image Results</title>
|
| 159 |
+
<style>
|
| 160 |
+
table {
|
| 161 |
+
width: 100%;
|
| 162 |
+
border-collapse: collapse;
|
| 163 |
+
}
|
| 164 |
+
th, td {
|
| 165 |
+
border: 1px solid black;
|
| 166 |
+
padding: 8px;
|
| 167 |
+
text-align: left;
|
| 168 |
+
}
|
| 169 |
+
.separator {
|
| 170 |
+
border-top: 2px solid black;
|
| 171 |
+
}
|
| 172 |
+
</style>
|
| 173 |
+
</head>
|
| 174 |
+
<body>
|
| 175 |
+
<table>
|
| 176 |
+
|
| 177 |
+
"""
|
| 178 |
+
total = len(cases)
|
| 179 |
+
each = 200
|
| 180 |
+
n_pages = total // each + 1
|
| 181 |
+
for p in range(n_pages):
|
| 182 |
+
html_content = html_head
|
| 183 |
+
for i in range(p * each, min((p + 1) * each, total)):
|
| 184 |
+
case = cases[i]
|
| 185 |
+
if self.hparams.get("test_no_GT", False):
|
| 186 |
+
aid_iou = rid_iou = aid_cdist = rid_cdist = aid_cd = rid_cd = aor = "N/A"
|
| 187 |
+
else:
|
| 188 |
+
with open(os.path.join(save_dir, case, "metrics.json"), "r") as f:
|
| 189 |
+
metrics = json.load(f)["avg"]
|
| 190 |
+
aid_iou = round(metrics["AS-IoU"], 4)
|
| 191 |
+
rid_iou = round(metrics["RS-IoU"], 4)
|
| 192 |
+
aid_cdist = round(metrics["AS-cDist"], 4)
|
| 193 |
+
rid_cdist = round(metrics["RS-cDist"], 4)
|
| 194 |
+
aid_cd = round(metrics["AS-CD"], 4)
|
| 195 |
+
rid_cd = round(metrics["RS-CD"], 4)
|
| 196 |
+
aor = metrics["AOR"]
|
| 197 |
+
if aor is not None:
|
| 198 |
+
aor = round(aor, 4)
|
| 199 |
+
html_content += f"""
|
| 200 |
+
<tr>
|
| 201 |
+
<th>Object ID</th>
|
| 202 |
+
<th>Metrics (avg) </th>
|
| 203 |
+
<th>Input image + GT object + GT graph</th>
|
| 204 |
+
<th>Input graph </th>
|
| 205 |
+
</tr>
|
| 206 |
+
<tr>
|
| 207 |
+
<td rowspan="3">{case}</td>
|
| 208 |
+
<td>
|
| 209 |
+
[AS-cDist] {aid_cdist}<br>
|
| 210 |
+
[RS-cDist] {rid_cdist}<br>
|
| 211 |
+
-----------------------<br>
|
| 212 |
+
[AS-IoU] {aid_iou}<br>
|
| 213 |
+
[RS-IoU] {rid_iou}<br>
|
| 214 |
+
-----------------------<br>
|
| 215 |
+
[RS-CD] {rid_cd}<br>
|
| 216 |
+
[AS-CD] {aid_cd}<br>
|
| 217 |
+
-----------------------<br>
|
| 218 |
+
[AOR] {aor}<br>
|
| 219 |
+
</td>
|
| 220 |
+
<td>
|
| 221 |
+
<img src="{exp_name}/{case}/gt.png" alt="GT Image" style="height: 128px; width: 3*128px;">
|
| 222 |
+
<img src="{exp_name}/{case}/gt_graph.png" alt="Graph Image" style="height: 128px; width: 3*128px;">
|
| 223 |
+
</td>
|
| 224 |
+
<td>
|
| 225 |
+
<img src="{exp_name}/{case}/input_graph.png" alt="Graph Image" style="height: 128px; width: 3*128px;">
|
| 226 |
+
</td>
|
| 227 |
+
</tr>
|
| 228 |
+
<tr><th colspan="3">Generated samples</th></tr>
|
| 229 |
+
<tr>
|
| 230 |
+
<td colspan="3"><img src="{exp_name}/{case}/output.png" alt="Generated Image" style="height: 3*128px; width: 10*128px;"></td>
|
| 231 |
+
</tr>
|
| 232 |
+
<tr class="separator"><td colspan="4"></td></tr>
|
| 233 |
+
"""
|
| 234 |
+
html_content += """</table></body></html>"""
|
| 235 |
+
outfile = self.get_save_path(f"{exp_name}_page_{p+1}.html")
|
| 236 |
+
with open(outfile, "w") as file:
|
| 237 |
+
file.write(html_content)
|
| 238 |
+
|
| 239 |
+
def val_compute_metrics(self, pred, gt, cond):
|
| 240 |
+
loss_dict = {}
|
| 241 |
+
B = pred.shape[0]
|
| 242 |
+
as_ious = 0.0
|
| 243 |
+
rs_ious = 0.0
|
| 244 |
+
as_cdists = 0.0
|
| 245 |
+
rs_cdists = 0.0
|
| 246 |
+
for b in range(B):
|
| 247 |
+
gt_json = self.convert_json(gt[b], cond, b)
|
| 248 |
+
pred_json = self.convert_json(pred[b], cond, b)
|
| 249 |
+
scores = IoU_cDist(
|
| 250 |
+
pred_json,
|
| 251 |
+
gt_json,
|
| 252 |
+
num_states=5,
|
| 253 |
+
compare_handles=True,
|
| 254 |
+
iou_include_base=True,
|
| 255 |
+
)
|
| 256 |
+
as_ious += scores['AS-IoU']
|
| 257 |
+
rs_ious += scores['RS-IoU']
|
| 258 |
+
as_cdists += scores['AS-cDist']
|
| 259 |
+
rs_cdists += scores['RS-cDist']
|
| 260 |
+
|
| 261 |
+
as_ious /= B
|
| 262 |
+
rs_ious /= B
|
| 263 |
+
as_cdists /= B
|
| 264 |
+
rs_cdists /= B
|
| 265 |
+
|
| 266 |
+
loss_dict['val/AS-IoU'] = as_ious
|
| 267 |
+
loss_dict['val/RS-IoU'] = rs_ious
|
| 268 |
+
loss_dict['val/AS-cDist'] = as_cdists
|
| 269 |
+
loss_dict['val/RS-cDist'] = rs_cdists
|
| 270 |
+
|
| 271 |
+
return loss_dict
|
| 272 |
+
|
| 273 |
+
def _get_exp_name(self):
|
| 274 |
+
which_ds = self.hparams.get("test_which", 'pm')
|
| 275 |
+
is_pred_G = self.hparams.get("test_pred_G", False)
|
| 276 |
+
is_label_free = self.hparams.get("test_label_free", False)
|
| 277 |
+
guidance_scaler = self.hparams.get("guidance_scaler", 0)
|
| 278 |
+
# config saving directory
|
| 279 |
+
exp_postfix = f"_w={guidance_scaler}_{which_ds}"
|
| 280 |
+
if is_pred_G:
|
| 281 |
+
exp_postfix += "_pred_G"
|
| 282 |
+
if is_label_free:
|
| 283 |
+
exp_postfix += "_label_free"
|
| 284 |
+
|
| 285 |
+
exp_name = "epoch_" + str(self.current_epoch).zfill(3) + exp_postfix
|
| 286 |
+
return exp_name
|
systems/dino_dummy.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67b13dadf868704eb0e5a1b55355d54bce806b7f9d8d877cdf4142f759544bbd
|
| 3 |
+
size 1572992
|
systems/plot.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
import matplotlib
|
| 4 |
+
matplotlib.use('Agg')
|
| 5 |
+
import numpy as np
|
| 6 |
+
import networkx as nx
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from PIL import Image, ImageDraw
|
| 9 |
+
from matplotlib import pyplot as plt
|
| 10 |
+
from sklearn.decomposition import PCA
|
| 11 |
+
from singapo_utils.refs import graph_color_ref
|
| 12 |
+
|
| 13 |
+
def add_text(text, imgarr):
|
| 14 |
+
'''
|
| 15 |
+
Function to add text to image
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
- text (str): text to add
|
| 19 |
+
- imgarr (np.array): image array
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
- img (np.array): image array with text
|
| 23 |
+
'''
|
| 24 |
+
img = Image.fromarray(imgarr)
|
| 25 |
+
I = ImageDraw.Draw(img)
|
| 26 |
+
I.text((10, 10), text, fill='black')
|
| 27 |
+
return np.asarray(img)
|
| 28 |
+
|
| 29 |
+
def get_color(ref, n_nodes):
|
| 30 |
+
'''
|
| 31 |
+
Function to color the nodes
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
- ref (list): list of color reference
|
| 35 |
+
- n_nodes (int): number of nodes
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
- colors (list): list of colors
|
| 39 |
+
'''
|
| 40 |
+
N = len(ref)
|
| 41 |
+
colors = []
|
| 42 |
+
for i in range(n_nodes):
|
| 43 |
+
colors.append(np.array([[int(i) for i in ref[i%N][4:-1].split(',')]]) / 255.)
|
| 44 |
+
return colors
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def make_grid(images, cols=5):
|
| 48 |
+
"""
|
| 49 |
+
Arrange list of images into a N x cols grid.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
- images (list): List of Numpy arrays representing the images.
|
| 53 |
+
- cols (int): Number of columns for the grid.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
- grid (numpy array): Numpy array representing the image grid.
|
| 57 |
+
"""
|
| 58 |
+
# Determine the dimensions of each image
|
| 59 |
+
img_h, img_w, _ = images[0].shape
|
| 60 |
+
rows = len(images) // cols
|
| 61 |
+
|
| 62 |
+
# Initialize a blank canvas
|
| 63 |
+
grid = np.zeros((rows * img_h, cols * img_w, 3), dtype=images[0].dtype)
|
| 64 |
+
|
| 65 |
+
# Place each image onto the grid
|
| 66 |
+
for idx, img in enumerate(images):
|
| 67 |
+
y = (idx // cols) * img_h
|
| 68 |
+
x = (idx % cols) * img_w
|
| 69 |
+
grid[y: y + img_h, x: x + img_w] = img
|
| 70 |
+
|
| 71 |
+
return grid
|
| 72 |
+
|
| 73 |
+
def viz_graph(info_dict, res=256):
|
| 74 |
+
'''
|
| 75 |
+
Function to plot the directed graph
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
- info_dict (dict): output json containing the graph information
|
| 79 |
+
- res (int): resolution of the image
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
- img_arr (np.array): image array
|
| 83 |
+
'''
|
| 84 |
+
# build tree
|
| 85 |
+
tree = info_dict['diffuse_tree']
|
| 86 |
+
edges = []
|
| 87 |
+
for node in tree:
|
| 88 |
+
edges += [(node['id'], child) for child in node['children']]
|
| 89 |
+
G = nx.DiGraph()
|
| 90 |
+
G.add_edges_from(edges)
|
| 91 |
+
|
| 92 |
+
# plot tree
|
| 93 |
+
plt.figure(figsize=(res/100, res/100))
|
| 94 |
+
|
| 95 |
+
colors = get_color(graph_color_ref, len(tree))
|
| 96 |
+
pos = nx.nx_agraph.graphviz_layout(G, prog="twopi", args="")
|
| 97 |
+
node_order = sorted(G.nodes())
|
| 98 |
+
nx.draw(G, pos, node_color=colors, nodelist=node_order, edge_color='k', with_labels=False)
|
| 99 |
+
|
| 100 |
+
buf = BytesIO()
|
| 101 |
+
plt.savefig(buf, format="png", dpi=100)
|
| 102 |
+
buf.seek(0)
|
| 103 |
+
img = Image.open(buf)
|
| 104 |
+
img_arr = np.asarray(img)
|
| 105 |
+
buf.close()
|
| 106 |
+
plt.clf()
|
| 107 |
+
plt.close()
|
| 108 |
+
return img_arr[:, :, :3]
|
| 109 |
+
|
| 110 |
+
def viz_patch_feat_pca(feat):
|
| 111 |
+
pca = PCA(n_components=3)
|
| 112 |
+
pca.fit(feat)
|
| 113 |
+
feat_pca = pca.transform(feat)
|
| 114 |
+
|
| 115 |
+
t = np.array(feat_pca)
|
| 116 |
+
t_min = t.min(axis=0, keepdims=True)
|
| 117 |
+
t_max = t.max(axis=0, keepdims=True)
|
| 118 |
+
normalized_t = (t - t_min) / (t_max - t_min)
|
| 119 |
+
|
| 120 |
+
array = (normalized_t * 255).astype(np.uint8)
|
| 121 |
+
img_array = array.reshape(16, 16, 3)
|
| 122 |
+
return img_array
|
systems/system_origin.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
import torch
|
| 4 |
+
import subprocess
|
| 5 |
+
import numpy as np
|
| 6 |
+
import models
|
| 7 |
+
import systems
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from diffusers import DDPMScheduler
|
| 10 |
+
from systems.base import BaseSystem
|
| 11 |
+
from my_utils.lr_schedulers import LinearWarmupCosineAnnealingLR
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
@systems.register("sys_origin")
|
| 16 |
+
class SingapoSystem(BaseSystem):
|
| 17 |
+
"""Trainer for the B9 model, incorporating the classifier-free for image condition."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, hparams):
|
| 20 |
+
super().__init__(hparams)
|
| 21 |
+
self.model = models.make(hparams.model.name, hparams.model)
|
| 22 |
+
# configure the scheduler of DDPM
|
| 23 |
+
self.scheduler = DDPMScheduler(**self.hparams.scheduler.config)
|
| 24 |
+
# load the dummy DINO features
|
| 25 |
+
self.dummy_dino = np.load('systems/dino_dummy.npy').astype(np.float32)
|
| 26 |
+
# use the manual optimization
|
| 27 |
+
self.automatic_optimization = False
|
| 28 |
+
# save the hyperparameters
|
| 29 |
+
self.save_hyperparameters()
|
| 30 |
+
|
| 31 |
+
self.custom_logger = logging.getLogger(__name__)
|
| 32 |
+
self.custom_logger.setLevel(logging.INFO)
|
| 33 |
+
if self.global_rank == 0:
|
| 34 |
+
self.custom_logger.addHandler(logging.StreamHandler())
|
| 35 |
+
|
| 36 |
+
def load_cage_weights(self, pretrained_ckpt=None):
|
| 37 |
+
ckpt = torch.load(pretrained_ckpt)
|
| 38 |
+
state_dict = ckpt["state_dict"]
|
| 39 |
+
# remove the "model." prefix from the keys
|
| 40 |
+
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
|
| 41 |
+
# load the weights
|
| 42 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 43 |
+
# separate the weights of CAGE and our new modules
|
| 44 |
+
print("[INFO] loaded model weights of the pretrained CAGE.")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def fg_loss(self, all_attn_maps, loss_masks):
|
| 48 |
+
"""
|
| 49 |
+
Excite the attention maps within the object regions, while weaken the attention outside the object regions.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
all_attn_maps: cross-attention maps from all layers, shape (B*L, H, 160, 256)
|
| 53 |
+
loss_masks: object seg mask on the image patches, shape (B, 160, 256)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
loss: loss on the attention maps
|
| 57 |
+
"""
|
| 58 |
+
valid_mask = loss_masks['valid_nodes']
|
| 59 |
+
fg_mask = loss_masks['fg']
|
| 60 |
+
# get the number of layers and batch size
|
| 61 |
+
L = self.hparams.model.n_layers
|
| 62 |
+
H = all_attn_maps.shape[1]
|
| 63 |
+
# Reshape all the masks to the shape of the attention maps
|
| 64 |
+
valid_node = valid_mask[:, :, 0].unsqueeze(1).expand(-1, H, -1).unsqueeze(-1).expand(-1, -1, -1, 256).repeat(L, 1, 1, 1)
|
| 65 |
+
obj_region = fg_mask.unsqueeze(1).expand(-1, H, -1, -1).repeat(L, 1, 1, 1)
|
| 66 |
+
# construct masks for the object and non-object regions
|
| 67 |
+
fg_region = torch.logical_and(valid_node, obj_region)
|
| 68 |
+
bg_region = torch.logical_and(valid_node, ~obj_region)
|
| 69 |
+
# loss to excite the foreground regions
|
| 70 |
+
loss = 1. - all_attn_maps[fg_region].mean() + all_attn_maps[bg_region].mean()
|
| 71 |
+
return loss
|
| 72 |
+
|
| 73 |
+
def diffuse_process(self, inputs):
|
| 74 |
+
x = inputs["x"]
|
| 75 |
+
# Sample Gaussian noise
|
| 76 |
+
noise = torch.randn(x.shape, device=self.device, dtype=x.dtype)
|
| 77 |
+
# Sample a random timestep for each image
|
| 78 |
+
timesteps = torch.randint(
|
| 79 |
+
0,
|
| 80 |
+
self.scheduler.config.num_train_timesteps,
|
| 81 |
+
(x.shape[0],),
|
| 82 |
+
device=self.device,
|
| 83 |
+
dtype=torch.long,
|
| 84 |
+
)
|
| 85 |
+
# Add Gaussian noise to the input
|
| 86 |
+
noisy_x = self.scheduler.add_noise(x, noise, timesteps)
|
| 87 |
+
# update the inputs
|
| 88 |
+
inputs["noise"] = noise
|
| 89 |
+
inputs["timesteps"] = timesteps
|
| 90 |
+
inputs["noisy_x"] = noisy_x
|
| 91 |
+
|
| 92 |
+
def prepare_inputs(self, batch, mode='train', n_samples=1):
|
| 93 |
+
x, c, f = batch
|
| 94 |
+
|
| 95 |
+
cat = c["cat"] # object category
|
| 96 |
+
attr_mask = c["attr_mask"] # attention mask for local self-attention (follow the CAGE)
|
| 97 |
+
key_pad_mask = c["key_pad_mask"] # key padding mask for global self-attention (follow the CAGE)
|
| 98 |
+
graph_mask = c["adj_mask"] # attention mask for graph relation self-attention (follow the CAGE)
|
| 99 |
+
|
| 100 |
+
inputs = {}
|
| 101 |
+
if mode == 'train':
|
| 102 |
+
# the number of sampled timesteps per iteration
|
| 103 |
+
n_repeat = self.hparams.n_time_samples
|
| 104 |
+
# for sampling multiple timesteps
|
| 105 |
+
x = x.repeat(n_repeat, 1, 1)
|
| 106 |
+
cat = cat.repeat(n_repeat)
|
| 107 |
+
f = f.repeat(n_repeat, 1, 1)
|
| 108 |
+
key_pad_mask = key_pad_mask.repeat(n_repeat, 1, 1)
|
| 109 |
+
graph_mask = graph_mask.repeat(n_repeat, 1, 1)
|
| 110 |
+
attr_mask = attr_mask.repeat(n_repeat, 1, 1)
|
| 111 |
+
elif mode == 'val':
|
| 112 |
+
noisy_x = torch.randn(x.shape, device=x.device)
|
| 113 |
+
dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
|
| 114 |
+
inputs["noisy_x"] = noisy_x
|
| 115 |
+
inputs["dummy_f"] = dummy_f
|
| 116 |
+
elif mode == 'test':
|
| 117 |
+
# for sampling multiple outputs
|
| 118 |
+
x = x.repeat(n_samples, 1, 1)
|
| 119 |
+
cat = cat.repeat(n_samples)
|
| 120 |
+
f = f.repeat(n_samples, 1, 1)
|
| 121 |
+
key_pad_mask = key_pad_mask.repeat(n_samples, 1, 1)
|
| 122 |
+
graph_mask = graph_mask.repeat(n_samples, 1, 1)
|
| 123 |
+
attr_mask = attr_mask.repeat(n_samples, 1, 1)
|
| 124 |
+
noisy_x = torch.randn(x.shape, device=x.device)
|
| 125 |
+
dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
|
| 126 |
+
inputs["noisy_x"] = noisy_x
|
| 127 |
+
inputs["dummy_f"] = dummy_f.repeat(1, 2, 1)
|
| 128 |
+
else:
|
| 129 |
+
raise ValueError(f"Invalid mode: {mode}")
|
| 130 |
+
|
| 131 |
+
inputs["x"] = x
|
| 132 |
+
inputs["f"] = f
|
| 133 |
+
inputs["cat"] = cat
|
| 134 |
+
inputs["key_pad_mask"] = key_pad_mask
|
| 135 |
+
inputs["graph_mask"] = graph_mask
|
| 136 |
+
inputs["attr_mask"] = attr_mask
|
| 137 |
+
|
| 138 |
+
return inputs
|
| 139 |
+
|
| 140 |
+
def prepare_loss_mask(self, batch):
|
| 141 |
+
x, c, _ = batch
|
| 142 |
+
n_repeat = self.hparams.n_time_samples # the number of sampled timesteps per iteration
|
| 143 |
+
|
| 144 |
+
# mask on the image patches for the foreground regions
|
| 145 |
+
# mask_fg = c["img_obj_mask"]
|
| 146 |
+
# if mask_fg is not None:
|
| 147 |
+
# mask_fg = mask_fg.repeat(n_repeat, 1, 1)
|
| 148 |
+
|
| 149 |
+
# mask on the valid nodes
|
| 150 |
+
index_tensor = torch.arange(x.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0) # (1, N)
|
| 151 |
+
valid_nodes = index_tensor < (c['n_nodes'] * 5).unsqueeze(-1)
|
| 152 |
+
mask_valid_nodes = valid_nodes.unsqueeze(-1).expand_as(x)
|
| 153 |
+
mask_valid_nodes = mask_valid_nodes.repeat(n_repeat, 1, 1)
|
| 154 |
+
|
| 155 |
+
return {"fg": None, "valid_nodes": mask_valid_nodes}
|
| 156 |
+
|
| 157 |
+
def manage_cfg(self, inputs):
|
| 158 |
+
'''
|
| 159 |
+
Manage the classifier-free training for the image and graph condition.
|
| 160 |
+
The CFG for object category is managed by the model (i.e. the CombinedTimestepLabelEmbeddings module in norm1 for each attention block)
|
| 161 |
+
'''
|
| 162 |
+
img_drop_prob = self.hparams.get("img_drop_prob", 0.0)
|
| 163 |
+
graph_drop_prob = self.hparams.get("graph_drop_prob", 0.0)
|
| 164 |
+
drop_img, drop_graph = False, False
|
| 165 |
+
|
| 166 |
+
if img_drop_prob > 0.0:
|
| 167 |
+
drop_img = torch.rand(1) < img_drop_prob
|
| 168 |
+
if drop_img.item():
|
| 169 |
+
dummy_batch = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(inputs['f'])
|
| 170 |
+
inputs['f'] = dummy_batch # use the dummy DINO features
|
| 171 |
+
|
| 172 |
+
if graph_drop_prob > 0.0:
|
| 173 |
+
if not drop_img:
|
| 174 |
+
drop_graph = torch.rand(1) < graph_drop_prob
|
| 175 |
+
if drop_graph.item():
|
| 176 |
+
inputs['graph_mask'] = None # for varify the model only, replace with the below line later and retrain the model
|
| 177 |
+
# inputs['graph_mask'] = inputs['key_pad_mask'] # use the key padding mask
|
| 178 |
+
|
| 179 |
+
def compute_loss(self, batch, inputs, outputs):
|
| 180 |
+
loss_dict = {}
|
| 181 |
+
# loss_weight = self.hparams.get("loss_fg_weight", 1.0)
|
| 182 |
+
|
| 183 |
+
# prepare the loss masks
|
| 184 |
+
loss_masks = self.prepare_loss_mask(batch)
|
| 185 |
+
|
| 186 |
+
# diffusion model loss: MSE on the residual noise
|
| 187 |
+
loss_mse = F.mse_loss(outputs['noise_pred'] * loss_masks['valid_nodes'], inputs['noise'] * loss_masks['valid_nodes'])
|
| 188 |
+
# attention mask loss: BCE loss on the attention maps
|
| 189 |
+
# loss_fg = loss_weight * self.fg_loss(outputs['attn_maps'], loss_masks)
|
| 190 |
+
|
| 191 |
+
# total loss
|
| 192 |
+
loss = loss_mse
|
| 193 |
+
|
| 194 |
+
# log the losses
|
| 195 |
+
loss_dict["train/loss_mse"] = loss_mse
|
| 196 |
+
loss_dict["train/loss_total"] = loss
|
| 197 |
+
|
| 198 |
+
return loss, loss_dict
|
| 199 |
+
|
| 200 |
+
def training_step(self, batch, batch_idx):
|
| 201 |
+
# prepare the inputs and GT
|
| 202 |
+
inputs = self.prepare_inputs(batch, mode='train')
|
| 203 |
+
|
| 204 |
+
# manage the classifier-free training
|
| 205 |
+
self.manage_cfg(inputs)
|
| 206 |
+
|
| 207 |
+
# forward: diffusion process
|
| 208 |
+
self.diffuse_process(inputs)
|
| 209 |
+
|
| 210 |
+
# reverse: denoising process
|
| 211 |
+
outputs = self.model(
|
| 212 |
+
x=inputs['noisy_x'],
|
| 213 |
+
cat=inputs['cat'],
|
| 214 |
+
timesteps=inputs['timesteps'],
|
| 215 |
+
feat=inputs['f'],
|
| 216 |
+
key_pad_mask=inputs['key_pad_mask'],
|
| 217 |
+
graph_mask=inputs['graph_mask'],
|
| 218 |
+
attr_mask=inputs['attr_mask'],
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# compute the loss
|
| 222 |
+
loss, loss_dict = self.compute_loss(batch, inputs, outputs)
|
| 223 |
+
|
| 224 |
+
# manual backward
|
| 225 |
+
opt1, opt2 = self.optimizers()
|
| 226 |
+
opt1.zero_grad()
|
| 227 |
+
opt2.zero_grad()
|
| 228 |
+
self.manual_backward(loss)
|
| 229 |
+
opt1.step()
|
| 230 |
+
opt2.step()
|
| 231 |
+
|
| 232 |
+
if batch_idx % 20 == 0 and self.global_rank == 0:
|
| 233 |
+
now = datetime.now()
|
| 234 |
+
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
|
| 235 |
+
loss_str = f'Epoch:{self.current_epoch} | Step:{batch_idx:03d} | '
|
| 236 |
+
for key, value in loss_dict.items():
|
| 237 |
+
loss_str += f"{key}: {value.item():.4f} | "
|
| 238 |
+
self.custom_logger.info(now_str + ' | ' + loss_str)
|
| 239 |
+
# logging
|
| 240 |
+
# self.log_dict(loss_dict, sync_dist=True, on_step=True, on_epoch=False)
|
| 241 |
+
|
| 242 |
+
def on_train_epoch_end(self):
|
| 243 |
+
# step the lr scheduler every epoch
|
| 244 |
+
sch1, sch2 = self.lr_schedulers()
|
| 245 |
+
sch1.step()
|
| 246 |
+
sch2.step()
|
| 247 |
+
|
| 248 |
+
def inference(self, inputs, is_label_free=False):
|
| 249 |
+
device = inputs['x'].device
|
| 250 |
+
omega = self.hparams.get("guidance_scaler", 0)
|
| 251 |
+
noisy_x = inputs['noisy_x']
|
| 252 |
+
|
| 253 |
+
# set scheduler to denoise every 100 steps
|
| 254 |
+
self.scheduler.set_timesteps(100)
|
| 255 |
+
# denoising process
|
| 256 |
+
for t in self.scheduler.timesteps:
|
| 257 |
+
timesteps = torch.tensor([t], device=device)
|
| 258 |
+
outputs_cond = self.model(
|
| 259 |
+
x=noisy_x,
|
| 260 |
+
cat=inputs['cat'],
|
| 261 |
+
timesteps=timesteps,
|
| 262 |
+
feat=inputs['f'],
|
| 263 |
+
key_pad_mask=inputs['key_pad_mask'],
|
| 264 |
+
graph_mask=inputs['graph_mask'],
|
| 265 |
+
attr_mask=inputs['attr_mask'],
|
| 266 |
+
label_free=is_label_free,
|
| 267 |
+
) # take condtional image as input
|
| 268 |
+
if omega != 0:
|
| 269 |
+
outputs_free = self.model(
|
| 270 |
+
x=noisy_x,
|
| 271 |
+
cat=inputs['cat'],
|
| 272 |
+
timesteps=timesteps,
|
| 273 |
+
feat=inputs['dummy_f'],
|
| 274 |
+
key_pad_mask=inputs['key_pad_mask'],
|
| 275 |
+
graph_mask=inputs['graph_mask'],
|
| 276 |
+
attr_mask=inputs['attr_mask'],
|
| 277 |
+
label_free=is_label_free,
|
| 278 |
+
) # take the dummy DINO features for the condition-free mode
|
| 279 |
+
noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred']
|
| 280 |
+
else:
|
| 281 |
+
noise_pred = outputs_cond['noise_pred']
|
| 282 |
+
noisy_x = self.scheduler.step(noise_pred, t, noisy_x).prev_sample
|
| 283 |
+
|
| 284 |
+
return noisy_x
|
| 285 |
+
|
| 286 |
+
def validation_step(self, batch, batch_idx):
|
| 287 |
+
# prepare the inputs and GT
|
| 288 |
+
inputs = self.prepare_inputs(batch, mode='val')
|
| 289 |
+
# denoising process for inference
|
| 290 |
+
out = self.inference(inputs)
|
| 291 |
+
# compute the metrics
|
| 292 |
+
# new_out = torch.zeros_like(out).type_as(out).to(out.device)
|
| 293 |
+
# for b in range(out.shape[0]):
|
| 294 |
+
# for k in range(32):
|
| 295 |
+
# if out[b][(k + 1) * 6 - 1].mean() > 0.5:
|
| 296 |
+
# new_out[b][k * 6: (k + 1) * 6] = out[b][k * 6: (k + 1) * 6]
|
| 297 |
+
# zero center
|
| 298 |
+
|
| 299 |
+
# rescale
|
| 300 |
+
|
| 301 |
+
# ready
|
| 302 |
+
# out = new_out
|
| 303 |
+
# new_out = torch.zeros_like(out).type_as(out).to(out.device)
|
| 304 |
+
# for b in range(out.shape[0]):
|
| 305 |
+
# for k in range(32):
|
| 306 |
+
# min_aabb_diff = 1e10
|
| 307 |
+
# min_index = k
|
| 308 |
+
# aabb_center = (out[b][k * 6][:3] + out[b][k * 6 ][3:]) / 2
|
| 309 |
+
# for k_gt in range(32):
|
| 310 |
+
# aabb_gt_center = (batch[1][b][k_gt * 6][:3] + batch[1][b][k_gt * 6][3:]) / 2
|
| 311 |
+
# aabb_diff = torch.norm(aabb_center - aabb_gt_center)
|
| 312 |
+
# if aabb_diff < min_aabb_diff:
|
| 313 |
+
# min_aabb_diff = aabb_diff
|
| 314 |
+
# min_index = k_gt
|
| 315 |
+
# new_out[b][min_index * 6: (min_index + 1) * 6] = out[b][k * 6: (k + 1) * 6]
|
| 316 |
+
# out = new_out
|
| 317 |
+
|
| 318 |
+
log_dict = self.val_compute_metrics(out, inputs['x'], batch[1])
|
| 319 |
+
self.log_dict(log_dict, on_step=True)
|
| 320 |
+
|
| 321 |
+
# visualize the first 10 results
|
| 322 |
+
# self.save_val_img(out[:16], inputs['x'][:16], batch[1])
|
| 323 |
+
|
| 324 |
+
def test_step(self, batch, batch_idx):
|
| 325 |
+
# exp_name = self._get_exp_name()
|
| 326 |
+
# print(self.get_save_path(exp_name))
|
| 327 |
+
# if batch_idx > 2:
|
| 328 |
+
# return
|
| 329 |
+
# return
|
| 330 |
+
is_label_free = self.hparams.get("test_label_free", False)
|
| 331 |
+
exp_name = self._get_exp_name()
|
| 332 |
+
model_name = batch[1]["name"][0].replace("/", '@')
|
| 333 |
+
save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}"
|
| 334 |
+
print(save_dir)
|
| 335 |
+
if os.path.exists(self.get_save_path(f"{save_dir}/output.png")):
|
| 336 |
+
|
| 337 |
+
return
|
| 338 |
+
# prepare the inputs and GT
|
| 339 |
+
inputs = self.prepare_inputs(batch, mode='test', n_samples=5)
|
| 340 |
+
# denoising process for inference
|
| 341 |
+
out = self.inference(inputs, is_label_free)
|
| 342 |
+
# save the results
|
| 343 |
+
self.save_test_step(out, inputs['x'], batch[1], batch_idx)
|
| 344 |
+
|
| 345 |
+
def on_test_end(self):
|
| 346 |
+
# only run the single GPU
|
| 347 |
+
# if self.global_rank == 0:
|
| 348 |
+
# exp_name = self._get_exp_name()
|
| 349 |
+
# # retrieve parts
|
| 350 |
+
# subprocess.run(['python', 'scripts/mesh_retrieval/run_retrieve.py', '--src', self.get_save_path(exp_name), '--json_name', 'object.json', '--gt_data_root', '../singapo'])
|
| 351 |
+
# # save metrics
|
| 352 |
+
# if not self.hparams.get("test_no_GT", False):
|
| 353 |
+
# subprocess.run(['python', 'scripts/eval_metrics.py', '--exp_dir', self.get_save_path(exp_name), '--gt_root', '../acd_data/'])
|
| 354 |
+
# # save html
|
| 355 |
+
# self._save_html_end()
|
| 356 |
+
pass
|
| 357 |
+
|
| 358 |
+
def configure_optimizers(self):
|
| 359 |
+
self.cage_params = self.adapter_params = []
|
| 360 |
+
for name, param in self.model.named_parameters():
|
| 361 |
+
if "img" in name or "norm5" in name or "norm6" in name:
|
| 362 |
+
self.adapter_params.append(param)
|
| 363 |
+
else:
|
| 364 |
+
self.cage_params.append(param)
|
| 365 |
+
optimizer_adapter = torch.optim.AdamW(
|
| 366 |
+
self.adapter_params, **self.hparams.optimizer_adapter.args
|
| 367 |
+
)
|
| 368 |
+
lr_scheduler_adapter = LinearWarmupCosineAnnealingLR(
|
| 369 |
+
optimizer_adapter,
|
| 370 |
+
warmup_epochs=self.hparams.lr_scheduler_adapter.warmup_epochs,
|
| 371 |
+
max_epochs=self.hparams.lr_scheduler_adapter.max_epochs,
|
| 372 |
+
warmup_start_lr=self.hparams.lr_scheduler_adapter.warmup_start_lr,
|
| 373 |
+
eta_min=self.hparams.lr_scheduler_adapter.eta_min,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
optimizer_cage = torch.optim.AdamW(
|
| 377 |
+
self.cage_params, **self.hparams.optimizer_cage.args
|
| 378 |
+
)
|
| 379 |
+
lr_scheduler_cage = LinearWarmupCosineAnnealingLR(
|
| 380 |
+
optimizer_cage,
|
| 381 |
+
warmup_epochs=self.hparams.lr_scheduler_cage.warmup_epochs,
|
| 382 |
+
max_epochs=self.hparams.lr_scheduler_cage.max_epochs,
|
| 383 |
+
warmup_start_lr=self.hparams.lr_scheduler_cage.warmup_start_lr,
|
| 384 |
+
eta_min=self.hparams.lr_scheduler_cage.eta_min,
|
| 385 |
+
)
|
| 386 |
+
return (
|
| 387 |
+
{"optimizer": optimizer_adapter, "lr_scheduler": lr_scheduler_adapter},
|
| 388 |
+
{"optimizer": optimizer_cage, "lr_scheduler": lr_scheduler_cage},
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
|