|
|
|
|
|
|
|
|
|
|
|
|
|
|
import spaces |
|
|
import traceback |
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Tuple, Optional, Dict, Any |
|
|
import gc |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
logger = logging.getLogger("mmedit_space") |
|
|
|
|
|
|
|
|
|
|
|
MMEDIT_REPO_ID = os.environ.get("MMEDIT_REPO_ID", "CocoBro/MMEdit") |
|
|
MMEDIT_REVISION = os.environ.get("MMEDIT_REVISION", None) |
|
|
|
|
|
QWEN_REPO_ID = os.environ.get("QWEN_REPO_ID", "Qwen/Qwen2-Audio-7B-Instruct") |
|
|
QWEN_REVISION = os.environ.get("QWEN_REVISION", None) |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "./outputs")) |
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
USE_AMP = os.environ.get("USE_AMP", "0") == "1" |
|
|
AMP_DTYPE = os.environ.get("AMP_DTYPE", "bf16") |
|
|
|
|
|
_PIPELINE_CACHE: Dict[str, Tuple[object, object, int]] = {} |
|
|
|
|
|
_MODEL_DIR_CACHE: Dict[str, Tuple[Path, Path]] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_model_dirs() -> Tuple[Path, Path]: |
|
|
cache_key = f"{MMEDIT_REPO_ID}@{MMEDIT_REVISION}::{QWEN_REPO_ID}@{QWEN_REVISION}" |
|
|
if cache_key in _MODEL_DIR_CACHE: |
|
|
return _MODEL_DIR_CACHE[cache_key] |
|
|
|
|
|
logger.info(f"Downloading MMEdit repo: {MMEDIT_REPO_ID} (revision={MMEDIT_REVISION})") |
|
|
repo_root = snapshot_download( |
|
|
repo_id=MMEDIT_REPO_ID, |
|
|
revision=MMEDIT_REVISION, |
|
|
local_dir=None, |
|
|
local_dir_use_symlinks=False, |
|
|
token=HF_TOKEN, |
|
|
) |
|
|
repo_root = Path(repo_root).resolve() |
|
|
|
|
|
logger.info(f"Downloading Qwen repo: {QWEN_REPO_ID} (revision={QWEN_REVISION})") |
|
|
qwen_root = snapshot_download( |
|
|
repo_id=QWEN_REPO_ID, |
|
|
revision=QWEN_REVISION, |
|
|
local_dir=None, |
|
|
local_dir_use_symlinks=False, |
|
|
token=HF_TOKEN, |
|
|
) |
|
|
qwen_root = Path(qwen_root).resolve() |
|
|
|
|
|
_MODEL_DIR_CACHE[cache_key] = (repo_root, qwen_root) |
|
|
return repo_root, qwen_root |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_and_process_audio(audio_path: str, target_sr: int): |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
import librosa |
|
|
|
|
|
|
|
|
|
|
|
path = Path(audio_path) |
|
|
if not path.exists(): |
|
|
raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
|
|
|
waveform, orig_sr = torchaudio.load(str(path)) |
|
|
|
|
|
|
|
|
if waveform.ndim == 2: |
|
|
waveform = waveform.mean(dim=0) |
|
|
elif waveform.ndim > 2: |
|
|
waveform = waveform.reshape(-1) |
|
|
|
|
|
if target_sr and int(target_sr) != int(orig_sr): |
|
|
waveform_np = waveform.cpu().numpy() |
|
|
|
|
|
|
|
|
sr_mid = 16000 |
|
|
if int(orig_sr) != sr_mid: |
|
|
waveform_np = librosa.resample(waveform_np, orig_sr=int(orig_sr), target_sr=sr_mid) |
|
|
orig_sr_mid = sr_mid |
|
|
else: |
|
|
orig_sr_mid = int(orig_sr) |
|
|
|
|
|
|
|
|
if int(target_sr) != orig_sr_mid: |
|
|
waveform_np = librosa.resample(waveform_np, orig_sr=orig_sr_mid, target_sr=int(target_sr)) |
|
|
|
|
|
waveform = torch.from_numpy(waveform_np) |
|
|
|
|
|
return waveform |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_repo_layout(repo_root: Path) -> None: |
|
|
must = [repo_root / "config.yaml", repo_root / "model.safetensors", repo_root / "vae"] |
|
|
for p in must: |
|
|
if not p.exists(): |
|
|
raise FileNotFoundError(f"Missing required path: {p}") |
|
|
|
|
|
vae_files = list((repo_root / "vae").glob("*.ckpt")) |
|
|
if len(vae_files) == 0: |
|
|
raise FileNotFoundError(f"No .ckpt found under: {repo_root/'vae'}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch_paths_in_exp_config(exp_cfg: Dict[str, Any], repo_root: Path, qwen_root: Path) -> None: |
|
|
|
|
|
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", None) |
|
|
if vae_ckpt: |
|
|
vae_ckpt = str(vae_ckpt).replace("\\", "/") |
|
|
idx = vae_ckpt.find("vae/") |
|
|
if idx != -1: |
|
|
vae_rel = vae_ckpt[idx:] |
|
|
else: |
|
|
if vae_ckpt.endswith(".ckpt") and "/" not in vae_ckpt: |
|
|
vae_rel = f"vae/{vae_ckpt}" |
|
|
else: |
|
|
vae_rel = vae_ckpt |
|
|
|
|
|
vae_path = (repo_root / vae_rel).resolve() |
|
|
exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(vae_path) |
|
|
|
|
|
if not vae_path.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"VAE ckpt not found after patch:\n" |
|
|
f" original: {vae_ckpt}\n" |
|
|
f" patched : {vae_path}\n" |
|
|
f"Repo root: {repo_root}\n" |
|
|
f"Expected: {repo_root/'vae'/'*.ckpt'}" |
|
|
) |
|
|
|
|
|
|
|
|
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root) |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def run_edit(audio_file, caption, num_steps, guidance_scale, guidance_rescale, seed): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import hydra |
|
|
from omegaconf import OmegaConf |
|
|
from safetensors.torch import load_file |
|
|
import diffusers.schedulers as noise_schedulers |
|
|
logger.info("🚀 Starting ..") |
|
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
|
torch.backends.cudnn.allow_tf32 = False |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from utils.config import register_omegaconf_resolvers |
|
|
register_omegaconf_resolvers() |
|
|
except: pass |
|
|
|
|
|
if not audio_file: return None, "Please upload audio." |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("🚀 Starting ZeroGPU Task...") |
|
|
|
|
|
|
|
|
repo_root, qwen_root = resolve_model_dirs() |
|
|
exp_cfg = OmegaConf.to_container(OmegaConf.load(repo_root / "config.yaml"), resolve=True) |
|
|
|
|
|
|
|
|
vae_ckpt = exp_cfg["model"]["autoencoder"].get("pretrained_ckpt", "") |
|
|
if vae_ckpt: |
|
|
p1 = repo_root / "vae" / Path(vae_ckpt).name |
|
|
p2 = repo_root / Path(vae_ckpt).name |
|
|
if p1.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p1) |
|
|
elif p2.exists(): exp_cfg["model"]["autoencoder"]["pretrained_ckpt"] = str(p2) |
|
|
exp_cfg["model"]["content_encoder"]["text_encoder"]["model_path"] = str(qwen_root) |
|
|
|
|
|
|
|
|
logger.info("Instantiating model (Hydra)...") |
|
|
model = hydra.utils.instantiate(exp_cfg["model"], _convert_="all") |
|
|
|
|
|
|
|
|
ckpt_path = str(repo_root / "model.safetensors") |
|
|
logger.info(f"Loading weights from {ckpt_path}...") |
|
|
sd = load_file(ckpt_path) |
|
|
model.load_pretrained(sd) |
|
|
del sd |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda") |
|
|
logger.info("Moving model to CUDA (FP16)...") |
|
|
|
|
|
|
|
|
def safe_move_model(m, dev): |
|
|
logger.info("🛡️ Moving model to GPU in FP32...") |
|
|
for name, child in m.named_children(): |
|
|
child.to(dev, dtype=torch.float32) |
|
|
logger.info(f"Moving {name} to GPU (fp32)...") |
|
|
m.to(dev, dtype=torch.float32) |
|
|
return m |
|
|
|
|
|
|
|
|
model = safe_move_model(model, device) |
|
|
model.eval() |
|
|
logger.info("Model is moved to CUDA.") |
|
|
|
|
|
try: |
|
|
scheduler = noise_schedulers.DDIMScheduler.from_pretrained( |
|
|
exp_cfg["model"].get("noise_scheduler_name", ""), |
|
|
subfolder="scheduler", token=HF_TOKEN |
|
|
) |
|
|
except: |
|
|
scheduler = noise_schedulers.DDIMScheduler(num_train_timesteps=1000) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_sr = int(exp_cfg.get("sample_rate", 24000)) |
|
|
torch.manual_seed(int(seed)) |
|
|
np.random.seed(int(seed)) |
|
|
|
|
|
wav = load_and_process_audio(audio_file, target_sr).to(device, dtype=torch.float32) |
|
|
|
|
|
batch = { |
|
|
"audio_id": [Path(audio_file).stem], |
|
|
"content": [{"audio": wav, "caption": caption}], |
|
|
"task": ["audio_editing"], |
|
|
"num_steps": int(num_steps), |
|
|
"guidance_scale": float(guidance_scale), |
|
|
"guidance_rescale": float(guidance_rescale), |
|
|
"use_gt_duration": False, |
|
|
"mask_time_aligned_content": False |
|
|
} |
|
|
|
|
|
logger.info("Inference running...") |
|
|
t0 = time.time() |
|
|
with torch.no_grad(): |
|
|
out = model.inference(scheduler=scheduler, **batch) |
|
|
|
|
|
|
|
|
out_audio = out[0, 0].detach().float().cpu().numpy() |
|
|
out_path = OUTPUT_DIR / f"{Path(audio_file).stem}_edited.wav" |
|
|
sf.write(str(out_path), out_audio, samplerate=target_sr) |
|
|
|
|
|
return str(out_path), f"Success | {time.time()-t0:.2f}s" |
|
|
|
|
|
except Exception as e: |
|
|
err = traceback.format_exc() |
|
|
logger.error(f"❌ ERROR:\n{err}") |
|
|
return None, f"Runtime Error: {e}" |
|
|
|
|
|
finally: |
|
|
|
|
|
logger.info("Cleaning up...") |
|
|
if model is not None: del model |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_demo(): |
|
|
with gr.Blocks(title="MMEdit") as demo: |
|
|
gr.Markdown("# MMEdit ZeroGPU (Direct Load)") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
audio_in = gr.Audio(label="Input", type="filepath") |
|
|
caption = gr.Textbox(label="Instruction", lines=3) |
|
|
gr.Examples( |
|
|
label="Examples (Click to load)", |
|
|
|
|
|
examples=[ |
|
|
|
|
|
["./Ym8O802VvJes.wav", "Mix in dog barking around the middle."], |
|
|
["./YDKM2KjNkX18.wav", "Incorporate Telephone bell ringing into the background."], |
|
|
["./drop_audiocaps_1.wav", "Erase the rain falling sound from the background."], |
|
|
["./reorder_audiocaps_1.wav", "Switch the positions of the woman's voice and whistling."] |
|
|
], |
|
|
inputs=[audio_in, caption], |
|
|
cache_examples=False, |
|
|
) |
|
|
with gr.Row(): |
|
|
num_steps = gr.Slider(10, 100, 50, step=1, label="Steps") |
|
|
guidance_scale = gr.Slider(1.0, 12.0, 5.0, step=0.5, label="Guidance") |
|
|
guidance_rescale = gr.Slider(0.0, 1.0, 0.5, step=0.05, label="Rescale") |
|
|
seed = gr.Number(42, label="Seed") |
|
|
run_btn = gr.Button("Run", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
out = gr.Audio(label="Output") |
|
|
status = gr.Textbox(label="Status") |
|
|
|
|
|
run_btn.click(run_edit, [audio_in, caption, num_steps, guidance_scale, guidance_rescale, seed], [out, status]) |
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("[BOOT] entering main()", flush=True) |
|
|
demo = build_demo() |
|
|
port = int(os.environ.get("PORT", "7860")) |
|
|
print(f"[BOOT] launching gradio on 0.0.0.0:{port}", flush=True) |
|
|
demo.queue().launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=port, |
|
|
share=False, |
|
|
ssr_mode=False, |
|
|
) |