Spaces:
Sleeping
Sleeping
| import os, json | |
| from typing import List, Dict, Any, Optional | |
| from PIL import Image | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| from diffusers import ( | |
| StableDiffusionXLPipeline, | |
| StableDiffusionPipeline, | |
| DPMSolverMultistepScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| EulerDiscreteScheduler, | |
| DDIMScheduler, | |
| LMSDiscreteScheduler, | |
| PNDMScheduler, | |
| ) | |
| # Config (set in Space Secrets if private) | |
| MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora").strip() | |
| CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors").strip() | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| DO_WARMUP = os.getenv("WARMUP", "1") == "1" # set WARMUP=0 to skip the first warmup call | |
| # Optional override: JSON string for LoRA manifest (same shape as loras.json) | |
| LORAS_JSON = os.getenv("LORAS_JSON", "").strip() | |
| # Where snapshot_download caches the repo in the container | |
| REPO_DIR = "/home/user/model" | |
| SCHEDULERS = { | |
| "default": None, | |
| "euler_a": EulerAncestralDiscreteScheduler, | |
| "euler": EulerDiscreteScheduler, | |
| "ddim": DDIMScheduler, | |
| "lms": LMSDiscreteScheduler, | |
| "pndm": PNDMScheduler, | |
| "dpmpp_2m": DPMSolverMultistepScheduler, | |
| } | |
| # Globals populated at startup | |
| pipe = None | |
| IS_SDXL = True | |
| LORA_MANIFEST: Dict[str, Dict[str, str]] = {} | |
| INIT_ERROR: Optional[str] = None | |
| def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]: | |
| """Manifest load order: | |
| 1) Environment variable LORAS_JSON (if provided) | |
| 2) loras.json inside the downloaded model repo | |
| 3) loras.json at the Space root (next to app.py) | |
| 4) Built-in fallback with MoriiMee_Gothic you provided | |
| """ | |
| # 1) From env JSON | |
| if LORAS_JSON: | |
| try: | |
| parsed = json.loads(LORAS_JSON) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except Exception as e: | |
| print(f"[WARN] Failed to parse LORAS_JSON: {e}") | |
| # 2) From repo | |
| repo_manifest = os.path.join(repo_dir, "loras.json") | |
| if os.path.exists(repo_manifest): | |
| try: | |
| with open(repo_manifest, "r", encoding="utf-8") as f: | |
| parsed = json.load(f) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except Exception as e: | |
| print(f"[WARN] Failed to parse repo loras.json: {e}") | |
| # 3) From Space root | |
| local_manifest = os.path.join(os.getcwd(), "loras.json") | |
| if os.path.exists(local_manifest): | |
| try: | |
| with open(local_manifest, "r", encoding="utf-8") as f: | |
| parsed = json.load(f) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except Exception as e: | |
| print(f"[WARN] Failed to parse local loras.json: {e}") | |
| # 4) Built-in fallback: your MoriiMee Gothic LoRA | |
| print("[INFO] Using built-in LoRA fallback manifest.") | |
| return { | |
| "MoriiMee_Gothic": { | |
| "repo": "LyliaEngine/MoriiMee_Gothic_Niji_Style_Illustrious_r1", | |
| "weight_name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1.safetensors" | |
| } | |
| } | |
| def bootstrap_model(): | |
| """ | |
| Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint, | |
| keeping weights on CPU; ZeroGPU attaches GPU only inside @spaces.GPU calls. | |
| """ | |
| global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR | |
| INIT_ERROR = None | |
| if not MODEL_REPO_ID or not CHECKPOINT_FILENAME: | |
| INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME." | |
| print(f"[ERROR] {INIT_ERROR}") | |
| return | |
| try: | |
| local_dir = snapshot_download( | |
| repo_id=MODEL_REPO_ID, | |
| token=HF_TOKEN, | |
| local_dir=REPO_DIR, | |
| ignore_patterns=["*.md"], | |
| ) | |
| except Exception as e: | |
| INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}" | |
| print(f"[ERROR] {INIT_ERROR}") | |
| return | |
| ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME) | |
| if not os.path.exists(ckpt_path): | |
| INIT_ERROR = f"Checkpoint not found at {ckpt_path}. Check CHECKPOINT_FILENAME." | |
| print(f"[ERROR] {INIT_ERROR}") | |
| return | |
| try: | |
| # Attempt SDXL first (text_encoder_2 present) | |
| _pipe = StableDiffusionXLPipeline.from_single_file( | |
| ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False | |
| ) | |
| sdxl = True | |
| except Exception: | |
| try: | |
| _pipe = StableDiffusionPipeline.from_single_file( | |
| ckpt_path, torch_dtype=torch.float16, use_safetensors=True | |
| ) | |
| sdxl = False | |
| except Exception as e: | |
| INIT_ERROR = f"Failed to load pipeline: {e}" | |
| print(f"[ERROR] {INIT_ERROR}") | |
| return | |
| if hasattr(_pipe, "enable_attention_slicing"): | |
| _pipe.enable_attention_slicing("max") | |
| if hasattr(_pipe, "enable_vae_slicing"): | |
| _pipe.enable_vae_slicing() | |
| if hasattr(_pipe, "set_progress_bar_config"): | |
| _pipe.set_progress_bar_config(disable=True) | |
| manifest = load_lora_manifest(local_dir) | |
| print(f"[INFO] LoRAs available: {list(manifest.keys())}") | |
| # Publish | |
| pipe = _pipe | |
| IS_SDXL = sdxl | |
| LORA_MANIFEST = manifest | |
| def apply_loras(selected: List[str], scale: float, repo_dir: str): | |
| if not selected or scale <= 0: | |
| return | |
| for name in selected: | |
| meta = LORA_MANIFEST.get(name) | |
| if not meta: | |
| print(f"[WARN] Requested LoRA '{name}' not in manifest.") | |
| continue | |
| try: | |
| if "path" in meta: | |
| pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name) | |
| else: | |
| pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name) | |
| print(f"[INFO] Loaded LoRA: {name}") | |
| except Exception as e: | |
| print(f"[WARN] LoRA load failed for {name}: {e}") | |
| try: | |
| pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected)) | |
| print(f"[INFO] Activated LoRAs: {selected} at scale {scale}") | |
| except Exception as e: | |
| print(f"[WARN] set_adapters failed: {e}") | |
| def txt2img( | |
| prompt: str, | |
| negative: str, | |
| width: int, | |
| height: int, | |
| steps: int, | |
| guidance: float, | |
| images: int, | |
| seed: Optional[int], | |
| scheduler: str, | |
| loras: List[str], | |
| lora_scale: float, | |
| fuse_lora: bool, | |
| ): | |
| if pipe is None: | |
| raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}") | |
| local_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe.to(local_device) | |
| if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None: | |
| try: | |
| pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config) | |
| except Exception as e: | |
| print(f"[WARN] Scheduler switch failed: {e}") | |
| apply_loras(loras, lora_scale, REPO_DIR) | |
| if fuse_lora and loras: | |
| try: | |
| pipe.fuse_lora(lora_scale=float(lora_scale)) | |
| except Exception as e: | |
| print(f"[WARN] fuse_lora failed: {e}") | |
| generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None | |
| kwargs: Dict[str, Any] = dict( | |
| prompt=prompt or "", | |
| negative_prompt=negative or None, | |
| width=int(width), | |
| height=int(height), | |
| num_inference_steps=int(steps), | |
| guidance_scale=float(guidance), | |
| num_images_per_prompt=int(images), | |
| generator=generator, | |
| ) | |
| with torch.inference_mode(): | |
| out = pipe(**kwargs) | |
| return out.images | |
| def warmup(): | |
| try: | |
| _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False) | |
| except Exception as e: | |
| print(f"[WARN] Warmup failed: {e}") | |
| # UI | |
| with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo: | |
| status = gr.Markdown("") | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt", lines=3) | |
| negative = gr.Textbox(label="Negative Prompt", lines=3) | |
| with gr.Row(): | |
| width = gr.Slider(256, 1536, 1024, step=64, label="Width") | |
| height = gr.Slider(256, 1536, 1024, step=64, label="Height") | |
| with gr.Row(): | |
| steps = gr.Slider(5, 80, 30, step=1, label="Steps") | |
| guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance") | |
| images = gr.Slider(1, 4, 1, step=1, label="Images") | |
| with gr.Row(): | |
| seed = gr.Number(value=None, precision=0, label="Seed (blank=random)") | |
| scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler") | |
| lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)") | |
| lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale") | |
| fuse = gr.Checkbox(label="Fuse LoRA (faster after load)") | |
| btn = gr.Button("Generate", variant="primary", interactive=False) | |
| gallery = gr.Gallery(columns=4, height=420) | |
| def _startup(): | |
| bootstrap_model() | |
| if INIT_ERROR: | |
| return gr.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.update(choices=[]), gr.update(interactive=False) | |
| msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})" | |
| # Populate LoRA choices (manifest could come from repo, Space file, or built-in fallback) | |
| return gr.update(value=msg), gr.update(choices=list(LORA_MANIFEST.keys())), gr.update(interactive=True) | |
| demo.load(_startup, outputs=[status, lora_names, btn]) | |
| if DO_WARMUP: | |
| demo.load(lambda: warmup(), inputs=None, outputs=None) | |
| btn.click( | |
| txt2img, | |
| inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse], | |
| outputs=[gallery], | |
| api_name="txt2img", | |
| concurrency_limit=1, | |
| concurrency_id="gpu_queue", | |
| ) | |
| demo.queue(max_size=32, default_concurrency_limit=1).launch() | |