""" Seed-VC Streaming API Server architecture.md と model_ref.md に基づいて実装 """ import io import os import sys import time import uuid from typing import Optional, Dict from argparse import Namespace import numpy as np import soundfile as sf import librosa import torch import torchaudio from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import Response from pydantic import BaseModel from huggingface_hub import hf_hub_download # Seed-VC sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'seed-vc')) # Hugging Face cache directory (absolute path) cache_dir = '/app/checkpoints' os.makedirs(cache_dir, exist_ok=True) os.environ['HF_HOME'] = cache_dir os.environ['HF_HUB_CACHE'] = cache_dir os.environ['TRANSFORMERS_CACHE'] = cache_dir os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # MPSを無効化してCPUを強制 import torch torch.backends.mps.is_available = lambda: False from inference import load_models # ============================================================================= # Configuration (architecture.md Section 5) # ============================================================================= DEFAULT_SAMPLE_RATE = 16000 DEFAULT_CHUNK_LEN_MS = 1000 DEFAULT_OVERLAP_MS = 200 SESSION_EXPIRE_SEC = 600 # model_ref.md Section 3.1 # Hugging Face Hubから参照音声をダウンロード # リポジトリ: Akatuki25/seed-vc-ref-audios (dataset) DEFAULT_REF_PRESET = "default_female" REF_PRESETS = { "default_female": ("Akatuki25/seed-vc-ref-audios", "default_female.wav"), "default_male": ("Akatuki25/seed-vc-ref-audios", "default_male.wav"), } # ダウンロード済み参照音声のキャッシュ downloaded_ref_cache = {} # ============================================================================= # Global Variables # ============================================================================= # MPSは避ける(seed-vcとの互換性問題) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Seed-VCモデル (inference.py load_models()の戻り値) model = None semantic_fn = None f0_fn = None vocoder_fn = None campplus_model = None to_mel = None mel_fn_args = None model_sr = 22050 # ============================================================================= # Session State (architecture.md Section 4.1) # ============================================================================= class SessionState: def __init__(self, sample_rate: int, tgt_speaker_id: Optional[str] = None): self.sample_rate = sample_rate self.tgt_speaker_id = tgt_speaker_id self.last_output_tail: Optional[np.ndarray] = None # model_ref.md Section 3: 参照音声の管理 self.ref_audio_tensor = None # 参照音声 (model_sr, float tensor) self.ref_mel = None self.ref_semantic = None self.style_embed = None self.last_access_ts = time.time() self.chunk_len_ms = DEFAULT_CHUNK_LEN_MS self.overlap_ms = DEFAULT_OVERLAP_MS SESSIONS: Dict[str, SessionState] = {} # ============================================================================= # FastAPI App # ============================================================================= app = FastAPI(title="Seed-VC Streaming API", version="1.0.0") @app.on_event("startup") async def startup_event(): """モデルロード (architecture.md Section 4.3.1)""" global model, semantic_fn, f0_fn, vocoder_fn, campplus_model, to_mel, mel_fn_args, model_sr print(f"Device: {device}") print("Loading Seed-VC models...") # inference.pyのload_modelsをそのまま使用 args = Namespace( f0_condition=False, # model_ref.md: 22050Hz系を使う checkpoint=None, config=None, fp16=False ) model, semantic_fn, f0_fn, vocoder_fn, campplus_model, to_mel, mel_fn_args = load_models(args) model_sr = mel_fn_args['sampling_rate'] print(f"Models loaded! SR={model_sr}") # ============================================================================= # Pydantic Models (architecture.md Section 3.2) # ============================================================================= class SessionCreateRequest(BaseModel): sample_rate: int = DEFAULT_SAMPLE_RATE tgt_speaker_id: Optional[str] = None ref_preset_id: Optional[str] = None use_uploaded_ref: bool = False chunk_len_ms: int = DEFAULT_CHUNK_LEN_MS overlap_ms: int = DEFAULT_OVERLAP_MS class SessionCreateResponse(BaseModel): session_id: str sample_rate: int chunk_len_ms: int overlap_ms: int class SessionEndRequest(BaseModel): session_id: str # ============================================================================= # Utility Functions # ============================================================================= def load_wav_to_numpy(file_bytes: bytes, target_sr: int) -> tuple[np.ndarray, int]: """WAVファイルをnumpy配列に変換""" audio, sr = sf.read(io.BytesIO(file_bytes)) if len(audio.shape) > 1: audio = audio.mean(axis=1) if sr != target_sr: audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) sr = target_sr if audio.dtype in (np.float32, np.float64): audio = (audio * 32767).astype(np.int16) return audio, sr def numpy_to_wav_bytes(audio: np.ndarray, sr: int) -> bytes: """numpy配列をWAVバイト列に変換""" buffer = io.BytesIO() sf.write(buffer, audio, sr, format="WAV", subtype="PCM_16") buffer.seek(0) return buffer.read() def crossfade(prev_tail: Optional[np.ndarray], new_chunk: np.ndarray, fade_len: int) -> np.ndarray: """クロスフェード (architecture.md Section 4.2.1)""" if prev_tail is None: return new_chunk fade_len = min(fade_len, len(prev_tail), len(new_chunk)) if fade_len <= 0: return new_chunk fade_in = np.linspace(0.0, 1.0, fade_len, endpoint=True) fade_out = 1.0 - fade_in mixed_head = (prev_tail[-fade_len:] * fade_out + new_chunk[:fade_len] * fade_in).astype(np.int16) tail = new_chunk[fade_len:] return np.concatenate([mixed_head, tail]) def download_ref_preset(preset_id: str) -> str: """ Hugging Face Hubから参照音声をダウンロード Returns: ローカルファイルパス """ if preset_id in downloaded_ref_cache: return downloaded_ref_cache[preset_id] if preset_id not in REF_PRESETS: raise ValueError(f"Unknown preset_id: {preset_id}") repo_id, filename = REF_PRESETS[preset_id] print(f"Downloading reference audio from {repo_id}/{filename}...") local_path = hf_hub_download( repo_id=repo_id, filename=filename, repo_type="dataset", cache_dir=cache_dir ) downloaded_ref_cache[preset_id] = local_path print(f"Downloaded to {local_path}") return local_path def prepare_reference_audio(audio_path: str, state: SessionState): """ 参照音声を準備 (model_ref.md Section 3) inference.py の main() と同じロジック """ # 参照音声をロード ref_audio, file_sr = librosa.load(audio_path, sr=model_sr) ref_audio = ref_audio[:model_sr * 25] # 25秒まで # tensorに変換 ref_audio_tensor = torch.tensor(ref_audio).unsqueeze(0).float().to(device) state.ref_audio_tensor = ref_audio_tensor # mel spectrogram state.ref_mel = to_mel(ref_audio_tensor) # Whisper semantic features ref_waves_16k = torchaudio.functional.resample(ref_audio_tensor, model_sr, 16000) state.ref_semantic = semantic_fn(ref_waves_16k) # CAMPPlus style embedding feat = torchaudio.compliance.kaldi.fbank( ref_waves_16k, num_mel_bins=80, dither=0, sample_frequency=16000 ) feat = feat - feat.mean(dim=0, keepdim=True) state.style_embed = campplus_model(feat.unsqueeze(0)) print(f"Reference prepared: mel={state.ref_mel.shape}, semantic={state.ref_semantic.shape}") def seed_vc_infer(chunk_np: np.ndarray, chunk_sr: int, state: SessionState) -> np.ndarray: """ Seed-VCで音声変換 (architecture.md Section 4.3.2) inference.py main()のロジックを使用 """ # int16 -> float32 if chunk_np.dtype == np.int16: source_audio = chunk_np.astype(np.float32) / 32768.0 else: source_audio = chunk_np.astype(np.float32) # model_sr にリサンプル if chunk_sr != model_sr: source_audio = librosa.resample(source_audio, orig_sr=chunk_sr, target_sr=model_sr) # tensor化 source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device) # 16kHz変換してWhisper特徴抽出 converted_waves_16k = torchaudio.functional.resample(source_audio, model_sr, 16000) S_alt = semantic_fn(converted_waves_16k) # mel spectrogram mel = to_mel(source_audio.to(device).float()) # target lengths target_lengths = torch.LongTensor([mel.size(2)]).to(device) target2_lengths = torch.LongTensor([state.ref_mel.size(2)]).to(device) # length regulator (inference.py line 354-360) with torch.no_grad(): cond, _, _, _, _ = model.length_regulator( S_alt, ylens=target_lengths, n_quantizers=3, f0=None ) prompt_condition, _, _, _, _ = model.length_regulator( state.ref_semantic, ylens=target2_lengths, n_quantizers=3, f0=None ) # 条件結合 cat_condition = torch.cat([prompt_condition, cond], dim=1) # CFM inference (inference.py line 373-376) with torch.no_grad(): vc_target = model.cfm.inference( cat_condition, torch.LongTensor([cat_condition.size(1)]).to(device), state.ref_mel, state.style_embed, None, 10, # diffusion_steps inference_cfg_rate=0.7 ) # プロンプト部分削除 vc_target = vc_target[:, :, state.ref_mel.size(-1):] # Vocoder (inference.py line 378) with torch.no_grad(): vc_wave = vocoder_fn(vc_target.float()).squeeze() vc_wave = vc_wave[None, :] # numpy変換 output_wave = vc_wave[0].cpu().numpy() # int16に戻す output_int16 = (output_wave * 32767).clip(-32768, 32767).astype(np.int16) return output_int16 # ============================================================================= # Endpoints (architecture.md Section 3.2) # ============================================================================= @app.get("/health") async def health_check(): """3.2.1 GET /health""" return {"status": "ok"} @app.post("/session", response_model=SessionCreateResponse) async def create_session(body: SessionCreateRequest): """ 3.2.2 POST /session model_ref.md Section 2.2(A) """ session_id = str(uuid.uuid4()) state = SessionState( sample_rate=body.sample_rate, tgt_speaker_id=body.tgt_speaker_id ) state.chunk_len_ms = body.chunk_len_ms state.overlap_ms = body.overlap_ms # 参照音声設定 (model_ref.md Section 3.2) if not body.use_uploaded_ref: preset_id = body.ref_preset_id or DEFAULT_REF_PRESET if preset_id is None: raise HTTPException(status_code=400, detail="ref_preset_id or use_uploaded_ref=true required") wav_path = download_ref_preset(preset_id) prepare_reference_audio(wav_path, state) SESSIONS[session_id] = state return SessionCreateResponse( session_id=session_id, sample_rate=body.sample_rate, chunk_len_ms=body.chunk_len_ms, overlap_ms=body.overlap_ms, ) @app.post("/session/ref") async def upload_ref_audio( session_id: str = Form(...), ref_audio: UploadFile = File(...) ): """ model_ref.md Section 2.2(B) """ if session_id not in SESSIONS: raise HTTPException(status_code=400, detail="Invalid session_id") state = SESSIONS[session_id] # 一時ファイル保存 import tempfile with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: content = await ref_audio.read() tmp.write(content) tmp_path = tmp.name try: prepare_reference_audio(tmp_path, state) finally: os.unlink(tmp_path) state.last_access_ts = time.time() return {"status": "ok"} @app.post("/chunk") async def process_chunk( session_id: str = Form(...), chunk_id: int = Form(...), audio: UploadFile = File(...) ): """ 3.2.3 POST /chunk architecture.md Section 3.2.3 サーバ内部処理フロー """ if session_id not in SESSIONS: raise HTTPException(status_code=400, detail="Invalid session_id") state = SESSIONS[session_id] if chunk_id < 0: raise HTTPException(status_code=400, detail="chunk_id must be non-negative") # Step 2: 音声読み込み audio_bytes = await audio.read() chunk_np, chunk_sr = load_wav_to_numpy(audio_bytes, target_sr=state.sample_rate) # Step 3: サンプルレートチェック if chunk_sr != state.sample_rate: raise HTTPException( status_code=400, detail=f"Sample rate mismatch: expected {state.sample_rate}, got {chunk_sr}" ) # Step 4: Seed-VCで変換 converted = seed_vc_infer(chunk_np, chunk_sr, state) # Step 5: クロスフェード fade_len = int(model_sr * state.overlap_ms / 1000) output = crossfade(state.last_output_tail, converted, fade_len) # Step 6: tail更新 if len(output) >= fade_len: state.last_output_tail = output[-fade_len:].copy() else: state.last_output_tail = output.copy() state.last_access_ts = time.time() # Step 7: WAVエンコード wav_bytes = numpy_to_wav_bytes(output, model_sr) return Response( content=wav_bytes, media_type="audio/wav", headers={"X-Chunk-Id": str(chunk_id)} ) @app.post("/end") async def end_session(body: SessionEndRequest): """3.2.4 POST /end""" SESSIONS.pop(body.session_id, None) return {"status": "ended"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)