Yinhong Liu commited on
Commit
492cf8f
·
1 Parent(s): 30bc77b

model selection

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -14,25 +14,38 @@ else:
14
  torch_dtype = torch.float32
15
 
16
  MODEL_OPTIONS = {
17
- "Sana": "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
18
- "SD3": "stabilityai/stable-diffusion-3-medium",
19
- "Flux": "black-forest-labs/FLUX.1-dev"
 
 
 
 
 
 
 
 
 
 
20
  }
21
 
 
22
  def load_model(model_choice):
23
  model_repo_id = MODEL_OPTIONS[model_choice]
24
  # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
25
- if model_choice == 'Sana':
26
  pipe = SanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
27
- elif model_choice == 'SD3':
28
- pipe = StableDiffusion3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
29
  else:
30
  pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
31
-
32
-
33
  pipe = pipe.to(device)
34
  return pipe
35
 
 
36
  MAX_SEED = np.iinfo(np.int32).max
37
  MAX_IMAGE_SIZE = 1024
38
 
@@ -95,14 +108,15 @@ with gr.Blocks(css=css) as demo:
95
  placeholder="Enter your prompt",
96
  container=False,
97
  )
98
- model_choice = gr.Dropdown(
99
- label="Model Choice",
100
- choices=["Sana", "SD3", "Flux"],
101
- value="Sana"
102
- )
103
 
104
  run_button = gr.Button("Run", scale=0, variant="primary")
105
 
 
 
 
 
 
 
106
  result = gr.Image(label="Result", show_label=False)
107
 
108
  with gr.Accordion("Advanced Settings", open=False):
 
14
  torch_dtype = torch.float32
15
 
16
  MODEL_OPTIONS = {
17
+ "SiD-Flow-SD3-medium": "YGu1998/SiD-Flow-SD3-medium",
18
+ "SiDA-Flow-SD3-medium": "YGu1998/SiDA-Flow-SD3-medium",
19
+ "SiD-Flow-SD3.5-large": "YGu1998/SiD-Flow-SD3.5-large",
20
+ "SiDA-Flow-SD3.5-large": "YGu1998/SiDA-Flow-SD3.5-large",
21
+ "SiD-Flow-Sana-0.6B-512-res": "YGu1998/SiD-Flow-Sana-0.6B-512-res",
22
+ "SiDA-Flow-Sana-0.6B-512-res": "YGu1998/SiDA-Flow-Sana-0.6B-512-res",
23
+ "SiD-Flow-Sana-1.6B-512-res": "YGu1998/SiD-Flow-Sana-1.6B-512-res",
24
+ "SiD-Flow-Sana-Sprint-0.6B-1024-res": "YGu1998/SiD-Flow-Sana-Sprint-0.6B-1024-res",
25
+ "SiDA-Flow-Sana-Sprint-0.6B-1024-res": "YGu1998/SiDA-Flow-Sana-Sprint-0.6B-1024-res",
26
+ "SiD-Flow-Sana-Sprint-1.6B-1024-res": "YGu1998/SiD-Flow-Sana-Sprint-1.6B-1024-res",
27
+ "SiDA-Flow-Sana-Sprint-1.6B-1024-res": "YGu1998/SiDA-Flow-Sana-Sprint-1.6B-1024-res",
28
+ "SiD-Flow-Flux-1024-res": "YGu1998/SiD-Flow-Flux-1024-res",
29
+ "SiD-Flow-Flux-512-res": "YGu1998/SiD-Flow-Flux-512-res",
30
  }
31
 
32
+
33
  def load_model(model_choice):
34
  model_repo_id = MODEL_OPTIONS[model_choice]
35
  # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
36
+ if model_choice == "Sana":
37
  pipe = SanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
38
+ elif model_choice == "SD3":
39
+ pipe = StableDiffusion3Pipeline.from_pretrained(
40
+ model_repo_id, torch_dtype=torch_dtype
41
+ )
42
  else:
43
  pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
44
+
 
45
  pipe = pipe.to(device)
46
  return pipe
47
 
48
+
49
  MAX_SEED = np.iinfo(np.int32).max
50
  MAX_IMAGE_SIZE = 1024
51
 
 
108
  placeholder="Enter your prompt",
109
  container=False,
110
  )
 
 
 
 
 
111
 
112
  run_button = gr.Button("Run", scale=0, variant="primary")
113
 
114
+ model_choice = gr.Dropdown(
115
+ label="Model Choice",
116
+ choices=list(MODEL_OPTIONS.keys()),
117
+ value="SiD-Flow-SD3-medium",
118
+ )
119
+
120
  result = gr.Image(label="Result", show_label=False)
121
 
122
  with gr.Accordion("Advanced Settings", open=False):