File size: 2,897 Bytes
b22b80e
45c0c4e
b05966a
f9f24d7
f5a3617
ee02270
 
afa2559
b05966a
 
 
b22b80e
b05966a
4b0fe46
 
937a94e
699b46e
4b0fe46
ee02270
 
4b0fe46
ee02270
7bf5ca7
 
b22b80e
7bf5ca7
 
b22b80e
4b0fe46
7bf5ca7
4b0fe46
7bf5ca7
 
 
 
 
 
 
4b0fe46
 
 
 
 
 
7bf5ca7
 
 
 
 
b22b80e
7bf5ca7
4b0fe46
 
 
 
 
 
9c2430d
ee02270
 
9c2430d
7bf5ca7
9c2430d
7bf5ca7
9c2430d
ee02270
b22b80e
4b0fe46
7bf5ca7
 
4b0fe46
 
 
7bf5ca7
4b0fe46
7bf5ca7
 
 
 
 
4b0fe46
7bf5ca7
4b0fe46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from optimization import compile_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami

# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model pipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)


@spaces.GPU(duration=120)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
    if not filename.endswith(".pt2"):
        raise NotImplementedError("The filename must end with a `.pt2` extension.")

    # this will throw if token is invalid
    try:
        _ = whoami(oauth_token.token)

        # --- Ahead-of-time compilation ---
        compiled_transformer = compile_transformer(pipe, prompt="prompt")

        token = oauth_token.token
        out = _push_compiled_graph_to_hub(
            compiled_transformer.archive_file, repo_id=repo_id, token=token, path_in_repo=filename
        )
        if not isinstance(out, str) and hasattr(out, "commit_url"):
            commit_url = out.commit_url
            return f"[{commit_url}]({commit_url})"
        else:
            return out
    except Exception as e:
        raise gr.Error(
            f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}"""
        )


css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            "## Compile [Flux.1-Dev](https://hf.co/black-forest-labs/Flux.1-Dev) graph ahead of time & push to the Hub"
        )
        gr.Markdown(
            "Enter a **repo_id** and **filename**. This repo automatically compiles the Flux.1-Dev model ahead of time. Read more about this in [this post](https://huggingface.co/blog/zerogpu-aoti)."
        )

        repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
        filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")

        run = gr.Button("Push graph to Hub", variant="primary")

        markdown_out = gr.Markdown()

    run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])


def swap_visibilty(profile: gr.OAuthProfile | None):
    return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])


css_login = """
.main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
"""
with gr.Blocks(css=css_login) as demo_login:
    gr.LoginButton()
    with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
        demo.render()
    demo_login.load(fn=swap_visibilty, outputs=main_ui)

demo_login.queue()
demo_login.launch()