| import requests |
| import os |
| import yaml |
| from .utils import get_ckpt, get_yaml_config |
|
|
|
|
| def download_ckpt_yaml(model_path, model_name, ckpt_path, yaml_url=None): |
| def download_file(url, save_path): |
| response = requests.get(url) |
| response.raise_for_status() |
| with open(save_path, 'wb') as f: |
| f.write(response.content) |
|
|
| os.makedirs(model_path, exist_ok=True) |
| local_dir = os.path.join(model_path, model_name) |
| os.makedirs(local_dir, exist_ok=True) |
|
|
| ckpt_name = ckpt_path.split("/")[-1] |
| local_ckpt_path = os.path.join(local_dir, ckpt_name) |
| if not os.path.exists(local_ckpt_path): |
| print(f"Downloading CKPT to {local_ckpt_path}") |
| download_file(ckpt_path, local_ckpt_path) |
|
|
| if yaml_url: |
| yaml_name = yaml_url.split("/")[-1] |
| local_yaml_path = os.path.join(local_dir, yaml_name) |
| if not os.path.exists(local_yaml_path): |
| print(f"Downloading YAML to {local_yaml_path}") |
| download_file(yaml_url, local_yaml_path) |
| return local_ckpt_path, local_yaml_path |
|
|
| return local_ckpt_path, None |
|
|
|
|
| def get_model(model_path, model_name): |
| model = None |
| huggingface_token = os.environ.get("HUGGINGFACE_TOKEN", None) |
| data_params = { |
| "target_image_size": (512, 512), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| if model_name.lower() == "anole": |
| from src.vqvaes.anole.anole import VQModel |
| yaml_url = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.yaml" |
| ckpt_path = "https://huggingface.co/GAIR/Anole-7b-v0.1/resolve/main/tokenizer/vqgan.ckpt" |
|
|
| if model_path is not None: |
| ckpt_path, yaml_url = download_ckpt_yaml(model_path, "anole", ckpt_path, yaml_url) |
| config = get_yaml_config(yaml_url) |
|
|
| params = config["model"]["params"] |
| if "lossconfig" in params: |
| del params["lossconfig"] |
| params["ckpt_path"] = ckpt_path |
| model = VQModel(**params) |
| data_params = { |
| "target_image_size": (512, 512), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif model_name.lower() == "chameleon": |
| from src.vqvaes.anole.anole import VQModel |
|
|
| yaml_url = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.yaml" |
| ckpt_path = "https://huggingface.co/huaweilin/chameleon_vqvae/resolve/main/vqgan.ckpt" |
| if model_path is not None: |
| ckpt_path, yaml_url = download_ckpt_yaml(model_path, "chameleon", ckpt_path, yaml_url) |
| config = get_yaml_config(yaml_url) |
|
|
| params = config["model"]["params"] |
| if "lossconfig" in params: |
| del params["lossconfig"] |
| params["ckpt_path"] = ckpt_path |
| model = VQModel(**params) |
| data_params = { |
| "target_image_size": (512, 512), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif model_name.lower() == "llamagen-ds16": |
| from src.vqvaes.llamagen.llamagen import VQ_models |
| ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt" |
| if model_path is not None: |
| ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16", ckpt_path, None) |
|
|
| model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8) |
| model.load_state_dict(get_ckpt(ckpt_path, key="model")) |
| data_params = { |
| "target_image_size": (512, 512), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif model_name.lower() == "llamagen-ds16-t2i": |
| from src.vqvaes.llamagen.llamagen import VQ_models |
| ckpt_path = "https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt" |
| if model_path is not None: |
| ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds16-t2i", ckpt_path, None) |
|
|
| model = VQ_models["VQ-16"](codebook_size=16384, codebook_embed_dim=8) |
| model.load_state_dict(get_ckpt(ckpt_path, key="model")) |
| data_params = { |
| "target_image_size": (512, 512), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif model_name.lower() == "llamagen-ds8": |
| from src.vqvaes.llamagen.llamagen import VQ_models |
| ckpt_path = "https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt" |
| if model_path is not None: |
| ckpt_path, _ = download_ckpt_yaml(model_path, "llamagen-ds8", ckpt_path, None) |
|
|
| model = VQ_models["VQ-8"](codebook_size=16384, codebook_embed_dim=8) |
| model.load_state_dict(get_ckpt(ckpt_path, key="model")) |
| data_params = { |
| "target_image_size": (256, 256), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif model_name.lower() == "flowmo_lo": |
| from src.vqvaes.flowmo.flowmo import build_model |
| yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml" |
| ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_lo.pth" |
| if model_path is not None: |
| ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_lo", ckpt_path, yaml_url) |
| config = get_yaml_config(yaml_url) |
|
|
| config.model.context_dim = 18 |
| model = build_model(config) |
| model.load_state_dict( |
| get_ckpt(ckpt_path, key="model_ema_state_dict") |
| ) |
| data_params = { |
| "target_image_size": (256, 256), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif model_name.lower() == "flowmo_hi": |
| from src.vqvaes.flowmo.flowmo import build_model |
|
|
| yaml_url = "https://raw.githubusercontent.com/kylesargent/FlowMo/refs/heads/main/flowmo/configs/base.yaml" |
| ckpt_path = "https://huggingface.co/ksarge/FlowMo/resolve/main/flowmo_hi.pth" |
| if model_path is not None: |
| ckpt_path, yaml_url = download_ckpt_yaml(model_path, "flowmo_hi", ckpt_path, yaml_url) |
| config = get_yaml_config(yaml_url) |
|
|
| config.model.context_dim = 56 |
| config.model.codebook_size_for_entropy = 14 |
| model = build_model(config) |
| model.load_state_dict( |
| get_ckpt(ckpt_path, key="model_ema_state_dict") |
| ) |
| data_params = { |
| "target_image_size": (256, 256), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif model_name.lower() == "open_magvit2": |
| from src.vqvaes.open_magvit2.open_magvit2 import VQModel |
|
|
| yaml_url = "https://raw.githubusercontent.com/TencentARC/SEED-Voken/refs/heads/main/configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml" |
| ckpt_path = "https://huggingface.co/TencentARC/Open-MAGVIT2-Tokenizer-256-resolution/resolve/main/imagenet_256_L.ckpt" |
| if model_path is not None: |
| ckpt_path, yaml_url = download_ckpt_yaml(model_path, "open_magvit2", ckpt_path, yaml_url) |
| config = get_yaml_config(yaml_url) |
|
|
| model = VQModel(**config.model.init_args) |
| model.load_state_dict(get_ckpt(ckpt_path, key="state_dict")) |
| data_params = { |
| "target_image_size": (256, 256), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif "maskbit" in model_name.lower(): |
| from src.vqvaes.maskbit.maskbit import ConvVQModel |
|
|
| if "16bit" in model_name.lower(): |
| yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_16bit.yaml" |
| ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_16bit/resolve/main/maskbit_tokenizer_16bit.bin" |
| if model_path is not None: |
| ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-16bit", ckpt_path, yaml_url) |
| elif "18bit" in model_name.lower(): |
| yaml_url = "https://raw.githubusercontent.com/markweberdev/maskbit/refs/heads/main/configs/tokenizer/maskbit_tokenizer_18bit.yaml" |
| ckpt_path = "https://huggingface.co/markweber/maskbit_tokenizer_18bit/resolve/main/maskbit_tokenizer_18bit.bin" |
| if model_path is not None: |
| ckpt_path, yaml_url = download_ckpt_yaml(model_path, "maskbit-18bit", ckpt_path, yaml_url) |
| else: |
| raise Exception(f"Unsupported model: {model_name}") |
|
|
| config = get_yaml_config(yaml_url) |
| model = ConvVQModel(config.model.vq_model, legacy=False) |
| model.load_pretrained(get_ckpt(ckpt_path, key=None)) |
| data_params = { |
| "target_image_size": (256, 256), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| "standardize": False, |
| } |
|
|
| elif "bsqvit" in model_name.lower(): |
| from src.vqvaes.bsqvit.bsqvit import VITBSQModel |
|
|
| yaml_url = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/config.yaml" |
| ckpt_path = "https://huggingface.co/huaweilin/bsqvit_256x256/resolve/main/checkpoint.pt" |
| if model_path is not None: |
| ckpt_path, yaml_url = download_ckpt_yaml(model_path, "bsqvit", ckpt_path, yaml_url) |
|
|
| config = get_yaml_config(yaml_url) |
| model = VITBSQModel(**config["model"]["params"]) |
| model.init_from_ckpt(get_ckpt(ckpt_path, key="state_dict")) |
| data_params = { |
| "target_image_size": (256, 256), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| "standardize": False, |
| } |
|
|
| elif "titok" in model_name.lower(): |
| from src.vqvaes.titok.titok import TiTok |
|
|
| ckpt_path = None |
| if "bl64" in model_name.lower(): |
| ckpt_path = "yucornetto/tokenizer_titok_bl64_vq8k_imagenet" |
| elif "bl128" in model_name.lower(): |
| ckpt_path = "yucornetto/tokenizer_titok_bl128_vq8k_imagenet" |
| elif "sl256" in model_name.lower(): |
| ckpt_path = "yucornetto/tokenizer_titok_sl256_vq8k_imagenet" |
| elif "l32" in model_name.lower(): |
| ckpt_path = "yucornetto/tokenizer_titok_l32_imagenet" |
| elif "b64" in model_name.lower(): |
| ckpt_path = "yucornetto/tokenizer_titok_b64_imagenet" |
| elif "s128" in model_name.lower(): |
| ckpt_path = "yucornetto/tokenizer_titok_s128_imagenet" |
| else: |
| raise Exception(f"Unsupported model: {model_name}") |
|
|
| model = TiTok.from_pretrained( |
| ckpt_path, |
| token=huggingface_token |
| ) |
| data_params = { |
| "target_image_size": (256, 256), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| "standardize": False, |
| } |
|
|
| elif "janus_pro" in model_name.lower(): |
| from janus.models import MultiModalityCausalLM |
| from src.vqvaes.janus_pro.janus_pro import forward |
| import types |
|
|
| model = MultiModalityCausalLM.from_pretrained( |
| "deepseek-ai/Janus-Pro-7B", trust_remote_code=True, |
| token=huggingface_token |
| ).gen_vision_model |
| model.forward = types.MethodType(forward, model) |
| data_params = { |
| "target_image_size": (384, 384), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| } |
|
|
| elif "var" in model_name.lower(): |
| from src.vqvaes.var.var_vq import VQVAE |
|
|
| ckpt_path = "https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth" |
| if model_path is not None: |
| ckpt_path, _ = download_ckpt_yaml(model_path, "var", ckpt_path, None) |
|
|
| v_patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16) |
| if "512" in model_name.lower(): |
| v_patch_nums = (1, 2, 3, 4, 6, 9, 13, 18, 24, 32) |
| model = VQVAE( |
| vocab_size=4096, |
| z_channels=32, |
| ch=160, |
| test_mode=True, |
| share_quant_resi=4, |
| v_patch_nums=v_patch_nums, |
| ) |
| model.load_state_dict(get_ckpt(ckpt_path, key=None)) |
| data_params = { |
| "target_image_size": ( |
| (512, 512) if "512" in model_name.lower() else (256, 256) |
| ), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| "standardize": False, |
| } |
|
|
| elif ( |
| "infinity" in model_name.lower() |
| ): |
| from src.vqvaes.infinity.vae import vae_model |
|
|
| if "d32" in model_name: |
| ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d32.pth" |
| codebook_dim = 32 |
| if model_path is not None: |
| ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d32", ckpt_path, None) |
| elif "d64" in model_name: |
| ckpt_path = "https://huggingface.co/FoundationVision/Infinity/resolve/main/infinity_vae_d64.pth" |
| codebook_dim = 64 |
| if model_path is not None: |
| ckpt_path, _ = download_ckpt_yaml(model_path, "infinity-d64", ckpt_path, None) |
|
|
| schedule_mode = "dynamic" |
| codebook_size = 2**codebook_dim |
| patch_size = 16 |
| encoder_ch_mult = [1, 2, 4, 4, 4] |
| decoder_ch_mult = [1, 2, 4, 4, 4] |
|
|
| ckpt = get_ckpt(ckpt_path, key=None) |
| model = vae_model( |
| ckpt, |
| schedule_mode, |
| codebook_dim, |
| codebook_size, |
| patch_size=patch_size, |
| encoder_ch_mult=encoder_ch_mult, |
| decoder_ch_mult=decoder_ch_mult, |
| test_mode=True, |
| ) |
|
|
| data_params = { |
| "target_image_size": (1024, 1024), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| "standardize": False, |
| } |
|
|
| elif "sd3.5l" in model_name.lower(): |
| from src.vaes.stable_diffusion.vae import forward |
| from diffusers import AutoencoderKL |
| import types |
|
|
| model = AutoencoderKL.from_pretrained( |
| "huaweilin/stable-diffusion-3.5-large-vae", subfolder="vae", |
| token=huggingface_token |
| ) |
| model.forward = types.MethodType(forward, model) |
| data_params = { |
| "target_image_size": (1024, 1024), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| "standardize": True, |
| } |
|
|
| elif "FLUX.1-dev".lower() in model_name.lower(): |
| from src.vaes.stable_diffusion.vae import forward |
| from diffusers import AutoencoderKL |
| import types |
|
|
| model = AutoencoderKL.from_pretrained( |
| "black-forest-labs/FLUX.1-dev", subfolder="vae", |
| token=huggingface_token |
| ) |
| model.forward = types.MethodType(forward, model) |
| data_params = { |
| "target_image_size": (1024, 1024), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| "standardize": True, |
| } |
|
|
| elif "gpt4o" in model_name.lower(): |
| from src.vaes.gpt_image.gpt_image import GPTImage |
|
|
| data_params = { |
| "target_image_size": (1024, 1024), |
| "lock_ratio": True, |
| "center_crop": True, |
| "padding": False, |
| "standardize": False, |
| } |
| model = GPTImage(data_params) |
|
|
| else: |
| raise Exception(f"Unsupported model: \"{model_name}\"") |
|
|
| try: |
| trainable_params = sum(p.numel() for p in model.parameters()) |
| print("trainable_params:", trainable_params) |
| except Exception as e: |
| print(e) |
| pass |
|
|
| model.eval() |
| return model, data_params |
|
|