File size: 2,877 Bytes
b22b80e
45c0c4e
b05966a
 
0f51018
f5a3617
ee02270
 
afa2559
b05966a
 
 
b22b80e
b05966a
 
937a94e
 
699b46e
ee02270
 
 
b05966a
ee02270
7bf5ca7
 
b22b80e
7bf5ca7
 
b22b80e
7bf5ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b22b80e
7bf5ca7
 
 
9c2430d
ee02270
 
9c2430d
7bf5ca7
9c2430d
7bf5ca7
9c2430d
ee02270
b22b80e
7bf5ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import spaces
import torch
from diffusers import QwenImagePipeline
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
from optimization import compile_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami

# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model pipeline
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())

@spaces.GPU(duration=120)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
    if not filename.endswith(".pt2"):
        raise NotImplementedError("The filename must end with a `.pt2` extension.")
    
    # this will throw if token is invalid
    try:
        _ = whoami(oauth_token.token)

        # --- Ahead-of-time compilation ---
        compiled_transformer = compile_transformer(pipe, prompt="prompt")

        token = oauth_token.token 
        out = _push_compiled_graph_to_hub(
            compiled_transformer.archive_file,
            repo_id=repo_id,
            token=token,
            path_in_repo=filename
        )
        if not isinstance(out, str) and hasattr(out, "commit_url"):
            commit_url = out.commit_url
            return f"[{commit_url}]({commit_url})"
        else:
            return out
    except Exception as e:
        raise gr.Error(f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}""")
     
css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("## Compile QwenImage graph ahead of time & push to the Hub")
        gr.Markdown("Enter a **repo_id** and **filename**. This repo automatically compiles the [QwenImage](https://hf.co/Qwen/Qwen-Image) model on start.")

        repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
        filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")

        run = gr.Button("Push graph to Hub", variant="primary")

        markdown_out = gr.Markdown()

    run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])

def swap_visibilty(profile: gr.OAuthProfile | None):
    return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])
        
css_login = '''
.main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
'''
with gr.Blocks(css=css_login) as demo_login:
    gr.LoginButton()
    with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
        demo.render()
    demo_login.load(fn=swap_visibilty, outputs=main_ui)
    
demo_login.queue()
demo_login.launch()