""" Authors: Zexu Pan, Shengkui Zhao """ import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import autocast, GradScaler import soundfile as sf import librosa import tempfile import os import subprocess from tqdm import tqdm from huggingface_hub import hf_hub_download from huggingface_hub import snapshot_download import numpy as np import ffmpeg class SpeechModel: def __init__(self, args): if torch.cuda.is_available(): print('GPU is found and used!') self.device = torch.device('cuda') else: # If no GPU is detected, use the CPU args.use_cuda = 0 self.device = torch.device('cpu') self.args = args self.model = None self.name = None self.data = {} def get_free_gpu(self): try: # Run nvidia-smi to query GPU memory usage and free memory result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE) gpu_info = result.stdout.decode('utf-8').strip().split('\n') free_gpu = None max_free_memory = 0 for i, info in enumerate(gpu_info): used, free = map(int, info.split(',')) if free > max_free_memory: max_free_memory = free free_gpu = i return free_gpu except Exception as e: print(f"Error finding free GPU: {e}") return None def load_model(self): checkpoint_path = hf_hub_download(repo_id=f"alibabasglab/{self.args.model_name}", filename="last_checkpoint.pt") # Load the checkpoint file into memory (map_location ensures compatibility with different devices) checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) # Load the model's state dictionary (weights and biases) into the current model if 'model' in checkpoint: pretrained_model = checkpoint['model'] else: pretrained_model = checkpoint state = self.model.state_dict() for key in state.keys(): if key in pretrained_model and state[key].shape == pretrained_model[key].shape: state[key] = pretrained_model[key] elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape: state[key] = pretrained_model[key.replace('module.', '')] elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape: state[key] = pretrained_model['module.'+key] else: raise NameError(f'{key} not loaded') self.model.load_state_dict(state) print(f'Successfully loaded model weights for decoding') def process(self, file_path, text): orig_audio = self.load_data(file_path) text = [text] with torch.no_grad(): chunk_size = 160000 # 240000 print(orig_audio.shape) if orig_audio.shape[0] > chunk_size: output_audio = torch.zeros(1,orig_audio.shape[0]) for itr in range(0, orig_audio.shape[0]//chunk_size): output_audio[:,chunk_size*itr:chunk_size*(itr+1)] = self.model(orig_audio[chunk_size*itr:chunk_size*(itr+1)], text, self.device) output_audio[:,-chunk_size:] = self.model(orig_audio[-chunk_size:], text, self.device) else: output_audio = self.model(orig_audio, text, self.device) output_audio = output_audio.detach().squeeze().cpu().numpy() # residual_audio = residual_audio.detach().squeeze().cpu().numpy() residual_audio = orig_audio - output_audio return orig_audio, output_audio, residual_audio def _audioread(self, path, sampling_rate): data, fs = sf.read(path, dtype='float32') if len(data.shape) >1: if data.shape[0] > data.shape[1]: data = data[:, 0] else: data = data[0, :] if fs != sampling_rate: data = librosa.resample(data, orig_sr=fs, target_sr=sampling_rate) max_val = np.max(np.abs(data)) if max_val > 1: data /= max_val return data def _videoread(self, path, sampling_rate): try: # Use ffmpeg to extract audio and output raw PCM data process = ( ffmpeg .input(path) .output('pipe:', format='wav', ar=sampling_rate, ac=1) .run(capture_stdout=True, capture_stderr=True) ) # Read the audio data from the raw output audio_data = np.frombuffer(process[0], dtype=np.int16) # Normalize to [-1, 1] if needed (optional, depending on your use case) audio_data = audio_data.astype(np.float32) / 32768.0 max_val = np.max(np.abs(audio_data)) if max_val > 1: audio_data /= max_val return audio_data except ffmpeg.Error as e: print(f"Error loading audio from video: {e.stderr.decode()}") return None, None def load_data(self, file_path): """ Detect whether the file is audio or video, then process it. - Audio: Load using `soundfile`. - Video: Extract audio and resample to 16 kHz. """ # Check if the file exists if not os.path.isfile(file_path): raise FileNotFoundError(f"File not found: {file_path}") # Supported audio and video extensions audio_extensions = ['.wav', '.flac', '.mp3', '.ogg', '.mat'] video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.webm'] _, ext = os.path.splitext(file_path) ext = ext.lower() if ext in audio_extensions: # Handle audio files print(f"Processing audio file: {file_path}") data = self._audioread(file_path, self.args.sampling_rate) return data elif ext in video_extensions: # Handle video files print(f"Processing video file: {file_path}") data = self._videoread(file_path, self.args.sampling_rate) return data else: raise ValueError(f"Unsupported file type: {file_path}") class select_network(nn.Module): def __init__(self, args): super(select_network, self).__init__() self.args = args from models.tflocoformer.tflocoformer_separator import TFLocoformer self.sep_network = TFLocoformer(args) print(f'{args.model_name} running.') import os from transformers import AutoTokenizer, T5EncoderModel model_path = snapshot_download(repo_id="alibabasglab/t5-base") model_path = os.path.join(model_path, "t5-base") # model_path = hf_hub_download(repo_id="alibabasglab/t5-base", filename="t5-base") self.tokenizer =AutoTokenizer.from_pretrained(model_path, model_max_length=512) self.text_encoder = T5EncoderModel.from_pretrained(model_path) # os.environ["TOKENIZERS_PARALLELISM"] = "false" for param in self.text_encoder.parameters(): param.requires_grad = False from models.beats.BEATs import BEATs, BEATsConfig model_path = snapshot_download(repo_id="alibabasglab/beats") model_path = os.path.join(model_path, "BEATs_iter3_plus_AS2M.pt") checkpoint = torch.load(model_path) cfg = BEATsConfig(checkpoint['cfg']) self.BEATs_model = BEATs(cfg) self.BEATs_model.load_state_dict(checkpoint['model']) self.BEATs_model.eval() for param in self.BEATs_model.parameters(): param.requires_grad = False def forward(self, mixture, t_ref, device): mixture = torch.tensor(mixture).to(device) mixture = mixture.unsqueeze(0) text_input = self.tokenizer(t_ref, return_tensors="pt", truncation=True, padding="longest") text_input_ids = text_input["input_ids"].to(device) text_attention_mask = text_input["attention_mask"].to(device) text_len = torch.sum(text_attention_mask, dim=1) text_embedding = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state t_ref = (text_embedding.clone().detach(), text_attention_mask.clone().detach(), text_len.clone().detach()) with torch.no_grad(): padding_mask = torch.zeros_like(mixture).bool() a_ref = self.BEATs_model.extract_features(mixture, padding_mask=padding_mask)[0] a_ref = a_ref.transpose(1,2) return self.forword_step(mixture, t_ref, a_ref.clone().detach()) def forword_step(self, mixture, t_ref, a_ref): return self.sep_network(mixture, t_ref, a_ref) class network_wrapper(SpeechModel): def __init__(self, args): # Initialize the parent SpeechModel class super(network_wrapper, self).__init__(args) # Import the AV-MossFormer2 model for 16 kHz target speech enhancement # Initialize the model self.model = select_network(args) # Load pre-trained model checkpoint self.load_model() # Move model to the appropriate device (GPU/CPU) self.model.to(self.device) # Set the model to evaluation mode (no gradient calculation) self.model.eval()