MMEdit / app.py
CocoBro's picture
f32
62c2ea1
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ZeroGPU 关键:必须最先导入
import spaces
import traceback
import os
import time
import logging
from pathlib import Path
from typing import Tuple, Optional, Dict, Any
import gc
import gradio as gr
import numpy as np
import soundfile as sf
from huggingface_hub import snapshot_download
# -----------------------------
# Logging
# -----------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("mmedit_space")
MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit")
MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None)
QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct")
QWEN_REVISION = os.environ.get("QWEN_REVISION", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs"))
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
USE_AMP = os.environ.get("USE_AMP", "0") == "1"
AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") # "bf16" or "fp16"
_PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {}
# cache: key -> (repo_root, qwen_root)
_MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {}
# ---------------------------------------------------------
# 下载 repo(只下载一次;huggingface_hub 自带缓存)
# ---------------------------------------------------------
def resolve_model_dirs() -> Tuple[Path, Path]:
cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}"
if cache_key in _MODEL_DIR_CACHE:
return _MODEL_DIR_CACHE[cache_key]
logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})")
repo_root = snapshot_download(
repo_id=MMEDIT_REPO_ID,
revision=MMEDIT_REVISION,
local_dir=None,
local_dir_use_symlinks=False,
token=HF_TOKEN,
)
repo_root = Path(repo_root).resolve()
logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})")
qwen_root = snapshot_download(
repo_id=QWEN_REPO_ID,
revision=QWEN_REVISION,
local_dir=None,
local_dir_use_symlinks=False,
token=HF_TOKEN, # gated 模型必须
)
qwen_root = Path(qwen_root).resolve()
_MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root)
return repo_root, qwen_root
# ---------------------------------------------------------
# 你的音频加载(按你要求:orig -> 16k -> target_sr)
# ---------------------------------------------------------
def load_and_process_audio(audio_path: str, target_sr: int):
# 延迟导入(避免启动阶段触发 CUDA 初始化)
import torch
import torchaudio
import librosa
path = Path(audio_path)
if not path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
waveform, orig_sr = torchaudio.load(str(path)) # (C, T)
# Convert to mono
if waveform.ndim == 2:
waveform = waveform.mean(dim=0) # (T,)
elif waveform.ndim > 2:
waveform = waveform.reshape(-1)
if target_sr and int(target_sr) != int(orig_sr):
waveform_np = waveform.cpu().numpy()
# 1) 先到 16k
sr_mid = 16000
if int(orig_sr) != sr_mid:
waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid)
orig_sr_mid = sr_mid
else:
orig_sr_mid = int(orig_sr)
# 2) 再到 target_sr(如 24k)
if int(target_sr) != orig_sr_mid:
waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr))
waveform = torch.from_numpy(waveform_np)
return waveform
# ---------------------------------------------------------
# 校验 repo 结构
# ---------------------------------------------------------
def assert_repo_layout(repo_root: Path) -> None:
must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"]
for p in must:
if not p.exists():
raise FileNotFoundError(f"Missing required path: {p}")
vae_files = list((repo_root / "vae").glob("*.ckpt"))
if len(vae_files) == 0:
raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}")
# ---------------------------------------------------------
# 适配 config.yaml 的路径写法
# ---------------------------------------------------------
def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None:
# ---- 1) VAE ckpt ----
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None)
if vae_ckpt:
vae_ckpt = str(vae_ckpt).replace("\\", "/")
idx = vae_ckpt.find("vae/")
if idx != -1:
vae_rel = vae_ckpt[idx:] # 从 vae/ 开始截断
else:
if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt:
vae_rel = f"vae/{vae_ckpt}"
else:
vae_rel = vae_ckpt
vae_path = (repo_root / vae_rel).resolve()
exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path)
if not vae_path.exists():
raise FileNotFoundError(
f"VAE ckpt not found after patch:\n"
f" original: {vae_ckpt}\n"
f" patched : {vae_path}\n"
f"Repo root: {repo_root}\n"
f"Expected: {repo_root/'vae'/'*.ckpt'}"
)
# ---- 2) Qwen2-Audio model_path ----
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
@spaces.GPU
def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed):
import torch
import hydra
from omegaconf import OmegaConf
from safetensors.torch import load_file
import diffusers.schedulers as noise_schedulers
logger.info("🚀 Starting ..")
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
try:
from utils.config import register_omegaconf_resolvers
register_omegaconf_resolvers()
except: pass
if not audio_file: return None, "Please upload audio."
model = None
try:
# ==========================================
logger.info("🚀 Starting ZeroGPU Task...")
# 路径准备
repo_root, qwen_root = resolve_model_dirs()
exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True)
#
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "")
if vae_ckpt:
p1 = repo_root / "vae" / Path(vae_ckpt).name
p2 = repo_root / Path(vae_ckpt).name
if p1.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p1)
elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2)
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root)
#
logger.info("Instantiating model (Hydra)...")
model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all")
# 加载权重
ckpt_path = str(repo_root / "model.safetensors")
logger.info(f"Loading weights from {ckpt_path}...")
sd = load_file(ckpt_path)
model.load_pretrained(sd)
del sd # 立即释放
gc.collect()
# ==========================================
# ==========================================
device = torch.device("cuda")
logger.info("Moving model to CUDA (FP16)...")
# 这一步将模型送入显卡
def safe_move_model(m, dev):
logger.info("🛡️ Moving model to GPU in FP32...")
for name, child in m.named_children():
child.to(dev, dtype=torch.float32)
logger.info(f"Moving {name} to GPU (fp32)...")
m.to(dev, dtype=torch.float32)
return m
model = safe_move_model(model, device)
model.eval()
logger.info("Model is moved to CUDA.")
# Scheduler
try:
scheduler = noise_schedulers.DDIMScheduler.from_pretrained(
exp_cfg["model"].get("noise_scheduler_name", ""),
subfolder="scheduler", token=HF_TOKEN
)
except:
scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000)
# ==========================================
# 3. 开始推理
# ==========================================
target_sr = int(exp_cfg.get("sample_rate", 24000))
torch.manual_seed(int(seed))
np.random.seed(int(seed))
wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float32)
batch = {
"audio_id": [Path(audio_file).stem],
"content": [{"audio": wav, "caption": caption}],
"task": ["audio_editing"],
"num_steps": int(num_steps),
"guidance_scale": float(guidance_scale),
"guidance_rescale": float(guidance_rescale),
"use_gt_duration": False,
"mask_time_aligned_content": False
}
logger.info("Inference running...")
t0 = time.time()
with torch.no_grad():
out = model.inference(scheduler=scheduler, **batch)
out_audio = out[0, 0].detach().float().cpu().numpy()
out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav"
sf.write(str(out_path), out_audio, samplerate=target_sr)
return str(out_path), f"Success | {time.time()-t0:.2f}s"
except Exception as e:
err = traceback.format_exc()
logger.error(f"❌ ERROR:\n{err}")
return None, f"Runtime Error: {e}"
finally:
# 强制清理,防止下一次任务显存不够
logger.info("Cleaning up...")
if model is not None: del model
torch.cuda.empty_cache()
gc.collect()
# -----------------------------
# UI
# -----------------------------
def build_demo():
with gr.Blocks(title="MMEdit") as demo:
gr.Markdown("# MMEdit ZeroGPU (Direct Load)")
with gr.Row():
with gr.Column():
audio_in = gr.Audio(label="Input", type="filepath")
caption = gr.Textbox(label="Instruction", lines=3)
gr.Examples(
label="Examples (Click to load)",
# 格式:[ [音频路径1, 提示词1], [音频路径2, 提示词2], ... ]
examples=[
# 示例 1 (原本的)
["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."],
["./YDKM2KjNkX18.wav", "Incorporate Telephone bell ringing into the background."],
["./drop_audiocaps_1.wav", "Erase the rain falling sound from the background."],
["./reorder_audiocaps_1.wav", "Switch the positions of the woman's voice and whistling."]
],
inputs=[audio_in, caption], # 对应上面列表的顺序:第一个是 Audio,第二个是 Textbox
cache_examples=False, # ZeroGPU 环境建议设为 False,避免启动时耗时计算
)
with gr.Row():
num_steps = gr.Slider(10, 100, 50, step=1, label="Steps")
guidance_scale = gr.Slider(1.0, 12.0, 5.0, step=0.5, label="Guidance")
guidance_rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale")
seed = gr.Number(42, label="Seed")
run_btn = gr.Button("Run", variant="primary")
with gr.Column():
out = gr.Audio(label="Output")
status = gr.Textbox(label="Status")
run_btn.click(run_edit, [audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed], [out, status])
return demo
if __name__ == "__main__":
print("[BOOT] entering main()", flush=True)
demo = build_demo()
port = int(os.environ.get("PORT", "7860"))
print(f"[BOOT] launching gradio on 0.0.0.0:{port}", flush=True)
demo.queue().launch(
server_name="0.0.0.0",
server_port=port,
share=False,
ssr_mode=False,
)