# Copyright (c) 2019-present, Meta, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # First author is Simon Rouard. import random import typing as tp import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from einops import rearrange from torch.nn import TransformerEncoder, TransformerEncoderLayer def create_sin_embedding( length: int, dim: int, shift: int = 0, device="cpu", max_period=10000 ): # We aim for TBC format assert dim % 2 == 0 pos = shift + torch.arange(length, device=device).view(-1, 1, 1) half_dim = dim // 2 adim = torch.arange(dim // 2, device=device).view(1, 1, -1) phase = pos / (max_period ** (adim / (half_dim - 1))) return torch.cat( [ torch.cos(phase), torch.sin(phase), ], dim=-1, ) def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000): """ :param d_model: dimension of the model :param height: height of the positions :param width: width of the positions :return: d_model*height*width position matrix """ if d_model % 4 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd dimension (got dim={:d})".format(d_model) ) pe = torch.zeros(d_model, height, width) # Each dimension use half of d_model d_model = int(d_model / 2) div_term = torch.exp( torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model) ) pos_w = torch.arange(0.0, width).unsqueeze(1) pos_h = torch.arange(0.0, height).unsqueeze(1) pe[0:d_model:2, :, :] = ( torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) ) pe[1:d_model:2, :, :] = ( torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) ) pe[d_model::2, :, :] = ( torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) ) pe[d_model + 1:: 2, :, :] = ( torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) ) return pe[None, :].to(device) def create_sin_embedding_cape( length: int, dim: int, batch_size: int, mean_normalize: bool, augment: bool, # True during training max_global_shift: float = 0.0, # delta max max_local_shift: float = 0.0, # epsilon max max_scale: float = 1.0, device: str = "cpu", max_period: float = 10000.0, ): # We aim for TBC format assert dim % 2 == 0 pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1) pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1) if mean_normalize: pos -= torch.nanmean(pos, dim=0, keepdim=True) if augment: delta = np.random.uniform( -max_global_shift, +max_global_shift, size=[1, batch_size, 1] ) delta_local = np.random.uniform( -max_local_shift, +max_local_shift, size=[length, batch_size, 1] ) log_lambdas = np.random.uniform( -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1] ) pos = (pos + delta + delta_local) * np.exp(log_lambdas) pos = pos.to(device) half_dim = dim // 2 adim = torch.arange(dim // 2, device=device).view(1, 1, -1) phase = pos / (max_period ** (adim / (half_dim - 1))) return torch.cat( [ torch.cos(phase), torch.sin(phase), ], dim=-1, ).float() def get_causal_mask(length): pos = torch.arange(length) return pos > pos[:, None] def get_elementary_mask( T1, T2, mask_type, sparse_attn_window, global_window, mask_random_seed, sparsity, device, ): """ When the input of the Decoder has length T1 and the output T2 The mask matrix has shape (T2, T1) """ assert mask_type in ["diag", "jmask", "random", "global"] if mask_type == "global": mask = torch.zeros(T2, T1, dtype=torch.bool) mask[:, :global_window] = True line_window = int(global_window * T2 / T1) mask[:line_window, :] = True if mask_type == "diag": mask = torch.zeros(T2, T1, dtype=torch.bool) rows = torch.arange(T2)[:, None] cols = ( (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)) .long() .clamp(0, T1 - 1) ) mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) elif mask_type == "jmask": mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool) rows = torch.arange(T2 + 2)[:, None] t = torch.arange(0, int((2 * T1) ** 0.5 + 1)) t = (t * (t + 1) / 2).int() t = torch.cat([-t.flip(0)[:-1], t]) cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1) mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) mask = mask[1:-1, 1:-1] elif mask_type == "random": gene = torch.Generator(device=device) gene.manual_seed(mask_random_seed) mask = ( torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) > sparsity ) mask = mask.to(device) return mask def get_mask( T1, T2, mask_type, sparse_attn_window, global_window, mask_random_seed, sparsity, device, ): """ Return a SparseCSRTensor mask that is a combination of elementary masks mask_type can be a combination of multiple masks: for instance "diag_jmask_random" """ from xformers.sparse import SparseCSRTensor # create a list mask_types = mask_type.split("_") all_masks = [ get_elementary_mask( T1, T2, mask, sparse_attn_window, global_window, mask_random_seed, sparsity, device, ) for mask in mask_types ] final_mask = torch.stack(all_masks).sum(axis=0) > 0 return SparseCSRTensor.from_dense(final_mask[None]) class ScaledEmbedding(nn.Module): def __init__( self, num_embeddings: int, embedding_dim: int, scale: float = 1.0, boost: float = 3.0, ): super().__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dim) self.embedding.weight.data *= scale / boost self.boost = boost @property def weight(self): return self.embedding.weight * self.boost def forward(self, x): return self.embedding(x) * self.boost class LayerScale(nn.Module): """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). This rescales diagonaly residual outputs close to 0 initially, then learnt. """ def __init__(self, channels: int, init: float = 0, channel_last=False): """ channel_last = False corresponds to (B, C, T) tensors channel_last = True corresponds to (T, B, C) tensors """ super().__init__() self.channel_last = channel_last self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) self.scale.data[:] = init def forward(self, x): if self.channel_last: return self.scale * x else: return self.scale[:, None] * x class MyGroupNorm(nn.GroupNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x): """ x: (B, T, C) if num_groups=1: Normalisation on all T and C together for each B """ x = x.transpose(1, 2) return super().forward(x).transpose(1, 2) class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): def __init__( self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, group_norm=0, norm_first=False, norm_out=False, layer_norm_eps=1e-5, layer_scale=False, init_values=1e-4, device=None, dtype=None, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, auto_sparsity=False, sparsity=0.95, batch_first=False, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, device=device, dtype=dtype, ) self.sparse = sparse self.auto_sparsity = auto_sparsity if sparse: if not auto_sparsity: self.mask_type = mask_type self.sparse_attn_window = sparse_attn_window self.global_window = global_window self.sparsity = sparsity if group_norm: self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) self.norm_out = None if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) self.gamma_1 = ( LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() ) self.gamma_2 = ( LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() ) if sparse: self.self_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, auto_sparsity=sparsity if auto_sparsity else 0, ) self.__setattr__("src_mask", torch.zeros(1, 1)) self.mask_random_seed = mask_random_seed def forward(self, src, src_mask=None, src_key_padding_mask=None): """ if batch_first = False, src shape is (T, B, C) the case where batch_first=True is not covered """ device = src.device x = src T, B, C = x.shape if self.sparse and not self.auto_sparsity: assert src_mask is None src_mask = self.src_mask if src_mask.shape[-1] != T: src_mask = get_mask( T, T, self.mask_type, self.sparse_attn_window, self.global_window, self.mask_random_seed, self.sparsity, device, ) self.__setattr__("src_mask", src_mask) if self.norm_first: x = x + self.gamma_1( self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) ) x = x + self.gamma_2(self._ff_block(self.norm2(x))) if self.norm_out: x = self.norm_out(x) else: x = self.norm1( x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)) ) x = self.norm2(x + self.gamma_2(self._ff_block(x))) return x class CrossTransformerEncoderLayer(nn.Module): def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation=F.relu, layer_norm_eps: float = 1e-5, layer_scale: bool = False, init_values: float = 1e-4, norm_first: bool = False, group_norm: bool = False, norm_out: bool = False, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, sparsity=0.95, auto_sparsity=None, device=None, dtype=None, batch_first=False, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.sparse = sparse self.auto_sparsity = auto_sparsity if sparse: if not auto_sparsity: self.mask_type = mask_type self.sparse_attn_window = sparse_attn_window self.global_window = global_window self.sparsity = sparsity self.cross_attn: nn.Module self.cross_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1: nn.Module self.norm2: nn.Module self.norm3: nn.Module if group_norm: self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) else: self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm_out = None if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) self.gamma_1 = ( LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() ) self.gamma_2 = ( LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): self.activation = self._get_activation_fn(activation) else: self.activation = activation if sparse: self.cross_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, auto_sparsity=sparsity if auto_sparsity else 0) if not auto_sparsity: self.__setattr__("mask", torch.zeros(1, 1)) self.mask_random_seed = mask_random_seed def forward(self, q, k, mask=None): """ Args: q: tensor of shape (T, B, C) k: tensor of shape (S, B, C) mask: tensor of shape (T, S) """ device = q.device T, B, C = q.shape S, B, C = k.shape if self.sparse and not self.auto_sparsity: assert mask is None mask = self.mask if mask.shape[-1] != S or mask.shape[-2] != T: mask = get_mask( S, T, self.mask_type, self.sparse_attn_window, self.global_window, self.mask_random_seed, self.sparsity, device, ) self.__setattr__("mask", mask) if self.norm_first: x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask)) x = x + self.gamma_2(self._ff_block(self.norm3(x))) if self.norm_out: x = self.norm_out(x) else: x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask))) x = self.norm2(x + self.gamma_2(self._ff_block(x))) return x # self-attention block def _ca_block(self, q, k, attn_mask=None): x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0] return self.dropout1(x) # feed forward block def _ff_block(self, x): x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) def _get_activation_fn(self, activation): if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) # ----------------- MULTI-BLOCKS MODELS: ----------------------- # class CrossTransformerEncoder(nn.Module): # def __init__( # self, # args, # dim: int, # emb: str = "sin", # hidden_scale: float = 4.0, # num_heads: int = 8, # num_layers: int = 5, # cross_first: bool = False, # dropout: float = 0.0, # max_positions: int = 10000, # norm_in: bool = True, # norm_in_group: bool = False, # group_norm: int = False, # norm_first: bool = True, # norm_out: bool = True, # max_period: float = 10000.0, # weight_decay: float = 0.0, # lr: tp.Optional[float] = None, # layer_scale: bool = True, # gelu: bool = True, # sin_random_shift: int = 0, # weight_pos_embed: float = 1.0, # cape_mean_normalize: bool = True, # cape_augment: bool = True, # cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], # sparse_self_attn: bool = False, # sparse_cross_attn: bool = False, # mask_type: str = "diag", # mask_random_seed: int = 42, # sparse_attn_window: int = 500, # global_window: int = 100, # auto_sparsity: bool = False, # sparsity: float = 0.95, # ): # super().__init__() # """ # """ # assert dim % num_heads == 0 # self.num_heads = num_heads # hidden_dim = int(dim * hidden_scale) # self.args = args # self.num_layers = num_layers # # classic parity = 1 means that if idx%2 == 1 there is a # # classical encoder else there is a cross encoder # self.classic_parity = 1 if cross_first else 0 # self.emb = emb # self.max_period = max_period # self.weight_decay = weight_decay # self.weight_pos_embed = weight_pos_embed # self.sin_random_shift = sin_random_shift # if emb == "cape": # self.cape_mean_normalize = cape_mean_normalize # self.cape_augment = cape_augment # self.cape_glob_loc_scale = cape_glob_loc_scale # if emb == "scaled": # self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2) # self.lr = lr # activation: tp.Any = F.gelu if gelu else F.relu # self.norm_in: nn.Module # self.norm_in_t: nn.Module # if norm_in: # self.norm_in = nn.LayerNorm(dim) # self.norm_in_t = nn.LayerNorm(dim) # elif norm_in_group: # self.norm_in = MyGroupNorm(int(norm_in_group), dim) # self.norm_in_t = MyGroupNorm(int(norm_in_group), dim) # else: # self.norm_in = nn.Identity() # self.norm_in_t = nn.Identity() # # spectrogram layers # self.layers = nn.ModuleList() # # temporal layers # self.layers_t = nn.ModuleList() # kwargs_common = { # "d_model": dim, # "nhead": num_heads, # "dim_feedforward": hidden_dim, # "dropout": dropout, # "activation": activation, # "group_norm": group_norm, # "norm_first": norm_first, # "norm_out": norm_out, # "layer_scale": layer_scale, # "mask_type": mask_type, # "mask_random_seed": mask_random_seed, # "sparse_attn_window": sparse_attn_window, # "global_window": global_window, # "sparsity": sparsity, # "auto_sparsity": auto_sparsity, # "batch_first": True, # } # kwargs_classic_encoder = dict(kwargs_common) # kwargs_classic_encoder.update({ # "sparse": sparse_self_attn, # }) # kwargs_cross_encoder = dict(kwargs_common) # kwargs_cross_encoder.update({ # "sparse": sparse_cross_attn, # }) # for idx in range(num_layers): # if idx % 2 == self.classic_parity: # self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) # self.layers_t.append( # MyTransformerEncoderLayer(**kwargs_classic_encoder) # ) # else: # self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) # self.layers_t.append( # CrossTransformerEncoderLayer(**kwargs_cross_encoder) # ) # # audio llm features # if self.args.network_audio.add_feature in ['beats']: # if self.args.network_audio.add_feature_fusion == 'cat': # self.audio_feat_layers = nn.ModuleList() # self.audio_feat_layers_t = nn.ModuleList() # self.audio_feat_fusion_layers = nn.ModuleList() # self.audio_feat_fusion_layers_t = nn.ModuleList() # for idx in range(self.args.network_audio.fusion_num_layers): # if idx % 2 == self.classic_parity: # self.audio_feat_layers.append(nn.Linear(768, dim)) # self.audio_feat_layers_t.append(nn.Linear(768, dim)) # self.audio_feat_fusion_layers.append(nn.Linear(dim*2, dim)) # self.audio_feat_fusion_layers_t.append(nn.Linear(dim*2, dim)) # elif self.args.network_audio.add_feature_fusion == 'att': # self.audio_feat_fusion_layers = nn.ModuleList() # self.audio_feat_fusion_layers_t = nn.ModuleList() # self.audio_feat_ds = nn.Linear(768, dim) # self.audio_feat_ds_t = nn.Linear(768, dim) # self.pos_encoder = PositionalEncoding(dim, dropout=0.1) # for idx in range(self.args.network_audio.fusion_num_layers): # if idx % 2 == self.classic_parity: # self.audio_feat_fusion_layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) # self.audio_feat_fusion_layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) # else: # self.audio_feat_fusion_layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) # self.audio_feat_fusion_layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) # # text conditioning # if self.args.network_audio.fusion in ['film', 'film_cat', 'film_att']: # self.text_layers = nn.ModuleList() # self.text_layers_t = nn.ModuleList() # self.film_layers = nn.ModuleList() # self.film_layers_t = nn.ModuleList() # for idx in range(self.args.network_audio.fusion_num_layers): # if idx % 2 == self.classic_parity: # self.text_layers.append(nn.Linear(args.network_reference.emb_size, dim)) # self.text_layers_t.append(nn.Linear(args.network_reference.emb_size, dim)) # self.film_layers.append(FiLMLayer(dim, dim)) # self.film_layers_t.append(FiLMLayer(dim, dim)) # if self.args.network_audio.fusion in ['att', 'film_att']: # self.att_text_layers = nn.ModuleList() # self.att_text_layers_t = nn.ModuleList() # self.att_prompt_layers = nn.ModuleList() # self.att_prompt_layers_t = nn.ModuleList() # for idx in range(self.args.network_audio.fusion_num_layers): # if idx % 2 != self.classic_parity: # self.att_text_layers.append(nn.Linear(args.network_reference.emb_size, dim)) # self.att_text_layers_t.append(nn.Linear(args.network_reference.emb_size, dim)) # self.att_prompt_layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) # self.att_prompt_layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) # # isam layers # if self.args.network_audio.isam == 1: # self.isam_layers = nn.ModuleList() # self.isam_norm_layers = nn.ModuleList() # self.isam_layers_t = nn.ModuleList() # self.isam_norm_layers_t = nn.ModuleList() # for _ in range(num_layers): # encoder_layers = TransformerEncoderLayer(d_model=dim, nhead=1, dim_feedforward=dim*2, batch_first=True) # isam = TransformerEncoder(encoder_layers, num_layers=1) # isam_norm = nn.LayerNorm(dim) # self.isam_layers.append(isam) # self.isam_norm_layers.append(isam_norm) # encoder_layers = TransformerEncoderLayer(d_model=dim, nhead=1, dim_feedforward=dim*2, batch_first=True) # isam_t = TransformerEncoder(encoder_layers, num_layers=1) # isam_norm_t = nn.LayerNorm(dim) # self.isam_layers_t.append(isam_t) # self.isam_norm_layers_t.append(isam_norm_t) # def forward(self, x, xt, text, a_ref, text_embedding, text_attention_mask): # B, C, Fr, T1 = x.shape # pos_emb_2d = create_2d_sin_embedding( # C, Fr, T1, x.device, self.max_period # ) # (1, C, Fr, T1) # pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c") # x = rearrange(x, "b c fr t1 -> b (t1 fr) c") # x = self.norm_in(x) # x = x + self.weight_pos_embed * pos_emb_2d # B, C, T2 = xt.shape # xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C # pos_emb = self._get_pos_embedding(T2, B, C, x.device) # pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c") # xt = self.norm_in_t(xt) # xt = xt + self.weight_pos_embed * pos_emb # if self.args.network_audio.isam == 1 and self.args.network_audio.fusion in ['film', 'film_cat']: # text = torch.cat([text.unsqueeze(1), text.unsqueeze(1)], axis = 1).reshape(text.shape[0]*2, text.shape[1]) # fusion_idx = 0 # non_fusion_idx = 0 # for idx in range(self.num_layers): # # fusion of condition and extra information # if idx < self.args.network_audio.fusion_num_layers: # # audio foundation model feature fusion # if self.args.network_audio.add_feature in ['beats']: # if self.args.network_audio.add_feature_fusion == 'cat': # if idx % 2 == self.classic_parity: # a_feat = self.audio_feat_layers[fusion_idx](a_ref.transpose(1,2)) # a_feat_t = self.audio_feat_layers_t[fusion_idx](a_ref.transpose(1,2)) # if self.args.network_audio.isam == 1: # a_feat = torch.cat([a_feat.unsqueeze(1), a_feat.unsqueeze(1)], axis = 1).reshape(a_feat.shape[0]*2, a_feat.shape[1], a_feat.shape[2]) # a_feat_t = torch.cat([a_feat_t.unsqueeze(1), a_feat_t.unsqueeze(1)], axis = 1).reshape(a_feat_t.shape[0]*2, a_feat_t.shape[1], a_feat_t.shape[2]) # if self.args.network_audio.add_feature_fusion == 'cat': # a_feat = F.interpolate(a_feat.transpose(1,2), (x.shape[1]), mode='linear') # a_feat_t = F.interpolate(a_feat_t.transpose(1,2), (xt.shape[1]), mode='linear') # x = torch.cat((x, a_feat.transpose(1,2)),2) # xt = torch.cat((xt, a_feat_t.transpose(1,2)),2) # x = self.audio_feat_fusion_layers[fusion_idx](x) # xt = self.audio_feat_fusion_layers_t[fusion_idx](xt) # elif self.args.network_audio.add_feature_fusion == 'att': # if idx == 0: # a_feat = self.audio_feat_ds(a_ref.transpose(1,2)) # a_feat_t = self.audio_feat_ds_t(a_ref.transpose(1,2)) # a_feat = self.pos_encoder(a_feat.transpose(0,1)).transpose(0,1) # a_feat_t = self.pos_encoder(a_feat_t.transpose(0,1)).transpose(0,1) # if self.args.network_audio.isam == 1: # a_feat = torch.cat([a_feat.unsqueeze(1), a_feat.unsqueeze(1)], axis = 1).reshape(a_feat.shape[0]*2, a_feat.shape[1], a_feat.shape[2]) # a_feat_t = torch.cat([a_feat_t.unsqueeze(1), a_feat_t.unsqueeze(1)], axis = 1).reshape(a_feat_t.shape[0]*2, a_feat_t.shape[1], a_feat_t.shape[2]) # if idx % 2 == self.classic_parity: # x = self.audio_feat_fusion_layers[idx](x, a_feat) # xt = self.audio_feat_fusion_layers_t[idx](xt, a_feat_t) # else: # a_feat = self.audio_feat_fusion_layers[idx](a_feat) # a_feat_t = self.audio_feat_fusion_layers_t[idx](a_feat_t) # # film fusion for text # if self.args.network_audio.fusion in ['film', 'film_cat', 'film_att']: # if idx % 2 == self.classic_parity: # new_text = self.text_layers[fusion_idx](text) # new_text_t = self.text_layers_t[fusion_idx](text) # x = self.film_layers[fusion_idx](x, new_text) # xt = self.film_layers_t[fusion_idx](xt, new_text_t) # if self.args.network_audio.fusion in ['att', 'film_att']: # if idx % 2 != self.classic_parity: # att_new_text = self.att_text_layers[non_fusion_idx](text_embedding) # att_new_text_t = self.att_text_layers_t[non_fusion_idx](text_embedding) # text_attention_mask_f = text_attention_mask.unsqueeze(1).repeat(1,x.shape[1],1) # text_attention_mask_f = text_attention_mask_f.repeat(self.num_heads,1,1) # text_attention_mask_t = text_attention_mask.unsqueeze(1).repeat(1,xt.shape[1],1) # text_attention_mask_t = text_attention_mask_t.repeat(self.num_heads,1,1) # x = self.att_prompt_layers[non_fusion_idx](x, att_new_text, mask = text_attention_mask_f) # xt = self.att_prompt_layers_t[non_fusion_idx](xt, att_new_text_t, mask = text_attention_mask_t) # # main transformer # if idx % 2 == self.classic_parity: # x = self.layers[idx](x) # xt = self.layers_t[idx](xt) # else: # old_x = x # x = self.layers[idx](x, xt) # xt = self.layers_t[idx](xt, old_x) # # isam layer for outputing residual audio # if self.args.network_audio.isam == 1: # B, S, N = x.shape # inter_spk = x.contiguous().view(B//2, 2, S, N) # inter_spk = inter_spk.permute(0,2,3,1).contiguous().view(B//2*S, 2, N) # # chunk_size = 40000 # # for itr in range(0, inter_spk.shape[0]//chunk_size+1): # # inter_spk[chunk_size*itr:chunk_size*(itr+1),:,:] = self.isam_layers[idx](inter_spk[chunk_size*itr:chunk_size*(itr+1),:,:]) # inter_spk = self.isam_layers[idx](inter_spk) # inter_spk = inter_spk.contiguous().view(B//2, S, 2, N) # inter_spk = inter_spk.permute(0,2,1,3) # inter_spk = inter_spk.contiguous().view(B, S, N) # inter_spk = self.isam_norm_layers[idx](inter_spk) # x += inter_spk # B, S, N = xt.shape # inter_spk = xt.contiguous().view(B//2, 2, S, N) # inter_spk = inter_spk.permute(0,2,3,1).contiguous().view(B//2*S, 2, N) # # chunk_size = 40000 # # for itr in range(0, inter_spk.shape[0]//chunk_size+1): # # inter_spk[chunk_size*itr:chunk_size*(itr+1),:,:] = self.isam_layers_t[idx](inter_spk[chunk_size*itr:chunk_size*(itr+1),:,:]) # inter_spk = self.isam_layers_t[idx](inter_spk) # inter_spk = inter_spk.contiguous().view(B//2, S, 2, N) # inter_spk = inter_spk.permute(0,2,1,3) # inter_spk = inter_spk.contiguous().view(B, S, N) # inter_spk = self.isam_norm_layers_t[idx](inter_spk) # xt += inter_spk # if idx % 2 == self.classic_parity: # fusion_idx+=1 # else: # non_fusion_idx+=1 # x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1) # xt = rearrange(xt, "b t2 c -> b c t2") # return x, xt # def _get_pos_embedding(self, T, B, C, device): # if self.emb == "sin": # shift = random.randrange(self.sin_random_shift + 1) # pos_emb = create_sin_embedding( # T, C, shift=shift, device=device, max_period=self.max_period # ) # elif self.emb == "cape": # if self.training: # pos_emb = create_sin_embedding_cape( # T, # C, # B, # device=device, # max_period=self.max_period, # mean_normalize=self.cape_mean_normalize, # augment=self.cape_augment, # max_global_shift=self.cape_glob_loc_scale[0], # max_local_shift=self.cape_glob_loc_scale[1], # max_scale=self.cape_glob_loc_scale[2], # ) # else: # pos_emb = create_sin_embedding_cape( # T, # C, # B, # device=device, # max_period=self.max_period, # mean_normalize=self.cape_mean_normalize, # augment=False, # ) # elif self.emb == "scaled": # pos = torch.arange(T, device=device) # pos_emb = self.position_embeddings(pos)[:, None] # return pos_emb # def make_optim_group(self): # group = {"params": list(self.parameters()), "weight_decay": self.weight_decay} # if self.lr is not None: # group["lr"] = self.lr # return group # Attention Modules class MultiheadAttention(nn.Module): def __init__( self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, auto_sparsity=None, ): super().__init__() assert auto_sparsity is not None, "sanity check" self.num_heads = num_heads self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias) self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias) self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias) self.attn_drop = torch.nn.Dropout(dropout) self.proj = torch.nn.Linear(embed_dim, embed_dim, bias) self.proj_drop = torch.nn.Dropout(dropout) self.batch_first = batch_first self.auto_sparsity = auto_sparsity def forward( self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, ): if not self.batch_first: # N, B, C query = query.permute(1, 0, 2) # B, N_q, C key = key.permute(1, 0, 2) # B, N_k, C value = value.permute(1, 0, 2) # B, N_k, C B, N_q, C = query.shape B, N_k, C = key.shape q = ( self.q(query) .reshape(B, N_q, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) q = q.flatten(0, 1) k = ( self.k(key) .reshape(B, N_k, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) k = k.flatten(0, 1) v = ( self.v(value) .reshape(B, N_k, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) v = v.flatten(0, 1) if self.auto_sparsity: assert attn_mask is None x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity) else: x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop) x = x.reshape(B, self.num_heads, N_q, C // self.num_heads) x = x.transpose(1, 2).reshape(B, N_q, C) x = self.proj(x) x = self.proj_drop(x) if not self.batch_first: x = x.permute(1, 0, 2) return x, None def scaled_query_key_softmax(q, k, att_mask): from xformers.ops import masked_matmul q = q / (k.size(-1)) ** 0.5 att = masked_matmul(q, k.transpose(-2, -1), att_mask) att = torch.nn.functional.softmax(att, -1) return att def scaled_dot_product_attention(q, k, v, att_mask, dropout): att = scaled_query_key_softmax(q, k, att_mask=att_mask) att = dropout(att) y = att @ v return y def _compute_buckets(x, R): qq = torch.einsum('btf,bfhi->bhti', x, R) qq = torch.cat([qq, -qq], dim=-1) buckets = qq.argmax(dim=-1) return buckets.permute(0, 2, 1).byte().contiguous() def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None): # assert False, "The code for the custom sparse kernel is not ready for release yet." from xformers.ops import find_locations, sparse_memory_efficient_attention n_hashes = 32 proj_size = 4 query, key, value = [x.contiguous() for x in [query, key, value]] with torch.no_grad(): R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device) bucket_query = _compute_buckets(query, R) bucket_key = _compute_buckets(key, R) row_offsets, column_indices = find_locations( bucket_query, bucket_key, sparsity, infer_sparsity) return sparse_memory_efficient_attention( query, key, value, row_offsets, column_indices, attn_bias) class FiLMLayer(nn.Module): def __init__(self, in_channels, cond_channels): """ Feature-wise Linear Modulation (FiLM) layer Parameters: in_channels: The number of channels in the input feature maps. cond_channels: The number of channels in the conditioning input. """ super(FiLMLayer, self).__init__() self.in_channels = in_channels self.film = nn.Linear(cond_channels, in_channels * 2) def forward(self, x, c): """ Parameters: x (Tensor): The input feature maps with shape [batch_size, time, in_channels]. c (Tensor): The conditioning input with shape [batch_size, cond_channels]. Returns: Tensor: The modulated feature maps with the same shape as input x. """ c = c.unsqueeze(1) film_params = self.film(c) gamma, beta = torch.chunk(film_params, chunks=2, dim=-1) return gamma * x + beta class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): """ Args: x: Tensor, shape [seq_len, batch_size, embedding_dim] """ x = x + self.pe[:x.size(0)] return self.dropout(x)