ZIT-Controlnet / app.py
Alexander Bagus
22
548acb6
raw
history blame
7.83 kB
import gradio as gr
import numpy as np
import random
import json
import spaces
import torch
from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from videox_fun.pipeline import ZImageControlPipeline
from videox_fun.models import ZImageControlTransformer2DModel
from transformers import AutoTokenizer, Qwen3ForCausalLM
from diffusers import AutoencoderKL
from utils.image_utils import get_image_latent, scale_image
from utils.prompt_utils import polish_prompt
# from videox_fun.utils.utils import get_image_latent
# MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280
# git clone https://huggingface.co/Tongyi-MAI/Z-Image-Turbo
MODEL_LOCAL = "models/Z-Image-Turbo/"
# curl -L -o Z-Image-Turbo-Fun-Controlnet-Union.safetensors https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors
TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
weight_dtype = torch.bfloat16
# load transformer
transformer = ZImageControlTransformer2DModel.from_pretrained(
MODEL_LOCAL,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
transformer_additional_kwargs={
"control_layers_places": [0, 5, 10, 15, 20, 25],
"control_in_dim": 16
},
).to(torch.bfloat16)
if TRANSFORMER_LOCAL is not None:
print(f"From checkpoint: {TRANSFORMER_LOCAL}")
if TRANSFORMER_LOCAL.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(TRANSFORMER_LOCAL)
else:
state_dict = torch.load(TRANSFORMER_LOCAL, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
# load ZImageControlPipeline
vae = AutoencoderKL.from_pretrained(
MODEL_LOCAL,
subfolder="vae"
).to(weight_dtype)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_LOCAL, subfolder="tokenizer"
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype,
low_cpu_mem_usage=True,
)
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
pipe = ZImageControlPipeline(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
pipe.transformer = transformer
pipe.to("cuda")
# ======== AoTI compilation + FA3 ========
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers,
"zerogpu-aoti/Z-Image", variant="fa3")
def prepare(prompt, input_image):
polished_prompt = polish_prompt(prompt)
return polished_prompt
@spaces.GPU
def inference(
prompt,
input_image,
image_scale=1.0,
control_context_scale = 0.75,
seed=42,
randomize_seed=True,
guidance_scale=1.5,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
# process image
if input_image is None:
print("Error: input_image is empty.")
return None
input_image, width, height = scale_image(input_image, image_scale)
control_image = get_image_latent(input_image, sample_size=[height, width])[:, :, 0]
# generation
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
height=height,
width=width,
generator=generator,
guidance_scale=guidance_scale,
control_image=control_image,
num_inference_steps=num_inference_steps,
control_context_scale=control_context_scale,
).images[0]
return image, seed
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
with open('static/data.json', 'r') as file:
data = json.load(file)
examples = data['examples']
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
with gr.Column():
gr.HTML(read_file("static/header.html"))
with gr.Row(equal_height=True):
with gr.Column():
input_image = gr.Image(
height=290, sources=['upload', 'clipboard'],
image_mode='RGB',
# elem_id="image_upload",
type="pil", label="Upload")
prompt = gr.Textbox(
label="Prompt",
show_label=False,
lines=2,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated image", show_label=False)
with gr.Accordion("Preprocessor output", open=False):
control_image = gr.Image(label="Control image", show_label=False)
polished_prompt = gr.Textbox(label="Polished prompt", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
image_scale = gr.Slider(
label="Image scale",
minimum=0.5,
maximum=2.0,
step=0.1,
value=1.0,
)
control_context_scale = gr.Slider(
label="Control context scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.75,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=2.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=30,
step=1,
value=8,
)
gr.Examples(examples=examples, inputs=[input_image, prompt])
gr.HTML(read_file("static/footer.html"))
run_button.click(
fn=prepare,
inputs=prompt,
outputs=[polished_prompt]
# outputs=gr.State(), # Pass to the next function, not to UI at this step
).then(
# fn=generate_image,
# inputs=None, # This will automatically use the previous result
# outputs=output
)
# gr.on(
# triggers=[run_button.click, prompt.submit],
# fn=inference,
# inputs=[
# prompt,
# input_image,
# image_scale,
# control_context_scale,
# seed,
# randomize_seed,
# guidance_scale,
# num_inference_steps,
# ],
# outputs=[output_image, seed],
# ).then(
# )
if __name__ == "__main__":
demo.launch(mcp_server=True, css=css)