DIPO / models /denoiser.py
xinjie.wang
init commit
c28dddb
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import torch
import models
from torch import nn
from diffusers.models.attention import Attention, FeedForward
from models.utils import (
PEmbeder,
FinalLayer,
VisAttnProcessor,
MyAdaLayerNormZero
)
class RAPCrossAttnBlock(nn.Module):
def __init__(self, dim, num_layers, num_heads, head_dim, dropout=0.0, img_emb_dims=None):
super().__init__()
self.layers = nn.ModuleList([
Attention(
query_dim=dim,
cross_attention_dim=dim,
heads=num_heads,
dim_head=head_dim,
dropout=dropout,
bias=True,
cross_attention_norm="layer_norm",
processor=VisAttnProcessor(),
)
for _ in range(num_layers)
])
self.norms = nn.ModuleList([
nn.LayerNorm(dim) for _ in range(num_layers)
])
img_emb_layers = []
for i in range(len(img_emb_dims) - 1):
img_emb_layers.append(nn.Linear(img_emb_dims[i], img_emb_dims[i + 1]))
img_emb_layers.append(nn.LeakyReLU(inplace=True))
img_emb_layers.pop(-1)
self.img_emb = nn.Sequential(*img_emb_layers)
self.init_img_emb_weights()
def init_img_emb_weights(self):
for m in self.img_emb.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_in")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, img_first, img_second):
"""
Inputs:
img_first: (B, Np, D)
img_second: (B, Np, D)
Output:
fused_feat: (B, Np, D)
"""
img_first = self.img_emb(img_first)
img_second = self.img_emb(img_second)
fused = img_second
for norm, attn in zip(self.norms, self.layers):
normed = norm(fused)
delta, _ = attn(normed, encoder_hidden_states=img_first, attention_mask=None)
fused = fused + delta # residual connection
return fused
class Attn_Block(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
activation_fn: str = "geglu",
num_embeds_ada_norm: int = None,
attention_bias: bool = False,
norm_elementwise_affine: bool = True,
final_dropout: bool = False,
class_dropout_prob: float = 0.0, # for classifier-free
img_emb_dims=None,
):
super().__init__()
self.norm1 = MyAdaLayerNormZero(dim, num_embeds_ada_norm, class_dropout_prob)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.norm4 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.norm5 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.norm6 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.local_attn = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
)
self.global_attn = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
)
self.graph_attn = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
)
self.img_attn = Attention(
query_dim=dim,
cross_attention_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_norm="layer_norm",
processor=VisAttnProcessor(), # to be removed for release model
)
self.img_attn_second = Attention(
query_dim=dim,
cross_attention_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_norm="layer_norm",
processor=VisAttnProcessor(), # to be removed for release model
)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# image embedding layers
layers = []
for i in range(len(img_emb_dims) - 1):
layers.append(nn.Linear(img_emb_dims[i], img_emb_dims[i + 1]))
layers.append(nn.LeakyReLU(inplace=True))
layers.pop(-1)
self.img_emb = nn.Sequential(*layers)
self.init_img_emb_weights()
def init_img_emb_weights(self):
for m in self.img_emb.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_in")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(
self,
hidden_states,
img_patches,
fuse_feat,
pad_mask,
attr_mask,
graph_mask,
timestep,
class_labels,
label_free=False,
):
# image patches embedding
img_emb = self.img_emb(img_patches)
# adaptive normalization, taken timestep and class_labels as input condition
norm_hidden_states, gate_1, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3 = (
self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype,
label_free=label_free
)
)
# local attribute self-attention
attr_out = self.local_attn(norm_hidden_states, attention_mask=attr_mask)
attr_out = gate_1.unsqueeze(1) * attr_out
hidden_states = hidden_states + attr_out
# global attribute self-attention
norm_hidden_states = self.norm2(hidden_states)
global_out = self.global_attn(norm_hidden_states, attention_mask=pad_mask)
global_out = gate_2.unsqueeze(1) * global_out
hidden_states = hidden_states + global_out
# graph relation self-attention
norm_hidden_states = self.norm3(hidden_states)
graph_out = self.graph_attn(norm_hidden_states, attention_mask=graph_mask)
graph_out = gate_3.unsqueeze(1) * graph_out
hidden_states = hidden_states + graph_out
img_first, img_second = img_emb.chunk(2, dim=1)
# cross attention with image patches
norm_hidden_states = self.norm4(hidden_states)
B, Na, D = norm_hidden_states.shape
Np = img_first.shape[1] # number of image patches
mode_num = Na // 32
reshaped = norm_hidden_states.reshape(B, Na // mode_num, mode_num, D)
bboxes = reshaped[:, :, 0, :] # (B, K, D)
# cross attention between bbox attributes and image patches
bbox_img_out, bbox_cross_attn_map = self.img_attn(
bboxes,
encoder_hidden_states=img_first,
attention_mask=None,
) # cross_attn_map: (B, n_head, K, Np)
# to reshape the cross_attn_map back to (B, n_head, Na*5, Np), reduntant for other attributes, fix later
# cross_attn_map_reshape = torch.zeros(size=(B, bbox_cross_attn_map.shape[1], Na // mode_num, mode_num, Np), device=bbox_cross_attn_map.device)
# cross_attn_map_reshape[:, :, :, 0, :] = bbox_cross_attn_map
# cross_attn_map = cross_attn_map_reshape.reshape(B, bbox_cross_attn_map.shape[1], Na, Np)
# assemble the output of cross attention with bbox attributes and other attributes
img_out = torch.empty(size=(B, Na // mode_num, mode_num, D), device=hidden_states.device, dtype=hidden_states.dtype)
img_out[:, :, 0, :] = bbox_img_out
img_out[:, :, 1:, :] = reshaped[:, :, 1:, :]
img_out = img_out.reshape(B, Na, D)
hidden_states = hidden_states + img_out
norm_hidden_states = self.norm6(hidden_states)
B, Na, D = norm_hidden_states.shape
Np = img_second.shape[1] # number of image patches
mode_num = Na // 32
reshaped = norm_hidden_states.reshape(B, Na // mode_num, mode_num, D)
joints = reshaped # (B, K, 4, D)
joints = joints.reshape(B, Na // mode_num * 5, D)
# cross attention between bbox attributes and image patches
joint_img_out, bbox_cross_attn_map = self.img_attn_second(
joints,
encoder_hidden_states=fuse_feat,
attention_mask=None,
) # cross_attn_map: (B, n_head, K*4, Np)
# to reshape the cross_attn_map back to (B, n_head, Na*5, Np), reduntant for other attributes, fix later
# cross_attn_map_reshape = torch.zeros(size=(B, bbox_cross_attn_map.shape[1], Na // mode_num, mode_num, Np), device=bbox_cross_attn_map.device)
# cross_attn_map_reshape[:, :, :, 1:5, :] = bbox_cross_attn_map.reshape(
# B, bbox_cross_attn_map.shape[1], Na // mode_num, 4, Np
# )
# cross_attn_map = cross_attn_map_reshape.reshape(B, bbox_cross_attn_map.shape[1], Na, Np)
# assemble the output of cross attention with bbox attributes and other attributes
img_out = torch.empty(size=(B, Na // mode_num, mode_num, D), device=hidden_states.device, dtype=hidden_states.dtype)
img_out = joint_img_out.reshape(B, Na // mode_num, 5, D)
img_out = img_out.reshape(B, Na, D)
hidden_states = hidden_states + img_out
# feed-forward
norm_hidden_states = self.norm5(hidden_states)
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
@models.register("denoiser")
class Denoiser(nn.Module):
"""
Denoiser based on CAGE's attribute attention block + our ICA module, with 4 sequential attentions: LA -> GA -> GRA -> ICA
Different image adapters for each layer.
The image cross attention is with key-padding masks (object mask, part mask)
*** The ICA only applies to the bbox attributes, not other attributes such as motion params.***
"""
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.K = self.hparams.get("K", 32)
in_ch = hparams.in_ch
attn_dim = hparams.attn_dim
mid_dim = attn_dim // 2
n_head = hparams.n_head
head_dim = attn_dim // n_head
num_embeds_ada_norm = 6 * attn_dim
# embedding layers for different node attributes
self.aabb_emb = nn.Sequential(
nn.Linear(in_ch, mid_dim),
nn.ReLU(inplace=True),
nn.Linear(mid_dim, attn_dim),
)
self.jaxis_emb = nn.Sequential(
nn.Linear(in_ch, mid_dim),
nn.ReLU(inplace=True),
nn.Linear(mid_dim, attn_dim),
)
self.range_emb = nn.Sequential(
nn.Linear(in_ch, mid_dim),
nn.ReLU(inplace=True),
nn.Linear(mid_dim, attn_dim),
)
self.label_emb = nn.Sequential(
nn.Linear(in_ch, mid_dim),
nn.ReLU(inplace=True),
nn.Linear(mid_dim, attn_dim),
)
self.jtype_emb = nn.Sequential(
nn.Linear(in_ch, mid_dim),
nn.ReLU(inplace=True),
nn.Linear(mid_dim, attn_dim),
)
# self.node_type_emb = nn.Sequential(
# nn.Linear(in_ch, mid_dim),
# nn.ReLU(inplace=True),
# nn.Linear(mid_dim, attn_dim),
# )
# positional encoding for nodes and attributes
self.pe_node = PEmbeder(self.K, attn_dim)
self.pe_attr = PEmbeder(self.hparams.mode_num, attn_dim)
# attention layers
self.attn_layers = nn.ModuleList(
[
Attn_Block(
dim=attn_dim,
num_attention_heads=n_head,
attention_head_dim=head_dim,
class_dropout_prob=hparams.get("cat_drop_prob", 0.0),
dropout=hparams.dropout,
activation_fn="geglu",
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=False,
norm_elementwise_affine=True,
final_dropout=False,
img_emb_dims=hparams.get("img_emb_dims", None),
)
for d in range(hparams.n_layers)
]
)
self.image_interaction = RAPCrossAttnBlock(
dim=attn_dim,
num_layers=6,
num_heads=n_head,
head_dim=head_dim,
dropout=hparams.dropout,
img_emb_dims=hparams.get("img_emb_dims", None),
)
self.final_layer = FinalLayer(attn_dim, in_ch)
def forward(
self,
x,
cat,
timesteps,
feat,
key_pad_mask=None,
graph_mask=None,
attr_mask=None,
label_free=False,
):
B = x.shape[0]
x = x.view(B, self.K, 5 * 6)
# embedding layers for different attributes
x_aabb = self.aabb_emb(x[..., :6])
x_jtype = self.jtype_emb(x[..., 6:12])
x_jaxis = self.jaxis_emb(x[..., 12:18])
x_range = self.range_emb(x[..., 18:24])
x_label = self.label_emb(x[..., 24:30])
# x_node_type = self.node_type_emb(x[..., 30:36])
# concatenate all attribute embeddings
x_ = torch.cat(
[x_aabb, x_jtype, x_jaxis, x_range, x_label], dim=2
) # (B, K, 6*attn_dim)
x_ = x_.view(B, self.K * self.hparams.mode_num, self.hparams.attn_dim)
# positional encoding for nodes and attributes
idx_attr = torch.tensor(
[0, 1, 2, 3, 4], device=x.device, dtype=torch.long
).repeat(self.K)
idx_node = torch.arange(
self.K, device=x.device, dtype=torch.long
).repeat_interleave(self.hparams.mode_num)
x_ = self.pe_attr(self.pe_node(x_, idx=idx_node), idx=idx_attr)
# init tensor to store attention maps
Np = feat.shape[1]
img_first, img_second = feat.chunk(2, dim=1)
fused_img_feat = self.image_interaction(img_first, img_second) # (B, Np, D)
# attention layers
for i, attn_layer in enumerate(self.attn_layers):
x_ = attn_layer(
hidden_states=x_,
img_patches=feat,
fuse_feat=fused_img_feat,
timestep=timesteps,
class_labels=cat,
pad_mask=key_pad_mask,
graph_mask=graph_mask,
attr_mask=attr_mask,
label_free=label_free,
)
y = self.final_layer(x_, timesteps, cat)
return {
'noise_pred': y,
'attn_maps': None,
}