DB2169 commited on
Commit
1d5499e
·
verified ·
1 Parent(s): 5e98324

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -47
app.py CHANGED
@@ -3,7 +3,8 @@ from typing import List, Dict, Any, Optional
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
- from huggingface_hub import snapshot_download # pulls your repo at startup
 
7
  from diffusers import (
8
  StableDiffusionXLPipeline,
9
  StableDiffusionPipeline,
@@ -15,14 +16,15 @@ from diffusers import (
15
  PNDMScheduler,
16
  )
17
 
18
- # -------- Configuration (set these in Space Secrets for private repos) --------
19
- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora") # e.g., your repo id
20
- CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors") # exact base ckpt filename
21
- HF_TOKEN = os.getenv("HF_TOKEN", None) # optional if repo is public
 
22
 
23
  # -------- Runtime defaults --------
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
- dtype = torch.float16 if device == "cuda" else torch.float32
26
 
27
  SCHEDULERS = {
28
  "default": None,
@@ -34,48 +36,44 @@ SCHEDULERS = {
34
  "dpmpp_2m": DPMSolverMultistepScheduler,
35
  }
36
 
37
- # Globals filled on startup
38
  pipe = None
39
  IS_SDXL = True
40
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
41
- REPO_DIR = "/home/user/model" # cached snapshot location in Spaces
42
 
43
- # -------- Model bootstrap --------
 
44
  def bootstrap_model():
45
  global pipe, IS_SDXL, LORA_MANIFEST
46
- # Download/copy all repo files locally (weights + manifest)
47
  local_dir = snapshot_download(
48
  repo_id=MODEL_REPO_ID,
49
  token=HF_TOKEN,
50
  local_dir=REPO_DIR,
51
  ignore_patterns=["*.md"],
52
- ) # downloads your model repo into the container cache [web:362]
53
-
54
  ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME)
55
  if not os.path.exists(ckpt_path):
56
  raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
57
 
58
- # Try SDXL single-file, then SD 1.x/2.x single-file
59
  try:
60
  _pipe = StableDiffusionXLPipeline.from_single_file(
61
- ckpt_path, torch_dtype=dtype, use_safetensors=True, add_watermarker=False
62
- ) # SDXL loader [web:104]
63
  sdxl = True
64
  except Exception:
65
  _pipe = StableDiffusionPipeline.from_single_file(
66
- ckpt_path, torch_dtype=dtype, use_safetensors=True
67
- ) # SD 1.x/2.x fallback [web:104]
68
  sdxl = False
69
 
 
70
  if hasattr(_pipe, "enable_attention_slicing"):
71
  _pipe.enable_attention_slicing("max")
72
  if hasattr(_pipe, "enable_vae_slicing"):
73
  _pipe.enable_vae_slicing()
74
  if hasattr(_pipe, "set_progress_bar_config"):
75
  _pipe.set_progress_bar_config(disable=True)
76
- _pipe.to(device)
77
 
78
- # Load LoRA manifest if present
79
  man_path = os.path.join(local_dir, "loras.json")
80
  manifest = {}
81
  if os.path.exists(man_path):
@@ -85,20 +83,21 @@ def bootstrap_model():
85
  except Exception as e:
86
  print(f"[WARN] Failed to parse loras.json: {e}")
87
 
88
- # Publish globals
89
- return _pipe, sdxl, manifest
 
 
90
 
91
- def apply_loras(selected: List[str], scale: float):
92
  if not selected or scale <= 0:
93
  return
94
- # Each selected LoRA should exist in manifest; supports repo/weight_name or local 'path'
95
  for name in selected:
96
  meta = LORA_MANIFEST.get(name)
97
  if not meta:
98
  continue
99
  try:
100
  if "path" in meta:
101
- pipe.load_lora_weights(os.path.join(REPO_DIR, meta["path"]), adapter_name=name)
102
  else:
103
  pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
104
  except Exception as e:
@@ -108,6 +107,8 @@ def apply_loras(selected: List[str], scale: float):
108
  except Exception as e:
109
  print(f"[WARN] set_adapters failed: {e}")
110
 
 
 
111
  def txt2img(
112
  prompt: str,
113
  negative: str,
@@ -122,6 +123,11 @@ def txt2img(
122
  lora_scale: float,
123
  fuse_lora: bool,
124
  ):
 
 
 
 
 
125
  # Scheduler swap
126
  if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
127
  try:
@@ -129,8 +135,8 @@ def txt2img(
129
  except Exception as e:
130
  print(f"[WARN] Scheduler switch failed: {e}")
131
 
132
- # Apply LoRAs
133
- apply_loras(loras, lora_scale)
134
  if fuse_lora and loras:
135
  try:
136
  pipe.fuse_lora(lora_scale=float(lora_scale))
@@ -138,7 +144,7 @@ def txt2img(
138
  print(f"[WARN] fuse_lora failed: {e}")
139
 
140
  # Determinism
141
- generator = torch.Generator(device=device).manual_seed(int(seed)) if seed not in (None, "") else None
142
 
143
  kwargs: Dict[str, Any] = dict(
144
  prompt=prompt or "",
@@ -150,60 +156,66 @@ def txt2img(
150
  num_images_per_prompt=int(images),
151
  generator=generator,
152
  )
153
- out = pipe(**kwargs)
 
154
  return out.images
155
 
 
156
  def warmup():
157
- # Small, fast call to initialize kernels/graphs so first user is instant
158
  try:
159
  _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False)
160
  except Exception as e:
161
  print(f"[WARN] Warmup failed: {e}")
162
 
163
- # --------------------------- Build the UI inside Blocks ---------------------------
164
- with gr.Blocks(title="SDXL Space (single-file, LoRA-ready)") as demo: # Blocks context required for events [web:371]
165
- gr.Markdown("### SDXL text‑to‑image (singlefile checkpoint) with optional LoRAs") # UI heading [web:147]
 
 
166
  with gr.Row():
167
  prompt = gr.Textbox(label="Prompt", lines=3)
168
  negative = gr.Textbox(label="Negative Prompt", lines=3)
 
169
  with gr.Row():
170
  width = gr.Slider(256, 1536, 1024, step=64, label="Width")
171
  height = gr.Slider(256, 1536, 1024, step=64, label="Height")
 
172
  with gr.Row():
173
  steps = gr.Slider(5, 80, 30, step=1, label="Steps")
174
  guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance")
175
  images = gr.Slider(1, 4, 1, step=1, label="Images")
 
176
  with gr.Row():
177
  seed = gr.Number(value=None, precision=0, label="Seed (blank=random)")
178
  scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler")
179
 
180
- # LoRA multi-select populated after manifest loads
181
- lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json)")
182
  lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
183
  fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
184
 
185
  btn = gr.Button("Generate", variant="primary")
186
  gallery = gr.Gallery(columns=4, height=420)
187
 
188
- # Startup loader (runs at app load)
189
  def _startup():
190
- global pipe, IS_SDXL, LORA_MANIFEST
191
- pipe, IS_SDXL, LORA_MANIFEST = bootstrap_model()
192
  return gr.CheckboxGroup.update(choices=list(LORA_MANIFEST.keys()))
193
- demo.load(_startup, outputs=[lora_names]) # fill LoRA list once model is ready [web:147]
194
 
195
- # Warm-up pass after model load for snappy first request
196
- demo.load(lambda: warmup(), inputs=None, outputs=None) # performance warmup [web:356]
 
 
 
197
 
198
- # Wire the button click inside Blocks, with per-event concurrency control
199
  btn.click(
200
  txt2img,
201
  inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
202
  outputs=[gallery],
203
  api_name="txt2img",
204
- concurrency_limit=1, # one GPU job at a time for SDXL
205
- concurrency_id="gpu_queue", # shared queue id if you add more GPU events
206
- ) # per-event queue parameters in Gradio 4.x [web:388][web:373]
207
 
208
- # Global queue config (no deprecated args)
209
- demo.queue(max_size=32, default_concurrency_limit=1).launch() # supported queue pattern in Gradio 4.x [web:373][web:381]
 
3
  from PIL import Image
4
  import torch
5
  import gradio as gr
6
+ import spaces # ZeroGPU: decorate GPU-bound functions
7
+ from huggingface_hub import snapshot_download
8
  from diffusers import (
9
  StableDiffusionXLPipeline,
10
  StableDiffusionPipeline,
 
16
  PNDMScheduler,
17
  )
18
 
19
+ # -------- Configuration (set as Space Secrets if needed) --------
20
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/CyberPony_Lora") # your model repo id
21
+ CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "SAFETENSORS_FILENAME.safetensors") # exact .safetensors name
22
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # only required for private repos
23
+ DO_WARMUP = os.getenv("WARMUP", "1") == "1" # set to "0" to disable warmup
24
 
25
  # -------- Runtime defaults --------
26
+ REPO_DIR = "/home/user/model" # local cache mount for snapshot_download
27
+ # Defer CUDA detection to GPU-run function for ZeroGPU; do not move to CUDA at import time
28
 
29
  SCHEDULERS = {
30
  "default": None,
 
36
  "dpmpp_2m": DPMSolverMultistepScheduler,
37
  }
38
 
39
+ # Globals populated on startup
40
  pipe = None
41
  IS_SDXL = True
42
  LORA_MANIFEST: Dict[str, Dict[str, str]] = {}
 
43
 
44
+
45
+ # -------- Model bootstrap (CPU) --------
46
  def bootstrap_model():
47
  global pipe, IS_SDXL, LORA_MANIFEST
 
48
  local_dir = snapshot_download(
49
  repo_id=MODEL_REPO_ID,
50
  token=HF_TOKEN,
51
  local_dir=REPO_DIR,
52
  ignore_patterns=["*.md"],
53
+ )
 
54
  ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME)
55
  if not os.path.exists(ckpt_path):
56
  raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
57
 
 
58
  try:
59
  _pipe = StableDiffusionXLPipeline.from_single_file(
60
+ ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False
61
+ )
62
  sdxl = True
63
  except Exception:
64
  _pipe = StableDiffusionPipeline.from_single_file(
65
+ ckpt_path, torch_dtype=torch.float16, use_safetensors=True
66
+ )
67
  sdxl = False
68
 
69
+ # Keep on CPU until GPU-decorated call (ZeroGPU attaches GPU on demand)
70
  if hasattr(_pipe, "enable_attention_slicing"):
71
  _pipe.enable_attention_slicing("max")
72
  if hasattr(_pipe, "enable_vae_slicing"):
73
  _pipe.enable_vae_slicing()
74
  if hasattr(_pipe, "set_progress_bar_config"):
75
  _pipe.set_progress_bar_config(disable=True)
 
76
 
 
77
  man_path = os.path.join(local_dir, "loras.json")
78
  manifest = {}
79
  if os.path.exists(man_path):
 
83
  except Exception as e:
84
  print(f"[WARN] Failed to parse loras.json: {e}")
85
 
86
+ pipe = _pipe
87
+ IS_SDXL = sdxl
88
+ LORA_MANIFEST = manifest
89
+
90
 
91
+ def apply_loras(selected: List[str], scale: float, repo_dir: str):
92
  if not selected or scale <= 0:
93
  return
 
94
  for name in selected:
95
  meta = LORA_MANIFEST.get(name)
96
  if not meta:
97
  continue
98
  try:
99
  if "path" in meta:
100
+ pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name)
101
  else:
