#!/usr/bin/env python3 # -*- coding: utf-8 -*- # ZeroGPU 关键:必须最先导入 import spaces import traceback import os import time import logging from pathlib import Path from typing import Tuple, Optional, Dict, Any import gc import gradio as gr import numpy as np import soundfile as sf from huggingface_hub import snapshot_download # ----------------------------- # Logging # ----------------------------- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger("mmedit_space") MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit") MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None) QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct") QWEN_REVISION = os.environ.get("QWEN_REVISION", None) HF_TOKEN = os.environ.get("HF_TOKEN", None) OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs")) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) USE_AMP = os.environ.get("USE_AMP", "0") == "1" AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16" _PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {} # cache: key -> (repo_root, qwen_root) _MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {} # --------------------------------------------------------- # 下载 repo(只下载一次;huggingface_hub 自带缓存) # --------------------------------------------------------- def resolve_model_dirs() -> Tuple[Path, Path]: cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}" if cache_key in _MODEL_DIR_CACHE: return _MODEL_DIR_CACHE[cache_key] logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})") repo_root = snapshot_download( repo_id=MMEDIT_REPO_ID, revision=MMEDIT_REVISION, local_dir=None, local_dir_use_symlinks=False, token=HF_TOKEN, ) repo_root = Path(repo_root).resolve() logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})") qwen_root = snapshot_download( repo_id=QWEN_REPO_ID, revision=QWEN_REVISION, local_dir=None, local_dir_use_symlinks=False, token=HF_TOKEN, # gated 模型必须 ) qwen_root = Path(qwen_root).resolve() _MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root) return repo_root, qwen_root # --------------------------------------------------------- # 你的音频加载(按你要求:orig -> 16k -> target_sr) # --------------------------------------------------------- def load_and_process_audio(audio_path: str, target_sr: int): # 延迟导入(避免启动阶段触发 CUDA 初始化) import torch import torchaudio import librosa path = Path(audio_path) if not path.exists(): raise FileNotFoundError(f"Audio file not found: {audio_path}") waveform, orig_sr = torchaudio.load(str(path)) # (C, T) # Convert to mono if waveform.ndim == 2: waveform = waveform.mean(dim=0) # (T,) elif waveform.ndim > 2: waveform = waveform.reshape(-1) if target_sr and int(target_sr) != int(orig_sr): waveform_np = waveform.cpu().numpy() # 1) 先到 16k sr_mid = 16000 if int(orig_sr) != sr_mid: waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid) orig_sr_mid = sr_mid else: orig_sr_mid = int(orig_sr) # 2) 再到 target_sr(如 24k) if int(target_sr) != orig_sr_mid: waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr)) waveform = torch.from_numpy(waveform_np) return waveform # --------------------------------------------------------- # 校验 repo 结构 # --------------------------------------------------------- def assert_repo_layout(repo_root: Path) -> None: must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"] for p in must: if not p.exists(): raise FileNotFoundError(f"Missing required path: {p}") vae_files = list((repo_root / "vae").glob("*.ckpt")) if len(vae_files) == 0: raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}") # --------------------------------------------------------- # 适配 config.yaml 的路径写法 # --------------------------------------------------------- def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None: # ---- 1) VAE ckpt ---- vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None) if vae_ckpt: vae_ckpt = str(vae_ckpt).replace("\\", "/") idx = vae_ckpt.find("vae/") if idx != -1: vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断 else: if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt: vae_rel = f"vae/{vae_ckpt}" else: vae_rel = vae_ckpt vae_path = (repo_root / vae_rel).resolve() exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path) if not vae_path.exists(): raise FileNotFoundError( f"VAE ckpt not found after patch:\n" f" original: {vae_ckpt}\n" f" patched : {vae_path}\n" f"Repo root: {repo_root}\n" f"Expected: {repo_root/'vae'/'*.ckpt'}" ) # ---- 2) Qwen2-Audio model_path ---- exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root) @spaces.GPU def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed): import torch import hydra from omegaconf import OmegaConf from safetensors.torch import load_file import diffusers.schedulers as noise_schedulers logger.info("🚀 Starting ..") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False try: from utils.config import register_omegaconf_resolvers register_omegaconf_resolvers() except: pass if not audio_file: return None, "Please upload audio." model = None try: # ========================================== logger.info("🚀 Starting ZeroGPU Task...") # 路径准备 repo_root, qwen_root = resolve_model_dirs() exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True) # vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "") if vae_ckpt: p1 = repo_root / "vae" / Path(vae_ckpt).name p2 = repo_root / Path(vae_ckpt).name if p1.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p1) elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2) exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root) # logger.info("Instantiating model (Hydra)...") model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all") # 加载权重 ckpt_path = str(repo_root / "model.safetensors") logger.info(f"Loading weights from {ckpt_path}...") sd = load_file(ckpt_path) model.load_pretrained(sd) del sd # 立即释放 gc.collect() # ========================================== # ========================================== device = torch.device("cuda") logger.info("Moving model to CUDA (FP16)...") # 这一步将模型送入显卡 def safe_move_model(m, dev): logger.info("🛡️ Moving model to GPU in FP32...") for name, child in m.named_children(): child.to(dev, dtype=torch.float32) logger.info(f"Moving {name} to GPU (fp32)...") m.to(dev, dtype=torch.float32) return m model = safe_move_model(model, device) model.eval() logger.info("Model is moved to CUDA.") # Scheduler try: scheduler = noise_schedulers.DDIMScheduler.from_pretrained( exp_cfg["model"].get("noise_scheduler_name", ""), subfolder="scheduler", token=HF_TOKEN ) except: scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000) # ========================================== # 3. 开始推理 # ========================================== target_sr = int(exp_cfg.get("sample_rate", 24000)) torch.manual_seed(int(seed)) np.random.seed(int(seed)) wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float32) batch = { "audio_id": [Path(audio_file).stem], "content": [{"audio": wav, "caption": caption}], "task": ["audio_editing"], "num_steps": int(num_steps), "guidance_scale": float(guidance_scale), "guidance_rescale": float(guidance_rescale), "use_gt_duration": False, "mask_time_aligned_content": False } logger.info("Inference running...") t0 = time.time() with torch.no_grad(): out = model.inference(scheduler=scheduler, **batch) out_audio = out[0, 0].detach().float().cpu().numpy() out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav" sf.write(str(out_path), out_audio, samplerate=target_sr) return str(out_path), f"Success | {time.time()-t0:.2f}s" except Exception as e: err = traceback.format_exc() logger.error(f"❌ ERROR:\n{err}") return None, f"Runtime Error: {e}" finally: # 强制清理,防止下一次任务显存不够 logger.info("Cleaning up...") if model is not None: del model torch.cuda.empty_cache() gc.collect() # ----------------------------- # UI # ----------------------------- def build_demo(): with gr.Blocks(title="MMEdit") as demo: gr.Markdown("# MMEdit ZeroGPU (Direct Load)") with gr.Row(): with gr.Column(): audio_in = gr.Audio(label="Input", type="filepath") caption = gr.Textbox(label="Instruction", lines=3) gr.Examples( label="Examples (Click to load)", # 格式:[ [音频路径1, 提示词1], [音频路径2, 提示词2], ... ] examples=[ # 示例 1 (原本的) ["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."], ["./YDKM2KjNkX18.wav", "Incorporate Telephone bell ringing into the background."], ["./drop_audiocaps_1.wav", "Erase the rain falling sound from the background."], ["./reorder_audiocaps_1.wav", "Switch the positions of the woman's voice and whistling."] ], inputs=[audio_in, caption], # 对应上面列表的顺序:第一个是 Audio,第二个是 Textbox cache_examples=False, # ZeroGPU 环境建议设为 False,避免启动时耗时计算 ) with gr.Row(): num_steps = gr.Slider(10, 100, 50, step=1, label="Steps") guidance_scale = gr.Slider(1.0, 12.0, 5.0, step=0.5, label="Guidance") guidance_rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale") seed = gr.Number(42, label="Seed") run_btn = gr.Button("Run", variant="primary") with gr.Column(): out = gr.Audio(label="Output") status = gr.Textbox(label="Status") run_btn.click(run_edit, [audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed], [out, status]) return demo if __name__ == "__main__": print("[BOOT] entering main()", flush=True) demo = build_demo() port = int(os.environ.get("PORT", "7860")) print(f"[BOOT] launching gradio on 0.0.0.0:{port}", flush=True) demo.queue().launch( server_name="0.0.0.0", server_port=port, share=False, ssr_mode=False, )