import argparse import yamlargparse import torch.nn as nn from networks import network_wrapper class main(nn.Module): def __init__(self): super(main, self).__init__() def load_args_eeyd(self, model_name): self.config_path = f'config/{model_name}.yaml' parser = yamlargparse.ArgumentParser("Settings") # General model and inference settings parser.add_argument('--config', help='Config file path', action=yamlargparse.ActionConfigFile) parser.add_argument('--mode', type=str, default='inference', help='Modes: train or inference') parser.add_argument('--use-cuda', dest='use_cuda', default=1, type=int, help='Enable CUDA (1=True, 0=False)') parser.add_argument('--num-gpu', dest='num_gpu', type=int, default=1, help='Number of GPUs to use') parser.add_argument('--checkpoint-dir', dest='checkpoint_dir', type=str, default='checkpoints/EEYD_base', help='Checkpoint directory') # Model-specific settings parser.add_argument('--network_audio', type=dict) parser.add_argument('--network_reference', type=dict) parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate') parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding') parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Decoding chunk size') parser.add_argument('--output_residual', type=int, default=0) parser.add_argument('--mix_precision', type=int, default=0, help='whether to perform mix precision training') # Parse arguments from the config file self.args = parser.parse_args(['--config', self.config_path]) self.args.model_name = model_name def __call__(self, model_name): self.load_args_eeyd(model_name) self.network = network_wrapper(self.args) return self.network class extract_everything: def __init__(self, model_name="EEYD_locoformer"): self.args_wrapper = main() self.model = self.args_wrapper(model_name) def __call__(self, input_wav, input_text_prompt): return self.model.process(input_wav, input_text_prompt)