DIPO / app.py
xinjie.wang
update
187b3a6
import gradio as gr
import os
import shutil
import zipfile
from types import SimpleNamespace
from inference import run_demo, load_config
import random
import string
from gradio.themes import Soft
from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
custom_theme = Soft(
primary_hue=stone,
secondary_hue=gray,
radius_size="md",
text_size="sm",
spacing_size="sm",
)
def inference_ui(img1, img2, omega, n_denoise_steps):
tmpdir = 'results'
random_str = ''.join(random.choices(string.ascii_letters, k=16))
tmpdir = tmpdir + "_" + random_str
# ๅˆ ้™คๆ‰€ๆœ‰ๅŒ…ๅซ "results" ็š„็›ฎๅฝ•
for dir in os.listdir('.'):
if dir.startswith('results') and os.path.isdir(dir):
shutil.rmtree(dir)
os.makedirs(os.path.join(tmpdir, "0"), exist_ok=True)
args = SimpleNamespace(
img_path_1=img1,
img_path_2=img2,
ckpt_path='ckpts/dipo.ckpt',
config_path='configs/config.yaml',
use_example_graph=False,
save_dir=tmpdir,
gt_data_root='./data/PartnetMobility',
n_samples=3,
omega=omega,
n_denoise_steps=n_denoise_steps,
)
args.config = load_config(args.config_path)
run_demo(args)
gif_path = os.path.join(tmpdir, "0", "animation.gif")
ply_path = os.path.join(tmpdir, "0", "object.ply")
glb_path = os.path.join(tmpdir, "0", "object.glb")
# ๅŽ‹็ผฉ็ป“ๆžœไธบZIPๅŒ…
zip_path = os.path.join(tmpdir, "output.zip")
folder_to_zip = os.path.join(tmpdir, "0")
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(folder_to_zip):
for file in files:
abs_path = os.path.join(root, file)
rel_path = os.path.relpath(abs_path, folder_to_zip)
zipf.write(abs_path, arcname=rel_path)
return (
gif_path if os.path.exists(gif_path) else None,
zip_path if os.path.exists(zip_path) else None
)
def prepare_data():
if not os.path.exists("data") or not os.path.exists("saved_model"):
print("Downloading data.tar from Hugging Face Datasets...")
os.system("wget https://huggingface.co/datasets/wuruiqi0722/DIPO_data/resolve/main/data/data.tar -O data.tar")
os.system("tar -xf data.tar")
with gr.Blocks(theme=custom_theme) as demo:
gr.Markdown("## DIPO: Dual-State Images Controlled Articulated Object Generation Powered by Diverse Data")
gr.Markdown("""
[๐Ÿ“– Project Page](https://rq-wu.github.io/projects/DIPO) | [๐Ÿ“„ arXiv](https://arxiv.org/abs/2505.20460) | [๐Ÿ’ป GitHub](https://github.com/RQ-Wu/DIPO)
""")
gr.Markdown("Currently, only the articulated object in following categories are supported: `Table`, `Dishwasher`, `StorageFurniture`, `Refrigerator`, `WashingMachine`, `Microwave`, `Oven`.")
with gr.Row():
with gr.Column(scale=1):
img1_input = gr.Image(label="Image: Closed State", type="filepath", height=250)
img2_input = gr.Image(label="Image: Opened State", type="filepath", height=250)
omega = gr.Slider(0.0, 1.0, step=0.1, value=0.5, label="Omega (CFG Guidance)")
n_denoise = gr.Slider(10, 200, step=10, value=100, label="Denoising Steps")
run_button = gr.Button("๐Ÿš€ Run Generation (~2mins)")
with gr.Column(scale=1):
output_gif = gr.Image(label="GIF Animation", type="filepath", height=678, width=10000)
zip_download_btn = gr.DownloadButton(label="๐Ÿ“ฆ Download URDF folder", interactive=False)
gr.Examples(
examples=[
["examples/close10.png", "examples/open10.png"],
["examples/close9.jpg", "examples/open9.jpg"],
["examples/1.png", "examples/1_open_1.png"],
["examples/1.png", "examples/1_open_2.png"],
["examples/close1.png", "examples/open1.png"],
# ["examples/close2.png", "examples/open2.png"],
# ["examples/close4.png", "examples/open4.png"],
["examples/close5.png", "examples/open5.png"],
["examples/close6.png", "examples/open6.png"],
["examples/close7.png", "examples/open7.png"],
["examples/close8.png", "examples/open8.png"],
["examples/close3.png", "examples/open3.png"],
],
inputs=[img1_input, img2_input],
label="๐Ÿ“‚ Example Inputs"
)
run_button.click(
fn=inference_ui,
inputs=[img1_input, img2_input, omega, n_denoise],
outputs=[output_gif, zip_download_btn]
).success(
lambda: gr.DownloadButton(interactive=True),
outputs=[zip_download_btn]
)
if __name__ == "__main__":
prepare_data()
demo.launch()