102
  pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name)
103
  except Exception as e:
 
107
  except Exception as e:
108
  print(f"[WARN] set_adapters failed: {e}")
109
 
110
+
111
+ @spaces.GPU # ZeroGPU: allocate/attach GPU for this function call
112
  def txt2img(
113
  prompt: str,
114
  negative: str,
 
123
  lora_scale: float,
124
  fuse_lora: bool,
125
  ):
126
+ # Resolve device inside GPU context
127
+ local_device = "cuda" if torch.cuda.is_available() else "cpu"
128
+ local_dtype = torch.float16 if local_device == "cuda" else torch.float32
129
+ pipe.to(local_device)
130
+
131
  # Scheduler swap
132
  if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None:
133
  try:
 
135
  except Exception as e:
136
  print(f"[WARN] Scheduler switch failed: {e}")
137
 
138
+ # LoRAs
139
+ apply_loras(loras, lora_scale, REPO_DIR)
140
  if fuse_lora and loras:
141
  try:
142
  pipe.fuse_lora(lora_scale=float(lora_scale))
 
144
  print(f"[WARN] fuse_lora failed: {e}")
145
 
146
  # Determinism
147
+ generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None
148
 
149
  kwargs: Dict[str, Any] = dict(
150
  prompt=prompt or "",
 
156
  num_images_per_prompt=int(images),
157
  generator=generator,
158
  )
