|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let ort = null; |
|
|
|
|
|
export class VibeVoiceTTS { |
|
|
constructor() { |
|
|
this.sessions = {}; |
|
|
this.config = null; |
|
|
this.loaded = false; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static async from_pretrained(modelId, options = {}) { |
|
|
const { |
|
|
dtype = 'int4', |
|
|
progress_callback = null, |
|
|
} = options; |
|
|
|
|
|
|
|
|
if (!ort) { |
|
|
ort = await import('https://cdn.jsdelivr.net/npm/[email protected]/dist/esm/ort.min.js'); |
|
|
} |
|
|
|
|
|
const instance = new VibeVoiceTTS(); |
|
|
|
|
|
|
|
|
const baseUrl = modelId.startsWith('http') |
|
|
? modelId |
|
|
: `https://huggingface.co/${modelId}/resolve/main`; |
|
|
|
|
|
|
|
|
progress_callback?.({ status: 'loading', component: 'config', progress: 0 }); |
|
|
try { |
|
|
const configResp = await fetch(`${baseUrl}/config.json`); |
|
|
instance.config = await configResp.json(); |
|
|
} catch (e) { |
|
|
console.warn('Could not load config, using defaults'); |
|
|
instance.config = { |
|
|
audio: { sample_rate: 24000, vae_dim: 64 }, |
|
|
diffusion: { num_inference_steps: 20, latent_size: 64, hidden_size: 896 } |
|
|
}; |
|
|
} |
|
|
|
|
|
|
|
|
const sessionOptions = { |
|
|
executionProviders: ['wasm'], |
|
|
graphOptimizationLevel: 'all', |
|
|
}; |
|
|
|
|
|
|
|
|
const models = { |
|
|
tts_llm: dtype === 'fp32' ? 'tts_llm.onnx' : `tts_llm_${dtype}.onnx`, |
|
|
diffusion_head: dtype === 'fp32' ? 'diffusion_head.onnx' : `diffusion_head_${dtype}.onnx`, |
|
|
vocoder: dtype === 'fp32' ? 'vocoder.onnx' : `vocoder_${dtype}.onnx`, |
|
|
}; |
|
|
|
|
|
|
|
|
const totalModels = Object.keys(models).length; |
|
|
let loadedCount = 0; |
|
|
|
|
|
for (const [name, filename] of Object.entries(models)) { |
|
|
progress_callback?.({ |
|
|
status: 'loading', |
|
|
component: name, |
|
|
progress: (loadedCount / totalModels) * 100 |
|
|
}); |
|
|
|
|
|
try { |
|
|
console.log(`Loading ${name} from ${baseUrl}/${filename}...`); |
|
|
const response = await fetch(`${baseUrl}/${filename}`); |
|
|
|
|
|
if (!response.ok) { |
|
|
throw new Error(`HTTP ${response.status}: ${response.statusText}`); |
|
|
} |
|
|
|
|
|
const buffer = await response.arrayBuffer(); |
|
|
instance.sessions[name] = await ort.InferenceSession.create(buffer, sessionOptions); |
|
|
console.log(`✓ Loaded ${name} (${(buffer.byteLength / 1024 / 1024).toFixed(1)} MB)`); |
|
|
loadedCount++; |
|
|
} catch (e) { |
|
|
console.error(`✗ Failed to load ${name}: ${e.message}`); |
|
|
throw e; |
|
|
} |
|
|
} |
|
|
|
|
|
progress_callback?.({ status: 'ready', progress: 100 }); |
|
|
instance.loaded = true; |
|
|
return instance; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenize(text) { |
|
|
|
|
|
|
|
|
const tokens = []; |
|
|
for (const char of text) { |
|
|
tokens.push(char.charCodeAt(0) % 1000); |
|
|
} |
|
|
return tokens; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async generate(text, options = {}) { |
|
|
if (!this.loaded) { |
|
|
throw new Error('Model not loaded. Call from_pretrained first.'); |
|
|
} |
|
|
|
|
|
const { |
|
|
num_inference_steps = 20, |
|
|
progress_callback = null, |
|
|
} = options; |
|
|
|
|
|
console.log(`Generating speech for: "${text.substring(0, 50)}..."`); |
|
|
|
|
|
|
|
|
progress_callback?.({ stage: 'tokenize', progress: 0 }); |
|
|
const tokens = this.tokenize(text); |
|
|
const seqLen = tokens.length; |
|
|
|
|
|
|
|
|
progress_callback?.({ stage: 'llm', progress: 10 }); |
|
|
const inputIds = new ort.Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, seqLen]); |
|
|
const attentionMask = new ort.Tensor('int64', BigInt64Array.from(Array(seqLen).fill(1n)), [1, seqLen]); |
|
|
const positionIds = new ort.Tensor('int64', BigInt64Array.from([...Array(seqLen).keys()].map(BigInt)), [1, seqLen]); |
|
|
|
|
|
const llmOutput = await this.sessions.tts_llm.run({ |
|
|
input_ids: inputIds, |
|
|
attention_mask: attentionMask, |
|
|
position_ids: positionIds, |
|
|
}); |
|
|
|
|
|
|
|
|
const hiddenStates = llmOutput.hidden_states; |
|
|
console.log('LLM output shape:', hiddenStates.dims); |
|
|
|
|
|
|
|
|
progress_callback?.({ stage: 'diffusion', progress: 30 }); |
|
|
const latentSize = this.config.diffusion?.latent_size || 64; |
|
|
const numFrames = seqLen; |
|
|
|
|
|
|
|
|
const allLatents = new Float32Array(numFrames * latentSize); |
|
|
|
|
|
for (let frame = 0; frame < numFrames; frame++) { |
|
|
|
|
|
const hiddenSize = this.config.diffusion?.hidden_size || 896; |
|
|
const frameHidden = new Float32Array(hiddenSize); |
|
|
for (let i = 0; i < hiddenSize; i++) { |
|
|
frameHidden[i] = hiddenStates.data[frame * hiddenSize + i]; |
|
|
} |
|
|
|
|
|
|
|
|
let latent = new Float32Array(latentSize); |
|
|
for (let i = 0; i < latentSize; i++) { |
|
|
latent[i] = this.randomNormal(); |
|
|
} |
|
|
|
|
|
const timesteps = this.getTimesteps(num_inference_steps); |
|
|
|
|
|
for (let step = 0; step < timesteps.length; step++) { |
|
|
const t = timesteps[step]; |
|
|
|
|
|
const diffusionOutput = await this.sessions.diffusion_head.run({ |
|
|
noisy_latent: new ort.Tensor('float32', latent, [1, latentSize]), |
|
|
timestep: new ort.Tensor('float32', [t], [1]), |
|
|
hidden_states: new ort.Tensor('float32', frameHidden, [1, hiddenSize]), |
|
|
}); |
|
|
|
|
|
const vPred = diffusionOutput.v_prediction.data; |
|
|
latent = this.denoisingStep(latent, vPred, t, timesteps[step + 1] || 0); |
|
|
} |
|
|
|
|
|
|
|
|
for (let i = 0; i < latentSize; i++) { |
|
|
allLatents[frame * latentSize + i] = latent[i]; |
|
|
} |
|
|
|
|
|
progress_callback?.({ |
|
|
stage: 'diffusion', |
|
|
progress: 30 + (frame / numFrames) * 50 |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
progress_callback?.({ stage: 'vocoder', progress: 80 }); |
|
|
|
|
|
|
|
|
const latentsTensor = new ort.Tensor('float32', allLatents, [1, latentSize, numFrames]); |
|
|
|
|
|
const vocoderOutput = await this.sessions.vocoder.run({ |
|
|
latents: latentsTensor, |
|
|
}); |
|
|
|
|
|
const audioData = vocoderOutput.audio.data; |
|
|
console.log('Audio output length:', audioData.length); |
|
|
|
|
|
progress_callback?.({ stage: 'done', progress: 100 }); |
|
|
|
|
|
return new Float32Array(audioData); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
getTimesteps(numSteps) { |
|
|
const timesteps = []; |
|
|
for (let i = 0; i < numSteps; i++) { |
|
|
timesteps.push(Math.floor(999 * (1 - i / (numSteps - 1)))); |
|
|
} |
|
|
return timesteps; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
denoisingStep(latent, vPred, t, tNext) { |
|
|
const alpha = this.getAlphaCumprod(t); |
|
|
const alphaNext = this.getAlphaCumprod(tNext); |
|
|
const sigma = Math.sqrt(1 - alpha); |
|
|
const sigmaNext = Math.sqrt(1 - alphaNext); |
|
|
|
|
|
const result = new Float32Array(latent.length); |
|
|
for (let i = 0; i < latent.length; i++) { |
|
|
|
|
|
|
|
|
|
|
|
const sqrtAlpha = Math.sqrt(alpha); |
|
|
const x0Pred = sqrtAlpha * latent[i] - sigma * vPred[i]; |
|
|
|
|
|
|
|
|
const sqrtAlphaNext = Math.sqrt(alphaNext); |
|
|
result[i] = sqrtAlphaNext * x0Pred + sigmaNext * this.randomNormal() * 0.1; |
|
|
} |
|
|
return result; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
getAlphaCumprod(t) { |
|
|
const s = 0.008; |
|
|
const tNorm = t / 1000; |
|
|
const f = Math.cos(((tNorm + s) / (1 + s)) * Math.PI / 2); |
|
|
return f * f; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
randomNormal() { |
|
|
const u1 = Math.random(); |
|
|
const u2 = Math.random(); |
|
|
return Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static playAudio(audioData, sampleRate = 24000) { |
|
|
const audioContext = new (window.AudioContext || window.webkitAudioContext)(); |
|
|
const buffer = audioContext.createBuffer(1, audioData.length, sampleRate); |
|
|
buffer.copyToChannel(audioData, 0); |
|
|
|
|
|
const source = audioContext.createBufferSource(); |
|
|
source.buffer = buffer; |
|
|
source.connect(audioContext.destination); |
|
|
source.start(0); |
|
|
|
|
|
return { audioContext, buffer }; |
|
|
} |
|
|
} |
|
|
|
|
|
export default VibeVoiceTTS; |
|
|
|