|
|
|
|
|
|
|
|
use crate::{Error, Result}; |
|
|
use serde::{Deserialize, Serialize}; |
|
|
use std::path::{Path, PathBuf}; |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)] |
|
|
pub struct Config { |
|
|
|
|
|
pub gpt: GptConfig, |
|
|
|
|
|
pub vocoder: VocoderConfig, |
|
|
|
|
|
pub s2mel: S2MelConfig, |
|
|
|
|
|
pub dataset: DatasetConfig, |
|
|
|
|
|
pub emotions: EmotionConfig, |
|
|
|
|
|
pub inference: InferenceConfig, |
|
|
|
|
|
pub model_dir: PathBuf, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)] |
|
|
pub struct GptConfig { |
|
|
|
|
|
pub layers: usize, |
|
|
|
|
|
pub model_dim: usize, |
|
|
|
|
|
pub heads: usize, |
|
|
|
|
|
pub max_text_tokens: usize, |
|
|
|
|
|
pub max_mel_tokens: usize, |
|
|
|
|
|
pub stop_mel_token: usize, |
|
|
|
|
|
pub start_text_token: usize, |
|
|
|
|
|
pub start_mel_token: usize, |
|
|
|
|
|
pub num_mel_codes: usize, |
|
|
|
|
|
pub num_text_tokens: usize, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)] |
|
|
pub struct VocoderConfig { |
|
|
|
|
|
pub name: String, |
|
|
|
|
|
pub checkpoint: Option<PathBuf>, |
|
|
|
|
|
pub use_fp16: bool, |
|
|
|
|
|
pub use_deepspeed: bool, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)] |
|
|
pub struct S2MelConfig { |
|
|
|
|
|
pub checkpoint: PathBuf, |
|
|
|
|
|
pub preprocess: PreprocessConfig, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)] |
|
|
pub struct PreprocessConfig { |
|
|
|
|
|
pub sr: u32, |
|
|
|
|
|
pub n_fft: usize, |
|
|
|
|
|
pub hop_length: usize, |
|
|
|
|
|
pub win_length: usize, |
|
|
|
|
|
pub n_mels: usize, |
|
|
|
|
|
pub fmin: f32, |
|
|
|
|
|
pub fmax: f32, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)] |
|
|
pub struct DatasetConfig { |
|
|
|
|
|
pub bpe_model: PathBuf, |
|
|
|
|
|
pub vocab_size: usize, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)] |
|
|
pub struct EmotionConfig { |
|
|
|
|
|
pub num_dims: usize, |
|
|
|
|
|
pub num: Vec<usize>, |
|
|
|
|
|
pub matrix_path: Option<PathBuf>, |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)] |
|
|
pub struct InferenceConfig { |
|
|
|
|
|
pub device: String, |
|
|
|
|
|
pub use_fp16: bool, |
|
|
|
|
|
pub batch_size: usize, |
|
|
|
|
|
pub top_k: usize, |
|
|
|
|
|
pub top_p: f32, |
|
|
|
|
|
pub temperature: f32, |
|
|
|
|
|
pub repetition_penalty: f32, |
|
|
|
|
|
pub length_penalty: f32, |
|
|
} |
|
|
|
|
|
impl Default for Config { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
gpt: GptConfig::default(), |
|
|
vocoder: VocoderConfig::default(), |
|
|
s2mel: S2MelConfig::default(), |
|
|
dataset: DatasetConfig::default(), |
|
|
emotions: EmotionConfig::default(), |
|
|
inference: InferenceConfig::default(), |
|
|
model_dir: PathBuf::from("models"), |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Default for GptConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
layers: 8, |
|
|
model_dim: 512, |
|
|
heads: 8, |
|
|
max_text_tokens: 120, |
|
|
max_mel_tokens: 250, |
|
|
stop_mel_token: 8193, |
|
|
start_text_token: 8192, |
|
|
start_mel_token: 8192, |
|
|
num_mel_codes: 8194, |
|
|
num_text_tokens: 6681, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Default for VocoderConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
name: "bigvgan_v2_22khz_80band_256x".into(), |
|
|
checkpoint: None, |
|
|
use_fp16: true, |
|
|
use_deepspeed: false, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Default for S2MelConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
checkpoint: PathBuf::from("models/s2mel.onnx"), |
|
|
preprocess: PreprocessConfig::default(), |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Default for PreprocessConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
sr: 22050, |
|
|
n_fft: 1024, |
|
|
hop_length: 256, |
|
|
win_length: 1024, |
|
|
n_mels: 80, |
|
|
fmin: 0.0, |
|
|
fmax: 8000.0, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Default for DatasetConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
bpe_model: PathBuf::from("models/bpe.model"), |
|
|
vocab_size: 6681, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Default for EmotionConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
num_dims: 8, |
|
|
num: vec![5, 6, 8, 6, 5, 4, 7, 6], |
|
|
matrix_path: Some(PathBuf::from("models/emotion_matrix.safetensors")), |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Default for InferenceConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
device: "cpu".into(), |
|
|
use_fp16: false, |
|
|
batch_size: 1, |
|
|
top_k: 50, |
|
|
top_p: 0.95, |
|
|
temperature: 1.0, |
|
|
repetition_penalty: 1.0, |
|
|
length_penalty: 1.0, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl Config { |
|
|
|
|
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
|
|
let path = path.as_ref(); |
|
|
if !path.exists() { |
|
|
return Err(Error::FileNotFound(path.display().to_string())); |
|
|
} |
|
|
|
|
|
let content = std::fs::read_to_string(path)?; |
|
|
let config: Config = serde_yaml::from_str(&content)?; |
|
|
Ok(config) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> { |
|
|
let content = serde_yaml::to_string(self) |
|
|
.map_err(|e| Error::Config(format!("Failed to serialize config: {}", e)))?; |
|
|
std::fs::write(path, content)?; |
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Self> { |
|
|
let path = path.as_ref(); |
|
|
if !path.exists() { |
|
|
return Err(Error::FileNotFound(path.display().to_string())); |
|
|
} |
|
|
|
|
|
let content = std::fs::read_to_string(path)?; |
|
|
let config: Config = serde_json::from_str(&content)?; |
|
|
Ok(config) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn create_default<P: AsRef<Path>>(path: P) -> Result<Self> { |
|
|
let config = Config::default(); |
|
|
config.save(path)?; |
|
|
Ok(config) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn validate(&self) -> Result<()> { |
|
|
|
|
|
if !self.model_dir.exists() { |
|
|
log::warn!( |
|
|
"Model directory does not exist: {}", |
|
|
self.model_dir.display() |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
if self.gpt.layers == 0 { |
|
|
return Err(Error::Config("GPT layers must be > 0".into())); |
|
|
} |
|
|
if self.gpt.model_dim == 0 { |
|
|
return Err(Error::Config("GPT model_dim must be > 0".into())); |
|
|
} |
|
|
if self.gpt.heads == 0 { |
|
|
return Err(Error::Config("GPT heads must be > 0".into())); |
|
|
} |
|
|
if !self.gpt.model_dim.is_multiple_of(self.gpt.heads) { |
|
|
return Err(Error::Config( |
|
|
"GPT model_dim must be divisible by heads".into(), |
|
|
)); |
|
|
} |
|
|
|
|
|
|
|
|
if self.s2mel.preprocess.sr == 0 { |
|
|
return Err(Error::Config("Sample rate must be > 0".into())); |
|
|
} |
|
|
if self.s2mel.preprocess.n_fft == 0 { |
|
|
return Err(Error::Config("n_fft must be > 0".into())); |
|
|
} |
|
|
if self.s2mel.preprocess.hop_length == 0 { |
|
|
return Err(Error::Config("hop_length must be > 0".into())); |
|
|
} |
|
|
|
|
|
|
|
|
if self.inference.temperature <= 0.0 { |
|
|
return Err(Error::Config("Temperature must be > 0".into())); |
|
|
} |
|
|
if self.inference.top_p <= 0.0 || self.inference.top_p > 1.0 { |
|
|
return Err(Error::Config("top_p must be in (0, 1]".into())); |
|
|
} |
|
|
|
|
|
Ok(()) |
|
|
} |
|
|
} |
|
|
|