|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional, Tuple, Union, Callable |
|
|
from tqdm import tqdm |
|
|
import copy |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
|
|
|
from transformers.models.auto import AutoModel, AutoModelForCausalLM |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput |
|
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
|
from transformers import modeling_utils |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead |
|
|
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler |
|
|
|
|
|
from .configuration_vibevoice_streaming import VibeVoiceStreamingConfig |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: |
|
|
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] |
|
|
|
|
|
|
|
|
class BinaryClassifier(nn.Module): |
|
|
def __init__(self, hidden_size): |
|
|
super(BinaryClassifier, self).__init__() |
|
|
self.fc1 = nn.Linear(hidden_size, hidden_size) |
|
|
self.fc2 = nn.Linear(hidden_size, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = torch.relu(self.fc1(x)) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class SpeechConnector(nn.Module): |
|
|
def __init__(self, input_dim, output_dim): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(input_dim, output_dim) |
|
|
self.norm = LlamaRMSNorm(output_dim, eps=1e-6) |
|
|
self.fc2 = nn.Linear(output_dim, output_dim) |
|
|
|
|
|
def forward(self, features, **kwargs): |
|
|
x = self.fc1(features) |
|
|
x = self.norm(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class VibeVoiceStreamingPreTrainedModel(PreTrainedModel): |
|
|
config_class = VibeVoiceStreamingConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
_supports_cache_class = True |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
_supports_quantized_cache = True |
|
|
_supports_static_cache = True |
|
|
_supports_attention_backend = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, VibeVoiceDiffusionHead): |
|
|
module.initialize_weights() |
|
|
return |
|
|
|
|
|
|
|
|
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'): |
|
|
std = self.config.language_model_config.initializer_range |
|
|
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'): |
|
|
std = self.config.decoder_config.initializer_range |
|
|
else: |
|
|
std = 0.02 |
|
|
|
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.weight.data.fill_(1.0) |
|
|
module.bias.data.zero_() |
|
|
|
|
|
|
|
|
|
|
|
class VibeVoiceStreamingModel(VibeVoiceStreamingPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None: |
|
|
if isinstance(config.torch_dtype, str): |
|
|
dtype = getattr(torch, config.torch_dtype) |
|
|
else: |
|
|
dtype = config.torch_dtype |
|
|
else: |
|
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm_config = copy.deepcopy(config.decoder_config) |
|
|
lm_backbone_num_hidden_layers = getattr(lm_config, 'num_hidden_layers', 24) - config.tts_backbone_num_hidden_layers |
|
|
lm_config.num_hidden_layers = lm_backbone_num_hidden_layers |
|
|
self.language_model = AutoModel.from_config(lm_config) |
|
|
self.language_model.norm = nn.Identity() |
|
|
|
|
|
|
|
|
tts_lm_config = copy.deepcopy(lm_config) |
|
|
tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers |
|
|
self.tts_language_model = AutoModel.from_config(tts_lm_config) |
|
|
|
|
|
|
|
|
self.tts_input_types = nn.Embedding(num_embeddings=2, embedding_dim=config.decoder_config.hidden_size) |
|
|
|
|
|
|
|
|
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype) |
|
|
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype) |
|
|
|
|
|
|
|
|
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan'))) |
|
|
self.register_buffer('speech_bias_factor', torch.tensor(float('nan'))) |
|
|
|
|
|
|
|
|
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype) |
|
|
|
|
|
|
|
|
self.noise_scheduler = DPMSolverMultistepScheduler( |
|
|
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, |
|
|
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, |
|
|
prediction_type=config.diffusion_head_config.prediction_type |
|
|
) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
if hasattr(self.language_model, 'embed_tokens'): |
|
|
|
|
|
return self.language_model.embed_tokens |
|
|
|
|
|
for name, attr in self.language_model.fullmap.items(): |
|
|
if attr.orig_name == 'embed_tokens.weight': |
|
|
return getattr(self.language_model, name) |
|
|
assert False, 'should not arrive here' |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.language_model.embed_tokens = value |
|
|
|
|
|
def set_speech_tokenizers(self, acoustic_tokenizer=None): |
|
|
"""Set the speech tokenizers used for encoding and decoding speech.""" |
|
|
self.acoustic_tokenizer = acoustic_tokenizer |
|
|
|
|
|
|
|
|
if self.acoustic_tokenizer is not None: |
|
|
self.acoustic_tokenizer.eval() |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
""" |
|
|
Intentionally not implemented. |
|
|
|
|
|
This streaming model is split into two explicit submodules: |
|
|
- `language_model` for plain text processing (lower layers). |
|
|
- `tts_language_model` for TTS-related upper layers. |
|
|
|
|
|
We deliberately avoid a unified `forward` to prevent accidental calls |
|
|
that mix responsibilities. |
|
|
|
|
|
To use the model: |
|
|
- Call `self.language_model(...)` for text embeddings / hidden states. |
|
|
- Call `self.tts_language_model(...)` for the TTS portion. |
|
|
- Use the dedicated inference class for combined generation logic. |
|
|
""" |
|
|
raise RuntimeError( |
|
|
"VibeVoiceStreamingModel.forward is intentionally disabled. " |
|
|
"Use `model.language_model(...)` or `model.tts_language_model(...)` instead." |
|
|
) |
|
|
|
|
|
|
|
|
AutoModel.register(VibeVoiceStreamingConfig, VibeVoiceStreamingModel) |
|
|
|
|
|
__all__ = [ |
|
|
"VibeVoiceStreamingPreTrainedModel", |
|
|
"VibeVoiceStreamingModel", |
|
|
] |