Yinhong Liu commited on
Commit
71383c2
·
1 Parent(s): b2e8669

model selection

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -7,15 +7,23 @@ from diffusers import DiffusionPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
14
  else:
15
  torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
@@ -31,6 +39,7 @@ def infer(
31
  height,
32
  guidance_scale,
33
  num_inference_steps,
 
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
  if randomize_seed:
@@ -38,6 +47,8 @@ def infer(
38
 
39
  generator = torch.Generator().manual_seed(seed)
40
 
 
 
41
  image = pipe(
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
@@ -76,6 +87,11 @@ with gr.Blocks(css=css) as demo:
76
  placeholder="Enter your prompt",
77
  container=False,
78
  )
 
 
 
 
 
79
 
80
  run_button = gr.Button("Run", scale=0, variant="primary")
81
 
@@ -146,6 +162,7 @@ with gr.Blocks(css=css) as demo:
146
  height,
147
  guidance_scale,
148
  num_inference_steps,
 
149
  ],
150
  outputs=[result, seed],
151
  )
 
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
 
11
  if torch.cuda.is_available():
12
  torch_dtype = torch.float16
13
  else:
14
  torch_dtype = torch.float32
15
 
16
+ MODEL_OPTIONS = {
17
+ "Sana": "sana-model-repo-id",
18
+ "SD3": "sd3-model-repo-id",
19
+ "Flux": "flux-model-repo-id"
20
+ }
21
+
22
+ def load_model(model_choice):
23
+ model_repo_id = MODEL_OPTIONS[model_choice]
24
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
25
+ pipe = pipe.to(device)
26
+ return pipe
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  MAX_IMAGE_SIZE = 1024
 
39
  height,
40
  guidance_scale,
41
  num_inference_steps,
42
+ model_choice,
43
  progress=gr.Progress(track_tqdm=True),
44
  ):
45
  if randomize_seed:
 
47
 
48
  generator = torch.Generator().manual_seed(seed)
49
 
50
+ pipe = load_model(model_choice)
51
+
52
  image = pipe(
53
  prompt=prompt,
54
  negative_prompt=negative_prompt,
 
87
  placeholder="Enter your prompt",
88
  container=False,
89
  )
90
+ model_choice = gr.Dropdown(
91
+ label="Model Choice",
92
+ choices=["Sana", "SD3", "Flux"],
93
+ value="Sana"
94
+ )
95
 
96
  run_button = gr.Button("Run", scale=0, variant="primary")
97
 
 
162
  height,
163
  guidance_scale,
164
  num_inference_steps,
165
+ model_choice,
166
  ],
167
  outputs=[result, seed],
168
  )