vibevoice-onnx-int4 / vibevoice_browser.js
FluffyBunnies's picture
Upload vibevoice_browser.js with huggingface_hub
4eb488b verified
/**
* VibeVoice TTS for Browser
*
* Loads ONNX models from HuggingFace and runs TTS inference.
*
* Models required:
* - tts_llm_int4.onnx (702 MB) - Text → Hidden States
* - diffusion_head_int4.onnx (25 MB) - Hidden States → Latents
* - vocoder_int4.onnx (339 MB) - Latents → Audio
*/
// ONNX Runtime Web will be loaded from CDN
let ort = null;
export class VibeVoiceTTS {
constructor() {
this.sessions = {};
this.config = null;
this.loaded = false;
}
/**
* Load model from HuggingFace
*/
static async from_pretrained(modelId, options = {}) {
const {
dtype = 'int4',
progress_callback = null,
} = options;
// Load ONNX Runtime if not already loaded
if (!ort) {
ort = await import('https://cdn.jsdelivr.net/npm/[email protected]/dist/esm/ort.min.js');
}
const instance = new VibeVoiceTTS();
// Build base URL
const baseUrl = modelId.startsWith('http')
? modelId
: `https://huggingface.co/${modelId}/resolve/main`;
// Load config
progress_callback?.({ status: 'loading', component: 'config', progress: 0 });
try {
const configResp = await fetch(`${baseUrl}/config.json`);
instance.config = await configResp.json();
} catch (e) {
console.warn('Could not load config, using defaults');
instance.config = {
audio: { sample_rate: 24000, vae_dim: 64 },
diffusion: { num_inference_steps: 20, latent_size: 64, hidden_size: 896 }
};
}
// Session options
const sessionOptions = {
executionProviders: ['wasm'],
graphOptimizationLevel: 'all',
};
// Model files to load
const models = {
tts_llm: dtype === 'fp32' ? 'tts_llm.onnx' : `tts_llm_${dtype}.onnx`,
diffusion_head: dtype === 'fp32' ? 'diffusion_head.onnx' : `diffusion_head_${dtype}.onnx`,
vocoder: dtype === 'fp32' ? 'vocoder.onnx' : `vocoder_${dtype}.onnx`,
};
// Load each model
const totalModels = Object.keys(models).length;
let loadedCount = 0;
for (const [name, filename] of Object.entries(models)) {
progress_callback?.({
status: 'loading',
component: name,
progress: (loadedCount / totalModels) * 100
});
try {
console.log(`Loading ${name} from ${baseUrl}/${filename}...`);
const response = await fetch(`${baseUrl}/${filename}`);
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
}
const buffer = await response.arrayBuffer();
instance.sessions[name] = await ort.InferenceSession.create(buffer, sessionOptions);
console.log(`✓ Loaded ${name} (${(buffer.byteLength / 1024 / 1024).toFixed(1)} MB)`);
loadedCount++;
} catch (e) {
console.error(`✗ Failed to load ${name}: ${e.message}`);
throw e;
}
}
progress_callback?.({ status: 'ready', progress: 100 });
instance.loaded = true;
return instance;
}
/**
* Simple tokenizer (character-level fallback)
* For production, use the actual Qwen2 tokenizer
*/
tokenize(text) {
// This is a placeholder - real implementation needs Qwen2 tokenizer
// For now, use simple character codes (won't produce good audio)
const tokens = [];
for (const char of text) {
tokens.push(char.charCodeAt(0) % 1000); // Simple mapping
}
return tokens;
}
/**
* Generate speech from text
*/
async generate(text, options = {}) {
if (!this.loaded) {
throw new Error('Model not loaded. Call from_pretrained first.');
}
const {
num_inference_steps = 20,
progress_callback = null,
} = options;
console.log(`Generating speech for: "${text.substring(0, 50)}..."`);
// Step 1: Tokenize
progress_callback?.({ stage: 'tokenize', progress: 0 });
const tokens = this.tokenize(text);
const seqLen = tokens.length;
// Step 2: Run LLM
progress_callback?.({ stage: 'llm', progress: 10 });
const inputIds = new ort.Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, seqLen]);
const attentionMask = new ort.Tensor('int64', BigInt64Array.from(Array(seqLen).fill(1n)), [1, seqLen]);
const positionIds = new ort.Tensor('int64', BigInt64Array.from([...Array(seqLen).keys()].map(BigInt)), [1, seqLen]);
const llmOutput = await this.sessions.tts_llm.run({
input_ids: inputIds,
attention_mask: attentionMask,
position_ids: positionIds,
});
// Get hidden states [batch, seq_len, hidden_size]
const hiddenStates = llmOutput.hidden_states;
console.log('LLM output shape:', hiddenStates.dims);
// Step 3: Run Diffusion for each frame
progress_callback?.({ stage: 'diffusion', progress: 30 });
const latentSize = this.config.diffusion?.latent_size || 64;
const numFrames = seqLen; // One latent frame per token (simplified)
// Initialize latents with noise
const allLatents = new Float32Array(numFrames * latentSize);
for (let frame = 0; frame < numFrames; frame++) {
// Get hidden state for this frame
const hiddenSize = this.config.diffusion?.hidden_size || 896;
const frameHidden = new Float32Array(hiddenSize);
for (let i = 0; i < hiddenSize; i++) {
frameHidden[i] = hiddenStates.data[frame * hiddenSize + i];
}
// Run diffusion denoising
let latent = new Float32Array(latentSize);
for (let i = 0; i < latentSize; i++) {
latent[i] = this.randomNormal();
}
const timesteps = this.getTimesteps(num_inference_steps);
for (let step = 0; step < timesteps.length; step++) {
const t = timesteps[step];
const diffusionOutput = await this.sessions.diffusion_head.run({
noisy_latent: new ort.Tensor('float32', latent, [1, latentSize]),
timestep: new ort.Tensor('float32', [t], [1]),
hidden_states: new ort.Tensor('float32', frameHidden, [1, hiddenSize]),
});
const vPred = diffusionOutput.v_prediction.data;
latent = this.denoisingStep(latent, vPred, t, timesteps[step + 1] || 0);
}
// Store frame latent
for (let i = 0; i < latentSize; i++) {
allLatents[frame * latentSize + i] = latent[i];
}
progress_callback?.({
stage: 'diffusion',
progress: 30 + (frame / numFrames) * 50
});
}
// Step 4: Run Vocoder
progress_callback?.({ stage: 'vocoder', progress: 80 });
// Reshape latents to [batch, vae_dim, seq_len]
const latentsTensor = new ort.Tensor('float32', allLatents, [1, latentSize, numFrames]);
const vocoderOutput = await this.sessions.vocoder.run({
latents: latentsTensor,
});
const audioData = vocoderOutput.audio.data;
console.log('Audio output length:', audioData.length);
progress_callback?.({ stage: 'done', progress: 100 });
return new Float32Array(audioData);
}
/**
* Get timesteps for diffusion (linear spacing from 999 to 0)
*/
getTimesteps(numSteps) {
const timesteps = [];
for (let i = 0; i < numSteps; i++) {
timesteps.push(Math.floor(999 * (1 - i / (numSteps - 1))));
}
return timesteps;
}
/**
* V-prediction denoising step
*/
denoisingStep(latent, vPred, t, tNext) {
const alpha = this.getAlphaCumprod(t);
const alphaNext = this.getAlphaCumprod(tNext);
const sigma = Math.sqrt(1 - alpha);
const sigmaNext = Math.sqrt(1 - alphaNext);
const result = new Float32Array(latent.length);
for (let i = 0; i < latent.length; i++) {
// V-prediction: v = alpha * noise - sigma * x0
// So: x0 = (alpha * latent - sigma * v) / (alpha^2 + sigma^2)
// Simplified: x0_pred = sqrt(alpha) * latent - sqrt(1-alpha) * v
const sqrtAlpha = Math.sqrt(alpha);
const x0Pred = sqrtAlpha * latent[i] - sigma * vPred[i];
// Move to next timestep
const sqrtAlphaNext = Math.sqrt(alphaNext);
result[i] = sqrtAlphaNext * x0Pred + sigmaNext * this.randomNormal() * 0.1;
}
return result;
}
/**
* Cosine schedule alpha_cumprod
*/
getAlphaCumprod(t) {
const s = 0.008;
const tNorm = t / 1000;
const f = Math.cos(((tNorm + s) / (1 + s)) * Math.PI / 2);
return f * f;
}
/**
* Random normal (Box-Muller)
*/
randomNormal() {
const u1 = Math.random();
const u2 = Math.random();
return Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
}
/**
* Play audio data
*/
static playAudio(audioData, sampleRate = 24000) {
const audioContext = new (window.AudioContext || window.webkitAudioContext)();
const buffer = audioContext.createBuffer(1, audioData.length, sampleRate);
buffer.copyToChannel(audioData, 0);
const source = audioContext.createBufferSource();
source.buffer = buffer;
source.connect(audioContext.destination);
source.start(0);
return { audioContext, buffer };
}
}
export default VibeVoiceTTS;