DB2169 commited on
Commit
3f5ac2a
·
verified ·
1 Parent(s): 9441bd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -22
app.py CHANGED
@@ -1,10 +1,25 @@
1
- import os, io, json
2
  from typing import List, Dict, Any, Optional
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
- import spaces
7
- from huggingface_hub import snapshot_download, HfHubHTTPError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from diffusers import (
9
  StableDiffusionXLPipeline,
10
  StableDiffusionPipeline,
@@ -15,14 +30,6 @@ from diffusers import (
15
  LMSDiscreteScheduler,
16
  PNDMScheduler,
17
  )
18
-
19
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "").strip()
20
- CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "").strip()
21
- HF_TOKEN = os.getenv("HF_TOKEN", None)
22
- DO_WARMUP = os.getenv("WARMUP", "1") == "1"
23
-
24
- REPO_DIR = "/home/user/model"
25
-
26
  SCHEDULERS = {
27
  "default": None,
28
  "euler_a": EulerAncestralDiscreteScheduler,
@@ -33,16 +40,23 @@ SCHEDULERS = {
33
  "dpmpp_2m": DPMSolverMultistepScheduler,
34
  }
35
 
 
36
  pipe = None
37
  IS_SDXL = True
38
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
39
- INIT_ERROR: Optional[str] = None # expose bootstrap error to UI
40
 
 
41
  def bootstrap_model():
 
 
 
 
42
  global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
43
  INIT_ERROR = None
 
44
  if not MODEL_REPO_ID or not CHECKPOINT_FILENAME:
45
- INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME environment variables."
46
  print(f"[ERROR] {INIT_ERROR}")
47
  return
48
 
@@ -53,12 +67,8 @@ def bootstrap_model():
53
  local_dir=REPO_DIR,
54
  ignore_patterns=["*.md"],
55
  )
56
- except HfHubHTTPError as e:
57
- INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}"
58
- print(f"[ERROR] {INIT_ERROR}")
59
- return
60
  except Exception as e:
61
- INIT_ERROR = f"Unexpected error while downloading repo: {e}"
62
  print(f"[ERROR] {INIT_ERROR}")
63
  return
64
 
@@ -69,6 +79,7 @@ def bootstrap_model():
69
  return
70
 
71
  try:
 
72
  _pipe = StableDiffusionXLPipeline.from_single_file(
73
  ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
74
  )
@@ -84,6 +95,7 @@ def bootstrap_model():
84
  print(f"[ERROR] {INIT_ERROR}")
85
  return
86
 
 
87
  if hasattr(_pipe, "enable_attention_slicing"):
88
  _pipe.enable_attention_slicing("max")
89
  if hasattr(_pipe, "enable_vae_slicing"):
@@ -91,6 +103,7 @@ def bootstrap_model():
91
  if hasattr(_pipe, "set_progress_bar_config"):
92
  _pipe.set_progress_bar_config(disable=True)
93
 
 
94
  man_path = os.path.join(local_dir, "loras.json")
95
  manifest = {}
96
  if os.path.exists(man_path):
@@ -100,7 +113,7 @@ def bootstrap_model():
100
  except Exception as e:
101
  print(f"[WARN] Failed to parse loras.json: {e}")
102
 
103
- # publish
104
  global pipe, IS_SDXL, LORA_MANIFEST
105
  pipe = _pipe
106
  IS_SDXL = sdxl
@@ -125,6 +138,7 @@ def apply_loras(selected: List[str], scale: float, repo_dir: str):
125
  except Exception as e:
126
  print(f"[WARN] set_adapters failed: {e}")
127
 
 
128
  @spaces.GPU
129
  def txt2img(
130
  prompt: str,
@@ -144,15 +158,16 @@ def txt2img(
144
  raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}")
145
 
146
  local_device = "cuda" if torch.cuda.is_available() else "cpu"
147
- local_dtype = torch.float16 if local_device == "cuda" else torch.float32
148
  pipe.to(local_device)
149
 
 
150
  if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
151
  try:
152
  pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
153
  except Exception as e:
154
  print(f"[WARN] Scheduler switch failed: {e}")
155
 
 
156
  apply_loras(loras, lora_scale, REPO_DIR)
157
  if fuse_lora and loras:
158
  try:
@@ -182,8 +197,9 @@ def warmup():
182
  except Exception as e:
183
  print(f"[WARN] Warmup failed: {e}")
184
 
 
185
  with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
186
- status = gr.Markdown("") # show init status/errors
187
 
188
  with gr.Row():
189
  prompt = gr.Textbox(label="Prompt", lines=3)
@@ -206,7 +222,7 @@ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
206
  lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
207
  fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
208
 
209
- btn = gr.Button("Generate", variant="primary", interactive=False) # locked until model loads
210
  gallery = gr.Gallery(columns=4, height=420)
211
 
212
  def _startup():
@@ -230,4 +246,5 @@ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
230
  concurrency_id="gpu_queue",
231
  )
232
 
 
233
  demo.queue(max_size=32, default_concurrency_limit=1).launch()
 