159
+ with torch.inference_mode():
160
+ out = pipe(**kwargs)
161
  return out.images
162
 
163
+
164
  def warmup():
 
165
  try:
166
  _ = txt2img("warmup", "", 512, 512, 4, 4.0, 1, 1234, "default", [], 0.0, False)
167
  except Exception as e:
168
  print(f"[WARN] Warmup failed: {e}")
169
 
170
+
171
+ # --------------------------- Build UI ---------------------------
172
+ with gr.Blocks(title="SDXL Space (ZeroGPU, single-file checkpoint, LoRA-ready)") as demo:
173
+ gr.Markdown("### SDXL text‑to‑image with single‑file checkpoint and optional LoRAs")
174
+
175
  with gr.Row():
176
  prompt = gr.Textbox(label="Prompt", lines=3)
177
  negative = gr.Textbox(label="Negative Prompt", lines=3)
178
+
179
  with gr.Row():
180
  width = gr.Slider(256, 1536, 1024, step=64, label="Width")
181
  height = gr.Slider(256, 1536, 1024, step=64, label="Height")
182
+
183
  with gr.Row():
184
  steps = gr.Slider(5, 80, 30, step=1, label="Steps")
185
  guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance")
186
  images = gr.Slider(1, 4, 1, step=1, label="Images")
