import os, json from typing import List, Dict, Any, Optional from PIL import Image import torch import gradio as gr import spaces from huggingface_hub import snapshot_download from diffusers import ( StableDiffusionXLPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, ) # Config (set in Space Secrets if private) MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora").strip() CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors").strip() HF_TOKEN = os.getenv("HF_TOKEN", None) DO_WARMUP = os.getenv("WARMUP", "1") == "1" # set WARMUP=0 to skip the first warmup call # Optional override: JSON string for LoRA manifest (same shape as loras.json) LORAS_JSON = os.getenv("LORAS_JSON", "").strip() # Where snapshot_download caches the repo in the container REPO_DIR = "/home/user/model" SCHEDULERS = { "default": None, "euler_a": EulerAncestralDiscreteScheduler, "euler": EulerDiscreteScheduler, "ddim": DDIMScheduler, "lms": LMSDiscreteScheduler, "pndm": PNDMScheduler, "dpmpp_2m": DPMSolverMultistepScheduler, } # Globals populated at startup pipe = None IS_SDXL = True LORA_MANIFEST: Dict[str, Dict[str, str]] = {} INIT_ERROR: Optional[str] = None def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]: """Manifest load order: 1) Environment variable LORAS_JSON (if provided) 2) loras.json inside the downloaded model repo 3) loras.json at the Space root (next to app.py) 4) Built-in fallback with MoriiMee_Gothic you provided """ # 1) From env JSON if LORAS_JSON: try: parsed = json.loads(LORAS_JSON) if isinstance(parsed, dict): return parsed except Exception as e: print(f"[WARN] Failed to parse LORAS_JSON: {e}") # 2) From repo repo_manifest = os.path.join(repo_dir, "loras.json") if os.path.exists(repo_manifest): try: with open(repo_manifest, "r", encoding="utf-8") as f: parsed = json.load(f) if isinstance(parsed, dict): return parsed except Exception as e: print(f"[WARN] Failed to parse repo loras.json: {e}") # 3) From Space root local_manifest = os.path.join(os.getcwd(), "loras.json") if os.path.exists(local_manifest): try: with open(local_manifest, "r", encoding="utf-8") as f: parsed = json.load(f) if isinstance(parsed, dict): return parsed except Exception as e: print(f"[WARN] Failed to parse local loras.json: {e}") # 4) Built-in fallback: your MoriiMee Gothic LoRA print("[INFO] Using built-in LoRA fallback manifest.") return { "MoriiMee_Gothic": { "repo": "LyliaEngine/MoriiMee_Gothic_Niji_Style_Illustrious_r1", "weight_name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1.safetensors" } } def bootstrap_model(): """ Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint, keeping weights on CPU; ZeroGPU attaches GPU only inside @spaces.GPU calls. """ global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR INIT_ERROR = None if not MODEL_REPO_ID or not CHECKPOINT_FILENAME: INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME." print(f"[ERROR] {INIT_ERROR}") return try: local_dir = snapshot_download( repo_id=MODEL_REPO_ID, token=HF_TOKEN, local_dir=REPO_DIR, ignore_patterns=["*.md"], ) except Exception as e: INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}" print(f"[ERROR] {INIT_ERROR}") return ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME) if not os.path.exists(ckpt_path): INIT_ERROR = f"Checkpoint not found at {ckpt_path}. Check CHECKPOINT_FILENAME." print(f"[ERROR] {INIT_ERROR}") return try: # Attempt SDXL first (text_encoder_2 present) _pipe = StableDiffusionXLPipeline.from_single_file( ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False ) sdxl = True except Exception: try: _pipe = StableDiffusionPipeline.from_single_file( ckpt_path, torch_dtype=torch.float16, use_safetensors=True ) sdxl = False except Exception as e: INIT_ERROR = f"Failed to load pipeline: {e}" print(f"[ERROR] {INIT_ERROR}") return if hasattr(_pipe, "enable_attention_slicing"): _pipe.enable_attention_slicing("max") if hasattr(_pipe, "enable_vae_slicing"): _pipe.enable_vae_slicing() if hasattr(_pipe, "set_progress_bar_config"): _pipe.set_progress_bar_config(disable=True) manifest = load_lora_manifest(local_dir) print(f"[INFO] LoRAs available: {list(manifest.keys())}") # Publish pipe = _pipe IS_SDXL = sdxl LORA_MANIFEST = manifest def apply_loras(selected: List[str], scale: float, repo_dir: str): if not selected or scale <= 0: return for name in selected: meta = LORA_MANIFEST.get(name) if not meta: print(f"[WARN] Requested LoRA '{name}' not in manifest.") continue try: if "path" in meta: pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name) else: pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name) print(f"[INFO] Loaded LoRA: {name}") except Exception as e: print(f"[WARN] LoRA load failed for {name}: {e}") try: pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected)) print(f"[INFO] Activated LoRAs: {selected} at scale {scale}") except Exception as e: print(f"[WARN] set_adapters failed: {e}") @spaces.GPU def txt2img( prompt: str, negative: str, width: int, height: int, steps: int, guidance: float, images: int, seed: Optional[int], scheduler: str, loras: List[str], lora_scale: float, fuse_lora: bool, ): if pipe is None: raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}") local_device = "cuda" if torch.cuda.is_available() else "cpu" pipe.to(local_device) if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None: try: pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config) except Exception as e: print(f"[WARN] Scheduler switch failed: {e}") apply_loras(loras, lora_scale, REPO_DIR) if fuse_lora and loras: try: pipe.fuse_lora(lora_scale=float(lora_scale)) except Exception as e: print(f"[WARN] fuse_lora failed: {e}") generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None kwargs: Dict[str, Any] = dict( prompt=prompt or "", negative_prompt=negative or None, width=int(width), height=int(height), num_inference_steps=int(steps), guidance_scale=float(guidance), num_images_per_prompt=int(images), generator=generator, ) with torch.inference_mode(): out = pipe(**kwargs) return out.images def warmup(): try: _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False) except Exception as e: print(f"[WARN] Warmup failed: {e}") # UI with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo: status = gr.Markdown("") with gr.Row(): prompt = gr.Textbox(label="Prompt", lines=3) negative = gr.Textbox(label="Negative Prompt", lines=3) with gr.Row(): width = gr.Slider(256, 1536, 1024, step=64, label="Width") height = gr.Slider(256, 1536, 1024, step=64, label="Height") with gr.Row(): steps = gr.Slider(5, 80, 30, step=1, label="Steps") guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance") images = gr.Slider(1, 4, 1, step=1, label="Images") with gr.Row(): seed = gr.Number(value=None, precision=0, label="Seed (blank=random)") scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler") lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)") lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale") fuse = gr.Checkbox(label="Fuse LoRA (faster after load)") btn = gr.Button("Generate", variant="primary", interactive=False) gallery = gr.Gallery(columns=4, height=420) def _startup(): bootstrap_model() if INIT_ERROR: return gr.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.update(choices=[]), gr.update(interactive=False) msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})" # Populate LoRA choices (manifest could come from repo, Space file, or built-in fallback) return gr.update(value=msg), gr.update(choices=list(LORA_MANIFEST.keys())), gr.update(interactive=True) demo.load(_startup, outputs=[status, lora_names, btn]) if DO_WARMUP: demo.load(lambda: warmup(), inputs=None, outputs=None) btn.click( txt2img, inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse], outputs=[gallery], api_name="txt2img", concurrency_limit=1, concurrency_id="gpu_queue", ) demo.queue(max_size=32, default_concurrency_limit=1).launch()