Dia2-2B / dia2 /assets.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from huggingface_hub import hf_hub_download
ASSET_MANIFEST = os.environ.get("DIA2_ASSET_MANIFEST", "dia2_assets.json")
@dataclass(frozen=True)
class AssetBundle:
config_path: str
weights_path: str
tokenizer_id: Optional[str]
mimi_id: Optional[str]
repo_id: Optional[str]
def resolve_assets(
*,
repo: Optional[str],
config_path: Optional[str | Path],
weights_path: Optional[str | Path],
manifest_name: Optional[str] = None,
) -> AssetBundle:
repo_id = repo
manifest_name = manifest_name or ASSET_MANIFEST
if repo_id and (config_path or weights_path):
raise ValueError("Provide either repo or config+weights, not both")
if config_path is None or weights_path is None:
if repo_id is None:
raise ValueError("Must specify repo or config+weights")
manifest = load_manifest(repo_id, manifest_name)
config_name = manifest.get("config", "config.json")
weights_name = manifest.get("weights", "model.safetensors")
config_local = hf_hub_download(repo_id, config_name)
weights_local = hf_hub_download(repo_id, weights_name)
return AssetBundle(
config_path=config_local,
weights_path=weights_local,
tokenizer_id=manifest.get("tokenizer") or repo_id,
mimi_id=manifest.get("mimi"),
repo_id=repo_id,
)
return AssetBundle(str(config_path), str(weights_path), None, None, repo_id)
def load_manifest(repo_id: str, manifest_name: str) -> dict:
if not manifest_name:
return {}
try:
path = hf_hub_download(repo_id, manifest_name)
except Exception:
return {}
try:
return json.loads(Path(path).read_text())
except json.JSONDecodeError:
return {}
__all__ = ["AssetBundle", "ASSET_MANIFEST", "resolve_assets", "load_manifest"]