import pandas as pd from utils.load_model import run_hubert_base, run_whisper, run_model, run_timit, run_wavlm_large_phoneme, run_gruut from utils.audio_process import calculate_error_rate, load_audio from utils.cmu_process import clean_cmu, cmu_to_ipa, text_to_phoneme from constants import DATASETS, FINAL_SIZE from datasets import load_dataset, Audio import argparse # Map model names to their runner functions MODEL_RUNNERS = { "HuBERT-Base": run_hubert_base, "Whisper": run_whisper, "HuBERT fine-tuned": run_model, "Timit": run_timit, "WavLM": run_wavlm_large_phoneme, "LJSpeech Gruut": run_gruut, } def set_output(model, pre_pho, ref_pho, duration, per, score): return { "model": model, "phonemes": pre_pho, "ref_phonemes": ref_pho, "duration": duration, "PER": per, "score": score } def get_output(model, wav, reference_phoneme): """ Run the given model, compute error rate, and return formatted output. """ if model not in MODEL_RUNNERS: raise ValueError(f"Unknown model: {model}") run_func = MODEL_RUNNERS[model] phonemes, dur = run_func(wav) per, score = calculate_error_rate(reference_phoneme, phonemes) return set_output(model, phonemes, reference_phoneme, dur, per, score) def benchmark_all(example): """ Run all models on a single dataset example in parallel. """ # Load waveform manually to avoid datasets' torchcodec dependency wav = load_audio(example["audio"]) reference_phoneme = example["phonetic"] reference_phoneme = cmu_to_ipa(clean_cmu(reference_phoneme)) # Run all models in parallel using ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor models = [ "HuBERT-Base", "Whisper", "HuBERT fine-tuned", "Timit", "WavLM", "LJSpeech Gruut" ] with ThreadPoolExecutor(max_workers=len(models)) as executor: futures = [ executor.submit(get_output, model, wav, reference_phoneme) for model in models ] results = [future.result() for future in futures] return pd.DataFrame(results) def benchmark_dataset(dataset): """ Run benchmark_all on each sample and compute average PER and duration per model. """ all_results = [] for example in dataset: df = benchmark_all(example) all_results.append(df) full_df = pd.concat(all_results, ignore_index=True) # Compute average PER and duration per model avg_stats = ( full_df.groupby("model")[["PER", "duration"]] .mean() .reset_index() .rename(columns={"PER": "Average PER", "duration": "Average Duration (s)"}) ) return full_df, avg_stats def load_dataset_with_limits(dataset_config, max_samples=None, use_streaming=False): """ Load a dataset with optional size limits and streaming. Args: dataset_config: Dictionary containing dataset configuration max_samples: Maximum number of samples to load (None for no limit) use_streaming: Whether to use streaming for large datasets Returns: Dataset object """ try: # Prepare load_dataset arguments load_args = { "path": dataset_config["name"], "split": dataset_config["split"] } # Add config if specified if "config" in dataset_config: load_args["name"] = dataset_config["config"] # Add streaming if requested if use_streaming: load_args["streaming"] = True print(f"Loading {dataset_config['name']} with streaming...") else: print(f"Loading {dataset_config['name']}...") dataset = load_dataset(**load_args) # Apply size limits if max_samples is not None: print(f"Limiting dataset to {max_samples} samples...") if use_streaming: dataset = dataset.take(max_samples) else: dataset = dataset.select(range(min(max_samples, len(dataset)))) return dataset except Exception as e: print(f"[warn] skip dataset {dataset_config['name']}: {e}") return None def parse_cli_args(): """ Parse and return CLI arguments for the evaluation script. """ parser = argparse.ArgumentParser(description='Phoneme Detection Evaluation') parser.add_argument('--max-samples', type=int, default=None, help='Override max_samples for all datasets') parser.add_argument('--dataset', type=str, default=None, help='Process only specific dataset (by name)') return parser.parse_args() def cast_audio_column_safely(dataset): """ Ensure the dataset's 'audio' column is set to non-decoding Audio. """ try: dataset = dataset.cast_column("audio", Audio(decode=False)) except Exception: pass return dataset def prepare_dataset_for_evaluation(dataset, dataset_config, max_samples): """ Normalize, deduplicate, and filter dataset examples for evaluation. Handles both streaming and non-streaming datasets. Returns a finalized small dataset suitable for benchmarking. """ field = dataset_config["field"] use_streaming = dataset_config.get("use_streaming", False) if use_streaming: print("Processing streaming dataset...") valid_samples = [] streaming_limit = min(max_samples, FINAL_SIZE) for example in dataset: if field == "text": phonetic_text = text_to_phoneme(example[field]) example = {**example, "phonetic": phonetic_text} current_field = "phonetic" else: current_field = field if current_field in example: phoneme_tokens = example[current_field].split() if len(phoneme_tokens) >= 10: valid_samples.append(example) if len(valid_samples) >= streaming_limit: break print(f"Found {len(valid_samples)} valid samples") if len(valid_samples) == 0: print("No valid samples found, skipping dataset") return None from datasets import Dataset dataset_final = Dataset.from_list(valid_samples) return dataset_final else: if field == "text": dataset = dataset.map(lambda x: {"phonetic": text_to_phoneme(x[field])}) field = "phonetic" unique_texts = dataset.unique(field) print("Unique phonetic strings (", dataset_config["name"], "):", len(unique_texts)) dataset_unique = dataset.filter(lambda x: x[field] in unique_texts) def is_valid(example): phoneme_tokens = example[field].split() return len(phoneme_tokens) >= 10 dataset_filtered = dataset_unique.filter(is_valid) final_size = min(FINAL_SIZE, len(dataset_filtered)) dataset_final = dataset_filtered.shuffle(seed=42).select(range(final_size)) return dataset_final def evaluate_dataset(dataset_final): """ Run benchmarking on a capped subset of the dataset and return both the full per-example results and the aggregated stats per model. """ benchmark_size = min(FINAL_SIZE, len(dataset_final)) return benchmark_dataset(dataset_final.select(range(benchmark_size))) def update_aggregates(per_model_results, avg_stats, dataset_name): """ Update the aggregate dictionary per model with results from one dataset. """ dataset_key = dataset_name.split("/")[-1] for _, row in avg_stats.iterrows(): model_name = str(row["model"]).replace(" ", "-") per = float(row["Average PER"]) if row["Average PER"] is not None else None avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None if model_name not in per_model_results: per_model_results[model_name] = {} per_model_results[model_name][dataset_key] = {"per": per, "avg_duration": avg_dur} def save_leaderboard_results(per_model_results, results_dir="eval-results"): """ Persist one JSON file per model for the leaderboard app to consume. """ import json, os, time os.makedirs(results_dir, exist_ok=True) timestamp = int(time.time()) for model_name, task_results in per_model_results.items(): org_model = f"{model_name}" payload = { "config": { "model_name": org_model, "model_dtype": "float32", "model_sha": "" }, "results": task_results } out_path = os.path.join(results_dir, f"results_{timestamp}_{model_name}.json") with open(out_path, "w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) print(f"Saved leaderboard result: {out_path}") def process_single_dataset(dataset_config, args, per_model_results): """ Load, normalize, evaluate a single dataset and update aggregates. """ if args.dataset and args.dataset not in dataset_config["name"]: return max_samples = args.max_samples if args.max_samples is not None else dataset_config.get("max_samples") use_streaming = dataset_config.get("use_streaming", False) dataset = load_dataset_with_limits( dataset_config, max_samples=max_samples, use_streaming=use_streaming ) if dataset is None: return dataset = cast_audio_column_safely(dataset) dataset_final = prepare_dataset_for_evaluation(dataset, dataset_config, max_samples) if dataset_final is None: return print(dataset_final) print("Final size:", len(dataset_final)) full_results, avg_stats = evaluate_dataset(dataset_final) print("Average Statistic per model (", dataset_config["name"], "):") print(avg_stats) update_aggregates(per_model_results, avg_stats, dataset_config["name"]) def main(): args = parse_cli_args() per_model_results = {} for dataset_config in DATASETS: process_single_dataset(dataset_config, args, per_model_results) save_leaderboard_results(per_model_results) if __name__ == "__main__": main()