1
+ import os, json
2
  from typing import List, Dict, Any, Optional
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
+ import spaces # ZeroGPU decorator
7
+ from huggingface_hub import snapshot_download
8
+
9
+ # ----------------- Config (set in Space Secrets if private) -----------------
10
+ # Your private repo that contains the base .safetensors and loras.json
11
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora").strip()
12
+ # Exact filename of the base checkpoint inside the repo (case-sensitive)
13
+ CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors").strip()
14
+ # Personal access token with read scope (required for private repos)
15
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
16
+ # Toggle first-boot warmup (GPU-allocating on ZeroGPU)
17
+ DO_WARMUP = os.getenv("WARMUP", "1") == "1"
18
+
19
+ # Where snapshot_download will cache the repo
20
+ REPO_DIR = "/home/user/model"
21
+
22
+ # Supported schedulers
23
  from diffusers import (
24
  StableDiffusionXLPipeline,
25
  StableDiffusionPipeline,
 
30
  LMSDiscreteScheduler,
31
  PNDMScheduler,
32
  )
 
 
 
 
 
 
 
 
33
  SCHEDULERS = {
34
  "default": None,
35
  "euler_a": EulerAncestralDiscreteScheduler,
 
40
  "dpmpp_2m": DPMSolverMultistepScheduler,
41
  }
42
 
43
+ # Globals populated at startup
44
  pipe = None
45
  IS_SDXL = True
46
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
47
+ INIT_ERROR: Optional[str] = None
48
 
49
+ # ----------------- Bootstrap (download + load on CPU) -----------------
50
  def bootstrap_model():
51
+ """
52
+ Downloads MODEL_REPO_ID into REPO_DIR and loads the single-file checkpoint.
53
+ Keeps pipeline on CPU; ZeroGPU attaches GPU inside the @spaces.GPU function.
54
+ """
55
  global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR
56
  INIT_ERROR = None
57
+
58
  if not MODEL_REPO_ID or not CHECKPOINT_FILENAME:
59
+ INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME."
60
  print(f"[ERROR] {INIT_ERROR}")
61
  return
62
 
 
67
  local_dir=REPO_DIR,
68
  ignore_patterns=["*.md"],
69
  )
 
 
 
 
70
  except Exception as e:
71
+ INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}"
72
  print(f"[ERROR] {INIT_ERROR}")
73
  return
74
 
 
79
  return
80
 
81
  try:
82
+ # Try SDXL first
83
  _pipe = StableDiffusionXLPipeline.from_single_file(
84
  ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
85
  )
 
95
  print(f"[ERROR] {INIT_ERROR}")
96
  return
97
 
98
+ # Light memory/perf tweaks
99
  if hasattr(_pipe, "enable_attention_slicing"):
100
  _pipe.enable_attention_slicing("max")
101
  if hasattr(_pipe, "enable_vae_slicing"):
 
103
  if hasattr(_pipe, "set_progress_bar_config"):
104
  _pipe.set_progress_bar_config(disable=True)
105
 
106
+ # Load LoRA manifest if present
107
  man_path = os.path.join(local_dir, "loras.json")
108
  manifest = {}
109
  if os.path.exists(man_path):
 
113
  except Exception as e:
114
  print(f"[WARN] Failed to parse loras.json: {e}")
115
 
116
+ # Publish globals
117
  global pipe, IS_SDXL, LORA_MANIFEST
118
  pipe = _pipe
119
  IS_SDXL = sdxl
 
138
  except Exception as e:
139
  print(f"[WARN] set_adapters failed: {e}")
140
 
141
+ # ----------------- Generation (GPU-attached under ZeroGPU) -----------------
142
  @spaces.GPU
143
  def txt2img(
144
  prompt: str,
 
158
  raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}")
159
 
160
  local_device = "cuda" if torch.cuda.is_available() else "cpu"
 
161
  pipe.to(local_device)
162
 
163
+ # Optional scheduler switch
164
  if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
165
  try:
166
  pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
167
  except Exception as e:
168
  print(f"[WARN] Scheduler switch failed: {e}")
169
 
170
+ # Apply LoRAs
171
  apply_loras(loras, lora_scale, REPO_DIR)
172
  if fuse_lora and loras:
173
  try:
 
197
  except Exception as e:
198
  print(f"[WARN] Warmup failed: {e}")
199
 
200
+ # ----------------- UI -----------------
201
  with gr.Blocks(title="SDXL Space (ZeroGPU, single-file, LoRA-ready)") as demo:
202
+ status = gr.Markdown("") # shows init result or errors
203
 
204
  with gr.Row():
205
  prompt = gr.Textbox(label="Prompt", lines=3)
 
222
  lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
223
  fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
224
 
225
+ btn = gr.Button("Generate", variant="primary", interactive=False)
226
  gallery = gr.Gallery(columns=4, height=420)
227
 
228
  def _startup():
 
246
  concurrency_id="gpu_queue",
247
  )
248
 
249
+ # Gradio 4.x queue config (no deprecated args)
250
  demo.queue(max_size=32, default_concurrency_limit=1).launch()