187
+
188
  with gr.Row():
189
  seed = gr.Number(value=None, precision=0, label="Seed (blank=random)")
190
  scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler")
191
 
192
+ lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)")
 
193
  lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale")
194
  fuse = gr.Checkbox(label="Fuse LoRA (faster after load)")
195
 
196
  btn = gr.Button("Generate", variant="primary")
197
  gallery = gr.Gallery(columns=4, height=420)
198
 
199
+ # Load model + manifest, then populate LoRA choices
200
  def _startup():
201
+ bootstrap_model()
 
202
  return gr.CheckboxGroup.update(choices=list(LORA_MANIFEST.keys()))
 
203
 
204
+ demo.load(_startup, outputs=[lora_names])
205
+
206
+ # Optional warmup (costs a tiny GPU run on first boot); set WARMUP=0 to skip
207
+ if DO_WARMUP:
208
+ demo.load(lambda: warmup(), inputs=None, outputs=None)
209
 
210
+ # Event binding inside Blocks; one GPU job at a time for SDXL
211
  btn.click(
212
  txt2img,
213
  inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse],
214
  outputs=[gallery],
215
  api_name="txt2img",
216
+ concurrency_limit=1,
217
+ concurrency_id="gpu_queue",
218
+ )
219
 
220
+ # Global queue limits for Gradio 4.x
221
+ demo.queue(max_size=32, default_concurrency_limit=1).launch()