EEYD / extract_everything.py
alibabasglab's picture
Update extract_everything.py
4aa542b verified
import argparse
import yamlargparse
import torch.nn as nn
from networks import network_wrapper
class main(nn.Module):
def __init__(self):
super(main, self).__init__()
def load_args_eeyd(self, model_name):
self.config_path = f'config/{model_name}.yaml'
parser = yamlargparse.ArgumentParser("Settings")
# General model and inference settings
parser.add_argument('--config', help='Config file path', action=yamlargparse.ActionConfigFile)
parser.add_argument('--mode', type=str, default='inference', help='Modes: train or inference')
parser.add_argument('--use-cuda', dest='use_cuda', default=1, type=int, help='Enable CUDA (1=True, 0=False)')
parser.add_argument('--num-gpu', dest='num_gpu', type=int, default=1, help='Number of GPUs to use')
parser.add_argument('--checkpoint-dir', dest='checkpoint_dir', type=str, default='checkpoints/EEYD_base', help='Checkpoint directory')
# Model-specific settings
parser.add_argument('--network_audio', type=dict)
parser.add_argument('--network_reference', type=dict)
parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate')
parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding')
parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Decoding chunk size')
parser.add_argument('--output_residual', type=int, default=0)
parser.add_argument('--mix_precision', type=int, default=0, help='whether to perform mix precision training')
# Parse arguments from the config file
self.args = parser.parse_args(['--config', self.config_path])
self.args.model_name = model_name
def __call__(self, model_name):
self.load_args_eeyd(model_name)
self.network = network_wrapper(self.args)
return self.network
class extract_everything:
def __init__(self, model_name="EEYD_locoformer"):
self.args_wrapper = main()
self.model = self.args_wrapper(model_name)
def __call__(self, input_wav, input_text_prompt):
return self.model.process(input_wav, input_text_prompt)