Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,897 Bytes
b22b80e 45c0c4e b05966a f9f24d7 f5a3617 ee02270 afa2559 b05966a b22b80e b05966a 4b0fe46 937a94e 699b46e 4b0fe46 ee02270 4b0fe46 ee02270 7bf5ca7 b22b80e 7bf5ca7 b22b80e 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 b22b80e 7bf5ca7 4b0fe46 9c2430d ee02270 9c2430d 7bf5ca7 9c2430d 7bf5ca7 9c2430d ee02270 b22b80e 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 4b0fe46 7bf5ca7 4b0fe46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from optimization import compile_transformer
from hub_utils import _push_compiled_graph_to_hub
from huggingface_hub import whoami
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model pipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
@spaces.GPU(duration=120)
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
if not filename.endswith(".pt2"):
raise NotImplementedError("The filename must end with a `.pt2` extension.")
# this will throw if token is invalid
try:
_ = whoami(oauth_token.token)
# --- Ahead-of-time compilation ---
compiled_transformer = compile_transformer(pipe, prompt="prompt")
token = oauth_token.token
out = _push_compiled_graph_to_hub(
compiled_transformer.archive_file, repo_id=repo_id, token=token, path_in_repo=filename
)
if not isinstance(out, str) and hasattr(out, "commit_url"):
commit_url = out.commit_url
return f"[{commit_url}]({commit_url})"
else:
return out
except Exception as e:
raise gr.Error(
f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}"""
)
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"## Compile [Flux.1-Dev](https://hf.co/black-forest-labs/Flux.1-Dev) graph ahead of time & push to the Hub"
)
gr.Markdown(
"Enter a **repo_id** and **filename**. This repo automatically compiles the Flux.1-Dev model ahead of time. Read more about this in [this post](https://huggingface.co/blog/zerogpu-aoti)."
)
repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")
run = gr.Button("Push graph to Hub", variant="primary")
markdown_out = gr.Markdown()
run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])
def swap_visibilty(profile: gr.OAuthProfile | None):
return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])
css_login = """
.main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
"""
with gr.Blocks(css=css_login) as demo_login:
gr.LoginButton()
with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
demo.render()
demo_login.load(fn=swap_visibilty, outputs=main_ui)
demo_login.queue()
demo_login.launch()
|