Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import yamlargparse | |
| import torch.nn as nn | |
| from networks import network_wrapper | |
| class main(nn.Module): | |
| def __init__(self): | |
| super(main, self).__init__() | |
| def load_args_eeyd(self, model_name): | |
| self.config_path = f'config/{model_name}.yaml' | |
| parser = yamlargparse.ArgumentParser("Settings") | |
| # General model and inference settings | |
| parser.add_argument('--config', help='Config file path', action=yamlargparse.ActionConfigFile) | |
| parser.add_argument('--mode', type=str, default='inference', help='Modes: train or inference') | |
| parser.add_argument('--use-cuda', dest='use_cuda', default=1, type=int, help='Enable CUDA (1=True, 0=False)') | |
| parser.add_argument('--num-gpu', dest='num_gpu', type=int, default=1, help='Number of GPUs to use') | |
| parser.add_argument('--checkpoint-dir', dest='checkpoint_dir', type=str, default='checkpoints/EEYD_base', help='Checkpoint directory') | |
| # Model-specific settings | |
| parser.add_argument('--network_audio', type=dict) | |
| parser.add_argument('--network_reference', type=dict) | |
| parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate') | |
| parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding') | |
| parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Decoding chunk size') | |
| parser.add_argument('--output_residual', type=int, default=0) | |
| parser.add_argument('--mix_precision', type=int, default=0, help='whether to perform mix precision training') | |
| # Parse arguments from the config file | |
| self.args = parser.parse_args(['--config', self.config_path]) | |
| self.args.model_name = model_name | |
| def __call__(self, model_name): | |
| self.load_args_eeyd(model_name) | |
| self.network = network_wrapper(self.args) | |
| return self.network | |
| class extract_everything: | |
| def __init__(self, model_name="EEYD_locoformer"): | |
| self.args_wrapper = main() | |
| self.model = self.args_wrapper(model_name) | |
| def __call__(self, input_wav, input_text_prompt): | |
| return self.model.process(input_wav, input_text_prompt) | |