# Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) # # SPDX-License-Identifier: Apache-2.0 import math from collections import OrderedDict from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from packaging.version import parse as V from rotary_embedding_torch import RotaryEmbedding from .espnet2.complex_utils import new_complex_like from .espnet2.abs_separator import AbsSeparator is_torch_2_0_plus = V(torch.__version__) >= V("2.0.0") from torch.nn import TransformerEncoder, TransformerEncoderLayer from .espnet2.stft_decoder import STFTDecoder from .espnet2.stft_encoder import STFTEncoder EPS = 1e-8 import copy def _clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class TFLocoformer(AbsSeparator): """TF-Locoformer model presented in [1]. Reference: [1] Kohei Saijo, Gordon Wichern, François G. Germain, Zexu Pan, and Jonathan Le Roux, "TF-Locoformer: Transformer with Local Modeling by Convolution for Speech Separation and Enhancement," in Proc. International Workshop on Acoustic Signal Enhancement (IWAENC), Sep. 2024. Args: input_dim: int placeholder, not used num_spk: int number of output sources/speakers. n_layers: int number of Locoformer blocks. emb_dim: int Size of hidden dimension in the encoding Conv2D. norm_type: str Normalization layer. Must be either "layernorm" or "rmsgroupnorm". num_groups: int Number of groups in RMSGroupNorm layer. tf_order: str Order of frequency and temporal modeling. Must be either "ft" or "tf". n_heads: int Number of heads in multi-head self-attention. flash_attention: bool Whether to use flash attention. Only compatible with half precision. ffn_type: str or list Feed-forward network (FFN)-type chosen from "conv1d" or "swiglu_conv1d". Giving the list (e.g., ["conv1d", "conv1d"]) makes the model Macaron-style. ffn_hidden_dim: int or list Number of hidden dimensions in FFN. Giving the list (e.g., [256, 256]) makes the model Macaron-style. conv1d_kernel: int Kernel size in Conv1d. conv1d_shift: int Shift size of Conv1d kernel. dropout: float Dropout probability. eps: float Small constant for normalization layer. """ def __init__( self, args, flash_attention: bool = False, # available when using mixed precision eps: float = 1.0e-5, ): super().__init__() assert is_torch_2_0_plus, "Support only pytorch >= 2.0.0" n_fft=args.network_audio.n_fft stride=args.network_audio.stride window=args.network_audio.window use_builtin_complex=args.network_audio.use_builtin_complex num_spk = args.network_audio.num_spk n_layers=args.network_audio.n_layers norm_type=args.network_audio.norm_type emb_dim=args.network_audio.emb_dim num_groups=args.network_audio.num_groups n_heads=args.network_audio.n_heads tf_order=args.network_audio.tf_order attention_dim=args.network_audio.attention_dim conv1d_kernel=args.network_audio.conv1d_kernel conv1d_shift=args.network_audio.conv1d_shift ffn_type=[args.network_audio.ffn_type] ffn_hidden_dim=[args.network_audio.ffn_hidden_dim] dropout=args.network_audio.dropout self.args = args assert n_fft % 2 == 0 n_freqs = n_fft // 2 + 1 # self.ref_channel = ref_channel self.enc = STFTEncoder( n_fft, n_fft, stride, window=window, use_builtin_complex=use_builtin_complex ) self.dec = STFTDecoder(n_fft, n_fft, stride, window=window) self._num_spk = num_spk self.n_layers = n_layers t_ksize = 3 ks, padding = (t_ksize, 3), (t_ksize // 2, 1) self.conv = nn.Sequential( nn.Conv2d(2, emb_dim, ks, padding=padding), nn.GroupNorm(1, emb_dim, eps=eps), # gLN ) assert attention_dim % n_heads == 0, (attention_dim, n_heads) rope_freq = RotaryEmbedding(attention_dim // n_heads) rope_time = RotaryEmbedding(attention_dim // n_heads) self.blocks = nn.ModuleList([]) for _ in range(n_layers): self.blocks.append( TFLocoformerBlock( args, rope_freq, rope_time, # general setup emb_dim=emb_dim, norm_type=norm_type, num_groups=num_groups, tf_order=tf_order, # self-attention related n_heads=n_heads, flash_attention=flash_attention, attention_dim=attention_dim, # ffn related ffn_type=ffn_type, ffn_hidden_dim=ffn_hidden_dim, conv1d_kernel=conv1d_kernel, conv1d_shift=conv1d_shift, dropout=dropout, eps=eps, ) ) self.deconv = nn.ConvTranspose2d(emb_dim, num_spk * 2, ks, padding=padding) fusion_layers = n_layers //2 # text self.ref_ds = nn.Linear(768, args.network_reference.emb_size) encoder_layers = TransformerEncoderLayer(d_model=args.network_reference.emb_size, nhead=2, dim_feedforward=args.network_reference.emb_size*2, batch_first=True) self.text_net = TransformerEncoder(encoder_layers, num_layers=5) self.summarize = nn.LSTM(args.network_reference.emb_size, args.network_reference.emb_size, num_layers=1, batch_first=True) self.text_layer = _clones(nn.Linear(args.network_reference.emb_size, emb_dim),fusion_layers) if self.args.network_reference.fusion in ['cat']: self.text_fusion = _clones(nn.Linear(emb_dim*2, emb_dim),fusion_layers) elif self.args.network_reference.fusion in ['film']: self.text_fusion = _clones(FiLMLayer(emb_dim, emb_dim),fusion_layers) else: raise NameError('Wrong text feature fusion selection') # audio llm features if self.args.network_audio.add_feature in ['beats']: if self.args.network_audio.add_feature_fusion == 'att': self.audio_feat_ds = nn.Linear(768, emb_dim) self.pos_encoder = PositionalEncoding(emb_dim, dropout=0.1) from .transformer import CrossTransformerEncoderLayer, MyTransformerEncoderLayer kwargs_common = { "d_model": emb_dim, "nhead": 4, "dim_feedforward": emb_dim*4, "dropout": 0.3, "activation": F.gelu, "auto_sparsity": False, "batch_first": True, } kwargs_cross_encoder = dict(kwargs_common) self.audio_feat_ca= _clones(CrossTransformerEncoderLayer(**kwargs_cross_encoder),fusion_layers) self.audio_feat_sa= _clones(MyTransformerEncoderLayer(**kwargs_cross_encoder),fusion_layers) else: raise NameError('Wrong audio feature fusion selection') def forward(self, input, ref, a_ref=None): n_samples = input.shape[1] assert len(input.shape) == 2 input = input[..., None] # [B, N, M] ilens = [input.shape[1] for i in range(input.shape[0])] ilens = torch.Tensor(ilens).int() mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1] mix_std_[mix_std_ == 0] = 1 input = input / mix_std_ # RMS normalization batch = self.enc(input, ilens)[0] # [B, T, M, F] batch0 = batch.transpose(1, 2) # [B, M, T, F] batch = torch.cat((batch0.real, batch0.imag), dim=1) # [B, 2*M, T, F] n_batch, _, n_frames, n_freqs = batch.shape with torch.cuda.amp.autocast(enabled=False): batch = self.conv(batch) # [B, -1, T, F] # if text cue is included text_embedding, text_attention_mask, text_len = ref text_embedding = self.ref_ds(text_embedding) text_attention_mask = (text_attention_mask==0) text_embedding = self.text_net(text_embedding, src_key_padding_mask=text_attention_mask) text_vector, _ = self.summarize(text_embedding) text_len = text_len-1 batch_indices = torch.arange(text_vector.size(0)) text = text_vector[batch_indices, text_len] # audio llm features if self.args.network_audio.add_feature in ['beats']: if self.args.network_audio.add_feature_fusion == 'att': a_ref = self.audio_feat_ds(a_ref.transpose(1,2)) a_ref = self.pos_encoder(a_ref.transpose(0,1)).transpose(0,1) text_fusion_cnt = 0 llm_fusion_cnt = 0 # separation for ii in range(self.n_layers): batch_cache = batch.clone() n_batch_t, channel_t, n_frame_t, n_freq_t = batch.shape # text conditioning if ii%2 == 1: new_text = self.text_layer[text_fusion_cnt](text) if self.args.network_reference.fusion in ['cat']: new_text = torch.repeat_interleave(new_text.unsqueeze(2), repeats=n_frame_t, dim=2) new_text = torch.repeat_interleave(new_text.unsqueeze(-1), n_freq_t, -1) batch = torch.cat((new_text,batch), axis=1) batch = self.text_fusion[text_fusion_cnt](batch.transpose(1,3)).transpose(1,3) elif self.args.network_reference.fusion in ['film']: # along time new_text = torch.repeat_interleave(new_text.unsqueeze(1), n_freq_t, 1) new_text = new_text.reshape(n_batch_t*n_freq_t,channel_t) batch = batch.transpose(1,3).reshape(n_batch_t*n_freq_t,n_frame_t,channel_t) batch = self.text_fusion[text_fusion_cnt](batch, new_text) batch = batch.reshape(n_batch_t,n_freq_t,n_frame_t,channel_t).transpose(1,3) text_fusion_cnt +=1 # audio llm features if ii%2 == 0: if self.args.network_audio.add_feature in ['beats']: if self.args.network_audio.add_feature_fusion == 'att': a_ref = self.audio_feat_sa[llm_fusion_cnt](a_ref) # along time new_a_ref = torch.repeat_interleave(a_ref.unsqueeze(1), n_freq_t, 1) new_a_ref = new_a_ref.reshape(n_batch_t*n_freq_t,a_ref.shape[1],channel_t) batch = batch.transpose(1,3).reshape(n_batch_t*n_freq_t,n_frame_t,channel_t) batch = self.audio_feat_ca[llm_fusion_cnt](batch, new_a_ref) batch = batch.reshape(n_batch_t,n_freq_t,n_frame_t,channel_t).transpose(1,3) llm_fusion_cnt +=1 # batch forward batch = self.blocks[ii](batch) # [B, -1, T, F] # # skip connection # if ii < (self.n_layers -1): # batch += batch_cache with torch.cuda.amp.autocast(enabled=False): batch = self.deconv(batch) # [B, num_spk*2, T, F] batch = batch.view([n_batch, self.num_spk, 2, n_frames, n_freqs]) batch = new_complex_like(batch0, (batch[:, :, 0], batch[:, :, 1])) batch = self.dec(batch.view(-1, n_frames, n_freqs), ilens)[0] # [B, n_srcs, -1] batch = self.pad2(batch.view([n_batch, self.num_spk, -1]), n_samples) batch = batch * mix_std_ # reverse the RMS normalization batch = batch.squeeze(1) return batch @property def num_spk(self): return self._num_spk @staticmethod def pad2(input_tensor, target_len): input_tensor = torch.nn.functional.pad( input_tensor, (0, target_len - input_tensor.shape[-1]) ) return input_tensor class TFLocoformerBlock(nn.Module): def __init__( self, args, rope_freq, rope_time, # general setup emb_dim=128, norm_type="rmsgrouporm", num_groups=4, tf_order="ft", # self-attention related n_heads=4, flash_attention=False, attention_dim=128, # ffn related ffn_type="swiglu_conv1d", ffn_hidden_dim=384, conv1d_kernel=4, conv1d_shift=1, dropout=0.0, eps=1.0e-5, ): super().__init__() assert tf_order in ["tf", "ft"], tf_order self.tf_order = tf_order self.conv1d_kernel = conv1d_kernel self.conv1d_shift = conv1d_shift self.freq_path = LocoformerBlock( rope_freq, # general setup emb_dim=emb_dim, norm_type=norm_type, num_groups=num_groups, # self-attention related n_heads=n_heads, flash_attention=flash_attention, attention_dim=attention_dim, # ffn related ffn_type=ffn_type, ffn_hidden_dim=ffn_hidden_dim, conv1d_kernel=conv1d_kernel, conv1d_shift=conv1d_shift, dropout=dropout, eps=eps, ) self.frame_path = LocoformerBlock( rope_time, # general setup emb_dim=emb_dim, norm_type=norm_type, num_groups=num_groups, # self-attention related n_heads=n_heads, flash_attention=flash_attention, attention_dim=attention_dim, # ffn related ffn_type=ffn_type, ffn_hidden_dim=ffn_hidden_dim, conv1d_kernel=conv1d_kernel, conv1d_shift=conv1d_shift, dropout=dropout, eps=eps, ) def forward(self, input): """TF-Locoformer forward. input: torch.Tensor Input tensor, (n_batch, channel, n_frame, n_freq) """ if self.tf_order == "ft": output = self.freq_frame_process(input) else: output = self.frame_freq_process(input) return output def freq_frame_process(self, input): output = input.movedim(1, -1) # (B, T, Q_old, H) output = self.freq_path(output) output = output.transpose(1, 2) # (B, F, T, H) output = self.frame_path(output) return output.transpose(-1, 1) def frame_freq_process(self, input): # Input tensor, (n_batch, hidden, n_frame, n_freq) output = input.transpose(1, -1) # (B, F, T, H) output = self.frame_path(output) output = output.transpose(1, 2) # (B, T, F, H) output = self.freq_path(output) return output.movedim(-1, 1) class LocoformerBlock(nn.Module): def __init__( self, rope, # general setup emb_dim=128, norm_type="rmsgrouporm", num_groups=4, # self-attention related n_heads=4, flash_attention=False, attention_dim=128, # ffn related ffn_type="swiglu_conv1d", ffn_hidden_dim=384, conv1d_kernel=4, conv1d_shift=1, dropout=0.0, eps=1.0e-5, ): super().__init__() FFN = { "conv1d": ConvDeconv1d, "swiglu_conv1d": SwiGLUConvDeconv1d, } Norm = { "layernorm": nn.LayerNorm, "rmsgroupnorm": RMSGroupNorm, } assert norm_type in Norm, norm_type self.macaron_style = isinstance(ffn_type, list) and len(ffn_type) == 2 if self.macaron_style: assert ( isinstance(ffn_hidden_dim, list) and len(ffn_hidden_dim) == 2 ), "Two FFNs required when using Macaron-style model" # initialize FFN self.ffn_norm = nn.ModuleList([]) self.ffn = nn.ModuleList([]) for f_type, f_dim in zip(ffn_type[::-1], ffn_hidden_dim[::-1]): assert f_type in FFN, f_type if norm_type == "rmsgroupnorm": self.ffn_norm.append(Norm[norm_type](num_groups, emb_dim, eps=eps)) else: self.ffn_norm.append(Norm[norm_type](emb_dim, eps=eps)) self.ffn.append( FFN[f_type]( emb_dim, f_dim, conv1d_kernel, conv1d_shift, dropout=dropout, ) ) # initialize self-attention if norm_type == "rmsgroupnorm": self.attn_norm = Norm[norm_type](num_groups, emb_dim, eps=eps) else: self.attn_norm = Norm[norm_type](emb_dim, eps=eps) self.attn = MultiHeadSelfAttention( emb_dim, attention_dim=attention_dim, n_heads=n_heads, rope=rope, dropout=dropout, flash_attention=flash_attention, ) def forward(self, x): """Locoformer block Forward. Args: x: torch.Tensor Input tensor, (n_batch, seq1, seq2, channel) seq1 (or seq2) is either the number of frames or freqs """ B, T, F, C = x.shape if self.macaron_style: # FFN before self-attention input_ = x output = self.ffn_norm[-1](x) # [B, T, F, C] output = self.ffn[-1](output) # [B, T, F, C] output = output + input_ else: output = x # Self-attention input_ = output output = self.attn_norm(output) output = output.reshape([B * T, F, C]) output = self.attn(output) output = output.reshape([B, T, F, C]) + input_ # FFN after self-attention input_ = output output = self.ffn_norm[0](output) # [B, T, F, C] output = self.ffn[0](output) # [B, T, F, C] output = output + input_ return output class MultiHeadSelfAttention(nn.Module): def __init__( self, emb_dim, attention_dim, n_heads=8, dropout=0.0, rope=None, flash_attention=False, ): super().__init__() self.n_heads = n_heads self.dropout = dropout self.rope = rope self.qkv = nn.Linear(emb_dim, attention_dim * 3, bias=False) self.aggregate_heads = nn.Sequential(nn.Linear(attention_dim, emb_dim, bias=False), nn.Dropout(dropout)) if flash_attention: self.flash_attention_config = dict(enable_flash=True, enable_math=False, enable_mem_efficient=False) else: self.flash_attention_config = dict(enable_flash=False, enable_math=True, enable_mem_efficient=True) def forward(self, input): # get query, key, and value query, key, value = self.get_qkv(input) # rotary positional encoding query, key = self.apply_rope(query, key) # pytorch 2.0 flash attention: q, k, v, mask, dropout, softmax_scale with torch.backends.cuda.sdp_kernel(**self.flash_attention_config): output = F.scaled_dot_product_attention( query=query, key=key, value=value, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, ) # (batch, head, seq_len, -1) output = output.transpose(1, 2) # (batch, seq_len, head, -1) output = output.reshape(output.shape[:2] + (-1,)) return self.aggregate_heads(output) def get_qkv(self, input): n_batch, seq_len = input.shape[:2] x = self.qkv(input).reshape(n_batch, seq_len, 3, self.n_heads, -1) x = x.movedim(-2, 1) # (batch, head, seq_len, 3, -1) query, key, value = x[..., 0, :], x[..., 1, :], x[..., 2, :] return query, key, value @torch.cuda.amp.autocast(enabled=False) def apply_rope(self, query, key): query = self.rope.rotate_queries_or_keys(query) key = self.rope.rotate_queries_or_keys(key) return query, key class ConvDeconv1d(nn.Module): def __init__(self, dim, dim_inner, conv1d_kernel, conv1d_shift, dropout=0.0, **kwargs): super().__init__() self.diff_ks = conv1d_kernel - conv1d_shift self.net = nn.Sequential( nn.Conv1d(dim, dim_inner, conv1d_kernel, stride=conv1d_shift), nn.SiLU(inplace=True), nn.Dropout(dropout), nn.ConvTranspose1d(dim_inner, dim, conv1d_kernel, stride=conv1d_shift), nn.Dropout(dropout), ) def forward(self, x): """ConvDeconv1d forward Args: x: torch.Tensor Input tensor, (n_batch, seq1, seq2, channel) seq1 (or seq2) is either the number of frames or freqs """ b, s1, s2, h = x.shape x = x.view(b * s1, s2, h) x = x.transpose(-1, -2) x = self.net(x).transpose(-1, -2) x = x[..., self.diff_ks // 2 : self.diff_ks // 2 + s2, :] return x.view(b, s1, s2, h) class SwiGLUConvDeconv1d(nn.Module): def __init__(self, dim, dim_inner, conv1d_kernel, conv1d_shift, dropout=0.0, **kwargs): super().__init__() self.conv1d = nn.Conv1d(dim, dim_inner * 2, conv1d_kernel, stride=conv1d_shift) self.swish = nn.SiLU() self.deconv1d = nn.ConvTranspose1d(dim_inner, dim, conv1d_kernel, stride=conv1d_shift) self.dropout = nn.Dropout(dropout) self.dim_inner = dim_inner self.diff_ks = conv1d_kernel - conv1d_shift self.conv1d_kernel = conv1d_kernel self.conv1d_shift = conv1d_shift def forward(self, x): """SwiGLUConvDeconv1d forward Args: x: torch.Tensor Input tensor, (n_batch, seq1, seq2, channel) seq1 (or seq2) is either the number of frames or freqs """ b, s1, s2, h = x.shape x = x.contiguous().view(b * s1, s2, h) x = x.transpose(-1, -2) # padding seq_len = ( math.ceil((s2 + 2 * self.diff_ks - self.conv1d_kernel) / self.conv1d_shift) * self.conv1d_shift + self.conv1d_kernel ) x = F.pad(x, (self.diff_ks, seq_len - s2 - self.diff_ks)) # conv-deconv1d x = self.conv1d(x) gate = self.swish(x[..., self.dim_inner :, :]) x = x[..., : self.dim_inner, :] * gate x = self.dropout(x) x = self.deconv1d(x).transpose(-1, -2) # cut necessary part x = x[..., self.diff_ks : self.diff_ks + s2, :] return self.dropout(x).view(b, s1, s2, h) class RMSGroupNorm(nn.Module): def __init__(self, num_groups, dim, eps=1e-8, bias=False): """ Root Mean Square Group Normalization (RMSGroupNorm). Unlike Group Normalization in vision, RMSGroupNorm is applied to each TF bin. Args: num_groups: int Number of groups dim: int Number of dimensions eps: float Small constant to avoid division by zero. bias: bool Whether to add a bias term. RMSNorm does not use bias. """ super().__init__() assert dim % num_groups == 0, (dim, num_groups) self.num_groups = num_groups self.dim_per_group = dim // self.num_groups self.gamma = nn.Parameter(torch.Tensor(dim).to(torch.float32)) nn.init.ones_(self.gamma) self.bias = bias if self.bias: self.beta = nn.Parameter(torch.Tensor(dim).to(torch.float32)) nn.init.zeros_(self.beta) self.eps = eps self.num_groups = num_groups @torch.cuda.amp.autocast(enabled=False) def forward(self, input): others = input.shape[:-1] input = input.view(others + (self.num_groups, self.dim_per_group)) # normalization norm_ = input.norm(2, dim=-1, keepdim=True) rms = norm_ * self.dim_per_group ** (-1.0 / 2) output = input / (rms + self.eps) # reshape and affine transformation output = output.view(others + (-1,)) output = output * self.gamma if self.bias: output = output + self.beta return output class FiLMLayer(nn.Module): def __init__(self, in_channels, cond_channels): """ Feature-wise Linear Modulation (FiLM) layer Parameters: in_channels: The number of channels in the input feature maps. cond_channels: The number of channels in the conditioning input. """ super(FiLMLayer, self).__init__() self.in_channels = in_channels self.film = nn.Linear(cond_channels, in_channels * 2) def forward(self, x, c): """ Parameters: x (Tensor): The input feature maps with shape [batch_size, time, in_channels]. c (Tensor): The conditioning input with shape [batch_size, cond_channels]. Returns: Tensor: The modulated feature maps with the same shape as input x. """ c = c.unsqueeze(1) film_params = self.film(c) gamma, beta = torch.chunk(film_params, chunks=2, dim=-1) return gamma * x + beta class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): """ Args: x: Tensor, shape [seq_len, batch_size, embedding_dim] """ x = x + self.pe[:x.size(0)] return self.dropout(x)