import os import torch import torch.nn.functional as F import gc from .utils import log, print_memory, apply_lora, clip_encode_image_tiled, fourier_filter import numpy as np import math from tqdm import tqdm from .wanvideo.modules.clip import CLIPModel from .wanvideo.modules.model import WanModel, rope_params from .wanvideo.modules.t5 import T5EncoderModel from .wanvideo.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .wanvideo.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .wanvideo.utils.basic_flowmatch import FlowMatchScheduler from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, DEISMultistepScheduler from .wanvideo.utils.scheduling_flow_match_lcm import FlowMatchLCMScheduler from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight, set_num_frames from .taehv import TAEHV from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device from einops import rearrange import folder_paths import comfy.model_management as mm from comfy.utils import load_torch_file, ProgressBar, common_upscale import comfy.model_base import comfy.latent_formats from comfy.clip_vision import clip_preprocess, ClipVisionModel from comfy.sd import load_lora_for_models from comfy.cli_args import args, LatentPreviewMethod script_directory = os.path.dirname(os.path.abspath(__file__)) def add_noise_to_reference_video(image, ratio=None): sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio image_noise = torch.randn_like(image) * sigma[:, None, None, None] image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) image = image + image_noise return image def optimized_scale(positive_flat, negative_flat): # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) # Squared norm of uncondition squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 st_star = dot_product / squared_norm return st_star class WanVideoBlockSwap: @classmethod def INPUT_TYPES(s): return { "required": { "blocks_to_swap": ("INT", {"default": 20, "min": 0, "max": 40, "step": 1, "tooltip": "Number of transformer blocks to swap, the 14B model has 40, while the 1.3B model has 30 blocks"}), "offload_img_emb": ("BOOLEAN", {"default": False, "tooltip": "Offload img_emb to offload_device"}), "offload_txt_emb": ("BOOLEAN", {"default": False, "tooltip": "Offload time_emb to offload_device"}), }, "optional": { "use_non_blocking": ("BOOLEAN", {"default": True, "tooltip": "Use non-blocking memory transfer for offloading, reserves more RAM but is faster"}), "vace_blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 15, "step": 1, "tooltip": "Number of VACE blocks to swap, the VACE model has 15 blocks"}), }, } RETURN_TYPES = ("BLOCKSWAPARGS",) RETURN_NAMES = ("block_swap_args",) FUNCTION = "setargs" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Settings for block swapping, reduces VRAM use by swapping blocks to CPU memory" def setargs(self, **kwargs): return (kwargs, ) class WanVideoVRAMManagement: @classmethod def INPUT_TYPES(s): return { "required": { "offload_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Percentage of parameters to offload"}), }, } RETURN_TYPES = ("VRAM_MANAGEMENTARGS",) RETURN_NAMES = ("vram_management_args",) FUNCTION = "setargs" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower" def setargs(self, **kwargs): return (kwargs, ) class WanVideoTeaCache: @classmethod def INPUT_TYPES(s): return { "required": { "rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Higher values will make TeaCache more aggressive, faster, but may cause artifacts. Good value range for 1.3B: 0.05 - 0.08, for other models 0.15-0.30"}), "start_step": ("INT", {"default": 1, "min": 0, "max": 9999, "step": 1, "tooltip": "Start percentage of the steps to apply TeaCache"}), "end_step": ("INT", {"default": -1, "min": -1, "max": 9999, "step": 1, "tooltip": "End steps to apply TeaCache"}), "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), "use_coefficients": ("BOOLEAN", {"default": True, "tooltip": "Use calculated coefficients for more accuracy. When enabled therel_l1_thresh should be about 10 times higher than without"}), }, "optional": { "mode": (["e", "e0"], {"default": "e", "tooltip": "Choice between using e (time embeds, default) or e0 (modulated time embeds)"}), }, } RETURN_TYPES = ("CACHEARGS",) RETURN_NAMES = ("cache_args",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = """ Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and applying it instead of doing the step. Best results are achieved by choosing the appropriate coefficients for the model. Early steps should never be skipped, with too aggressive values this can happen and the motion suffers. Starting later can help with that too. When NOT using coefficients, the threshold value should be about 10 times smaller than the value used with coefficients. Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4Wan2.1:
+-------------------+--------+---------+--------+ | Model | Low | Medium | High | +-------------------+--------+---------+--------+ | Wan2.1 t2v 1.3B | 0.05 | 0.07 | 0.08 | | Wan2.1 t2v 14B | 0.14 | 0.15 | 0.20 | | Wan2.1 i2v 480P | 0.13 | 0.19 | 0.26 | | Wan2.1 i2v 720P | 0.18 | 0.20 | 0.30 | +-------------------+--------+---------+--------+""" EXPERIMENTAL = True def process(self, rel_l1_thresh, start_step, end_step, cache_device, use_coefficients, mode="e"): if cache_device == "main_device": cache_device = mm.get_torch_device() else: cache_device = mm.unet_offload_device() cache_args = { "cache_type": "TeaCache", "rel_l1_thresh": rel_l1_thresh, "start_step": start_step, "end_step": end_step, "cache_device": cache_device, "use_coefficients": use_coefficients, "mode": mode, } return (cache_args,) class WanVideoMagCache: @classmethod def INPUT_TYPES(s): return { "required": { "magcache_thresh": ("FLOAT", {"default": 0.02, "min": 0.0, "max": 0.3, "step": 0.001, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}), "magcache_K": ("INT", {"default": 4, "min": 0, "max": 6, "step": 1, "tooltip": "The maxium skip steps of MagCache."}), "start_step": ("INT", {"default": 1, "min": 0, "max": 9999, "step": 1, "tooltip": "Step to start applying MagCache"}), "end_step": ("INT", {"default": -1, "min": -1, "max": 9999, "step": 1, "tooltip": "Step to end applying MagCache"}), "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), }, } RETURN_TYPES = ("CACHEARGS",) RETURN_NAMES = ("cache_args",) FUNCTION = "setargs" CATEGORY = "WanVideoWrapper" EXPERIMENTAL = True DESCRIPTION = "MagCache for WanVideoWrapper, source https://github.com/Zehong-Ma/MagCache" def setargs(self, magcache_thresh, magcache_K, start_step, end_step, cache_device): if cache_device == "main_device": cache_device = mm.get_torch_device() else: cache_device = mm.unet_offload_device() cache_args = { "cache_type": "MagCache", "magcache_thresh": magcache_thresh, "magcache_K": magcache_K, "start_step": start_step, "end_step": end_step, "cache_device": cache_device, } return (cache_args,) class WanVideoModel(comfy.model_base.BaseModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pipeline = {} def __getitem__(self, k): return self.pipeline[k] def __setitem__(self, k, v): self.pipeline[k] = v try: from comfy.latent_formats import Wan21 latent_format = Wan21 except: #for backwards compatibility log.warning("Wan21 latent format not found, update ComfyUI for better livepreview") from comfy.latent_formats import HunyuanVideo latent_format = HunyuanVideo class WanVideoModelConfig: def __init__(self, dtype): self.unet_config = {} self.unet_extra_config = {} self.latent_format = latent_format self.latent_format.latent_channels = 16 self.manual_cast_dtype = dtype self.sampling_settings = {"multiplier": 1.0} self.memory_usage_factor = 2.0 self.unet_config["disable_unet_model_creation"] = True def filter_state_dict_by_blocks(state_dict, blocks_mapping, layer_filter=[]): filtered_dict = {} if isinstance(layer_filter, str): layer_filters = [layer_filter] if layer_filter else [] else: # Filter out empty strings layer_filters = [f for f in layer_filter if f] if layer_filter else [] #print("layer_filter: ", layer_filters) for key in state_dict: if not any(filter_str in key for filter_str in layer_filters): if 'blocks.' in key: block_pattern = key.split('diffusion_model.')[1].split('.', 2)[0:2] block_key = f'{block_pattern[0]}.{block_pattern[1]}.' if block_key in blocks_mapping: filtered_dict[key] = state_dict[key] else: filtered_dict[key] = state_dict[key] for key in filtered_dict: print(key) #from safetensors.torch import save_file #save_file(filtered_dict, "filtered_state_dict_2.safetensors") return filtered_dict def standardize_lora_key_format(lora_sd): new_sd = {} for k, v in lora_sd.items(): # Diffusers format if k.startswith('transformer.'): k = k.replace('transformer.', 'diffusion_model.') if k.startswith('pipe.dit.'): #unianimate-dit/diffsynth k = k.replace('pipe.dit.', 'diffusion_model.') # Fun LoRA format if k.startswith('lora_unet__'): # Split into main path and weight type parts parts = k.split('.') main_part = parts[0] # e.g. lora_unet__blocks_0_cross_attn_k weight_type = '.'.join(parts[1:]) if len(parts) > 1 else None # e.g. lora_down.weight # Process the main part - convert from underscore to dot format if 'blocks_' in main_part: # Extract components components = main_part[len('lora_unet__'):].split('_') # Start with diffusion_model new_key = "diffusion_model" # Add blocks.N if components[0] == 'blocks': new_key += f".blocks.{components[1]}" # Handle different module types idx = 2 if idx < len(components): if components[idx] == 'self' and idx+1 < len(components) and components[idx+1] == 'attn': new_key += ".self_attn" idx += 2 elif components[idx] == 'cross' and idx+1 < len(components) and components[idx+1] == 'attn': new_key += ".cross_attn" idx += 2 elif components[idx] == 'ffn': new_key += ".ffn" idx += 1 # Add the component (k, q, v, o) and handle img suffix if idx < len(components): component = components[idx] idx += 1 # Check for img suffix if idx < len(components) and components[idx] == 'img': component += '_img' idx += 1 new_key += f".{component}" # Handle weight type - this is the critical fix if weight_type: if weight_type == 'alpha': new_key += '.alpha' elif weight_type == 'lora_down.weight' or weight_type == 'lora_down': new_key += '.lora_A.weight' elif weight_type == 'lora_up.weight' or weight_type == 'lora_up': new_key += '.lora_B.weight' else: # Keep original weight type if not matching our patterns new_key += f'.{weight_type}' # Add .weight suffix if missing if not new_key.endswith('.weight'): new_key += '.weight' k = new_key else: # For other lora_unet__ formats (head, embeddings, etc.) new_key = main_part.replace('lora_unet__', 'diffusion_model.') # Fix specific component naming patterns new_key = new_key.replace('_self_attn', '.self_attn') new_key = new_key.replace('_cross_attn', '.cross_attn') new_key = new_key.replace('_ffn', '.ffn') new_key = new_key.replace('blocks_', 'blocks.') new_key = new_key.replace('head_head', 'head.head') new_key = new_key.replace('img_emb', 'img_emb') new_key = new_key.replace('text_embedding', 'text.embedding') new_key = new_key.replace('time_embedding', 'time.embedding') new_key = new_key.replace('time_projection', 'time.projection') # Replace remaining underscores with dots, carefully parts = new_key.split('.') final_parts = [] for part in parts: if part in ['img_emb', 'self_attn', 'cross_attn']: final_parts.append(part) # Keep these intact else: final_parts.append(part.replace('_', '.')) new_key = '.'.join(final_parts) # Handle weight type if weight_type: if weight_type == 'alpha': new_key += '.alpha' elif weight_type == 'lora_down.weight' or weight_type == 'lora_down': new_key += '.lora_A.weight' elif weight_type == 'lora_up.weight' or weight_type == 'lora_up': new_key += '.lora_B.weight' else: new_key += f'.{weight_type}' if not new_key.endswith('.weight'): new_key += '.weight' k = new_key # Handle special embedded components special_components = { 'time.projection': 'time_projection', 'img.emb': 'img_emb', 'text.emb': 'text_emb', 'time.emb': 'time_emb', } for old, new in special_components.items(): if old in k: k = k.replace(old, new) # Fix diffusion.model -> diffusion_model if k.startswith('diffusion.model.'): k = k.replace('diffusion.model.', 'diffusion_model.') # Finetrainer format if '.attn1.' in k: k = k.replace('.attn1.', '.cross_attn.') k = k.replace('.to_k.', '.k.') k = k.replace('.to_q.', '.q.') k = k.replace('.to_v.', '.v.') k = k.replace('.to_out.0.', '.o.') elif '.attn2.' in k: k = k.replace('.attn2.', '.cross_attn.') k = k.replace('.to_k.', '.k.') k = k.replace('.to_q.', '.q.') k = k.replace('.to_v.', '.v.') k = k.replace('.to_out.0.', '.o.') if "img_attn.proj" in k: k = k.replace("img_attn.proj", "img_attn_proj") if "img_attn.qkv" in k: k = k.replace("img_attn.qkv", "img_attn_qkv") if "txt_attn.proj" in k: k = k.replace("txt_attn.proj", "txt_attn_proj") if "txt_attn.qkv" in k: k = k.replace("txt_attn.qkv", "txt_attn_qkv") new_sd[k] = v return new_sd class WanVideoEnhanceAVideo: @classmethod def INPUT_TYPES(s): return { "required": { "weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}), }, } RETURN_TYPES = ("FETAARGS",) RETURN_NAMES = ("feta_args",) FUNCTION = "setargs" CATEGORY = "WanVideoWrapper" DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" def setargs(self, **kwargs): return (kwargs, ) class WanVideoLoraSelect: @classmethod def INPUT_TYPES(s): return { "required": { "lora": (folder_paths.get_filename_list("loras"), {"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), }, "optional": { "prev_lora":("WANVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), "blocks":("SELECTEDBLOCKS", ), "low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading"}), } } RETURN_TYPES = ("WANVIDLORA",) RETURN_NAMES = ("lora", ) FUNCTION = "getlorapath" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" def getlorapath(self, lora, strength, blocks={}, prev_lora=None, low_mem_load=False): loras_list = [] lora = { "path": folder_paths.get_full_path("loras", lora), "strength": strength, "name": lora.split(".")[0], "blocks": blocks.get("selected_blocks", {}), "layer_filter": blocks.get("layer_filter", ""), "low_mem_load": low_mem_load, } if prev_lora is not None: loras_list.extend(prev_lora) loras_list.append(lora) return (loras_list,) class WanVideoLoraSelectMulti: @classmethod def INPUT_TYPES(s): lora_files = folder_paths.get_filename_list("loras") lora_files = ["none"] + lora_files # Add "none" as the first option return { "required": { "lora_0": (lora_files, {"default": "none"}), "strength_0": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), "lora_1": (lora_files, {"default": "none"}), "strength_1": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), "lora_2": (lora_files, {"default": "none"}), "strength_2": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), "lora_3": (lora_files, {"default": "none"}), "strength_3": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), "lora_4": (lora_files, {"default": "none"}), "strength_4": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), }, "optional": { "prev_lora":("WANVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), "blocks":("SELECTEDBLOCKS", ), "low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading"}), } } RETURN_TYPES = ("WANVIDLORA",) RETURN_NAMES = ("lora", ) FUNCTION = "getlorapath" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" def getlorapath(self, lora_0, strength_0, lora_1, strength_1, lora_2, strength_2, lora_3, strength_3, lora_4, strength_4, blocks={}, prev_lora=None, low_mem_load=False): loras_list = [] if prev_lora is not None: loras_list.extend(prev_lora) # Process each LoRA lora_inputs = [ (lora_0, strength_0), (lora_1, strength_1), (lora_2, strength_2), (lora_3, strength_3), (lora_4, strength_4) ] for lora_name, strength in lora_inputs: # Skip if the LoRA is empty if not lora_name or lora_name == "none": continue lora = { "path": folder_paths.get_full_path("loras", lora_name), "strength": strength, "name": lora_name.split(".")[0], "blocks": blocks.get("selected_blocks", {}), "layer_filter": blocks.get("layer_filter", ""), "low_mem_load": low_mem_load, } loras_list.append(lora) return (loras_list,) class WanVideoVACEModelSelect: @classmethod def INPUT_TYPES(s): return { "required": { "vace_model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' VACE model to use when not using model that has it included"}), }, } RETURN_TYPES = ("VACEPATH",) RETURN_NAMES = ("vace_model", ) FUNCTION = "getvacepath" CATEGORY = "WanVideoWrapper" DESCRIPTION = "VACE model to use when not using model that has it included, loaded from 'ComfyUI/models/diffusion_models'" def getvacepath(self, vace_model): vace_model = { "path": folder_paths.get_full_path("diffusion_models", vace_model), } return (vace_model,) class WanVideoLoraBlockEdit: def __init__(self): self.loaded_lora = None @classmethod def INPUT_TYPES(s): arg_dict = {} argument = ("BOOLEAN", {"default": True}) for i in range(40): arg_dict["blocks.{}.".format(i)] = argument return {"required": arg_dict, "optional": {"layer_filter": ("STRING", {"default": "", "multiline": True})}} RETURN_TYPES = ("SELECTEDBLOCKS", ) RETURN_NAMES = ("blocks", ) OUTPUT_TOOLTIPS = ("The modified lora model",) FUNCTION = "select" CATEGORY = "WanVideoWrapper" def select(self, layer_filter=[], **kwargs): selected_blocks = {k: v for k, v in kwargs.items() if v is True and isinstance(v, bool)} print("Selected blocks LoRA: ", selected_blocks) selected = { "selected_blocks": selected_blocks, "layer_filter": [x.strip() for x in layer_filter.split(",")] } return (selected,) #region Model loading class WanVideoModelLoader: @classmethod def INPUT_TYPES(s): return { "required": { "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), "base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}), "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2', 'fp8_e4m3fn_fast_no_ffn'], {"default": 'disabled', "tooltip": "optional quantization method"}), "load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}), }, "optional": { "attention_mode": ([ "sdpa", "flash_attn_2", "flash_attn_3", "sageattn", "flex_attention", #"spargeattn", needs tuning #"spargeattn_tune", ], {"default": "sdpa"}), "compile_args": ("WANCOMPILEARGS", ), "block_swap_args": ("BLOCKSWAPARGS", ), "lora": ("WANVIDLORA", {"default": None}), "vram_management_args": ("VRAM_MANAGEMENTARGS", {"default": None, "tooltip": "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower"}), "vace_model": ("VACEPATH", {"default": None, "tooltip": "VACE model to use when not using model that has it included"}), "fantasytalking_model": ("FANTASYTALKINGMODEL", {"default": None, "tooltip": "FantasyTalking model https://github.com/Fantasy-AMAP"}), } } RETURN_TYPES = ("WANVIDEOMODEL",) RETURN_NAMES = ("model", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" def loadmodel(self, model, base_precision, load_device, quantization, compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, vram_management_args=None, vace_model=None, fantasytalking_model=None): assert not (vram_management_args is not None and block_swap_args is not None), "Can't use both block_swap_args and vram_management_args at the same time" lora_low_mem_load = False if lora is not None: for l in lora: lora_low_mem_load = l.get("low_mem_load") if lora is not None else False transformer = None mm.unload_all_models() mm.cleanup_models() mm.soft_empty_cache() manual_offloading = True if "sage" in attention_mode: try: from sageattention import sageattn except Exception as e: raise ValueError(f"Can't import SageAttention: {str(e)}") device = mm.get_torch_device() offload_device = mm.unet_offload_device() manual_offloading = True transformer_load_device = device if load_device == "main_device" else offload_device base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] if base_precision == "fp16_fast": if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = True else: raise ValueError("torch.backends.cuda.matmul.allow_fp16_accumulation is not available in this version of torch, requires torch 2.7.0.dev2025 02 26 nightly minimum currently") else: try: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = False except: pass model_path = folder_paths.get_full_path_or_raise("diffusion_models", model) sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True) if "vace_blocks.0.after_proj.weight" in sd and not "patch_embedding.weight" in sd: raise ValueError("You are attempting to load a VACE module as a WanVideo model, instead you should use the vace_model input and matching T2V base model") if vace_model is not None: vace_sd = load_torch_file(vace_model["path"], device=transformer_load_device, safe_load=True) sd.update(vace_sd) first_key = next(iter(sd)) if first_key.startswith("model.diffusion_model."): new_sd = {} for key, value in sd.items(): new_key = key.replace("model.diffusion_model.", "", 1) new_sd[new_key] = value sd = new_sd elif first_key.startswith("model."): new_sd = {} for key, value in sd.items(): new_key = key.replace("model.", "", 1) new_sd[new_key] = value sd = new_sd if not "patch_embedding.weight" in sd: raise ValueError("Invalid WanVideo model selected") dim = sd["patch_embedding.weight"].shape[0] in_channels = sd["patch_embedding.weight"].shape[1] log.info(f"Detected model in_channels: {in_channels}") ffn_dim = sd["blocks.0.ffn.0.bias"].shape[0] if not "text_embedding.0.weight" in sd: model_type = "no_cross_attn" #minimaxremover elif "model_type.Wan2_1-FLF2V-14B-720P" in sd or "img_emb.emb_pos" in sd or "flf2v" in model.lower(): model_type = "fl2v" elif in_channels in [36, 48]: model_type = "i2v" elif in_channels == 16: model_type = "t2v" elif "control_adapter.conv.weight" in sd: model_type = "t2v" num_heads = 40 if dim == 5120 else 12 num_layers = 40 if dim == 5120 else 30 vace_layers, vace_in_dim = None, None if "vace_blocks.0.after_proj.weight" in sd: if in_channels != 16: raise ValueError("VACE only works properly with T2V models.") model_type = "t2v" if dim == 5120: vace_layers = [0, 5, 10, 15, 20, 25, 30, 35] else: vace_layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28] vace_in_dim = 96 log.info(f"Model type: {model_type}, num_heads: {num_heads}, num_layers: {num_layers}") teacache_coefficients_map = { "1_3B": { "e": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], "e0": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], }, "14B": { "e": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], "e0": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], }, "i2v_480": { "e": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], "e0": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], }, "i2v_720":{ "e": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], "e0": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], }, } magcache_ratios_map = { "1_3B": np.array([1.0]*2+[1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939]), "14B": np.array([1.0]*2+[1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]), "i2v_480": np.array([1.0]*2+[0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616]), "i2v_720": np.array([1.0]*2+[0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768]), } model_variant = "14B" #default to this if model_type == "i2v" or model_type == "fl2v": if "480" in model or "fun" in model.lower() or "a2" in model.lower() or "540" in model: #just a guess for the Fun model for now... model_variant = "i2v_480" elif "720" in model: model_variant = "i2v_720" elif model_type == "t2v": model_variant = "14B" if dim == 1536: model_variant = "1_3B" log.info(f"Model variant detected: {model_variant}") TRANSFORMER_CONFIG= { "dim": dim, "ffn_dim": ffn_dim, "eps": 1e-06, "freq_dim": 256, "in_dim": in_channels, "model_type": model_type, "out_dim": 16, "text_len": 512, "num_heads": num_heads, "num_layers": num_layers, "attention_mode": attention_mode, "main_device": device, "offload_device": offload_device, "teacache_coefficients": teacache_coefficients_map[model_variant], "magcache_ratios": magcache_ratios_map[model_variant], "vace_layers": vace_layers, "vace_in_dim": vace_in_dim, "inject_sample_info": True if "fps_embedding.weight" in sd else False, "add_ref_conv": True if "ref_conv.weight" in sd else False, "in_dim_ref_conv": sd["ref_conv.weight"].shape[1] if "ref_conv.weight" in sd else None, "add_control_adapter": True if "control_adapter.conv.weight" in sd else False, } with init_empty_weights(): transformer = WanModel(**TRANSFORMER_CONFIG) transformer.eval() #ReCamMaster if "blocks.0.cam_encoder.weight" in sd: log.info("ReCamMaster model detected, patching model...") import torch.nn as nn for block in transformer.blocks: block.cam_encoder = nn.Linear(12, dim) block.projector = nn.Linear(dim, dim) block.cam_encoder.weight.data.zero_() block.cam_encoder.bias.data.zero_() block.projector.weight = nn.Parameter(torch.eye(dim)) block.projector.bias = nn.Parameter(torch.zeros(dim)) # FantasyTalking https://github.com/Fantasy-AMAP if fantasytalking_model is not None: log.info("FantasyTalking model detected, patching model...") context_dim = fantasytalking_model["sd"]["proj_model.proj.weight"].shape[0] import torch.nn as nn for block in transformer.blocks: block.cross_attn.k_proj = nn.Linear(context_dim, dim, bias=False) block.cross_attn.v_proj = nn.Linear(context_dim, dim, bias=False) sd.update(fantasytalking_model["sd"]) # RealisDance-DiT if "add_conv_in.weight" in sd: def zero_module(module): for p in module.parameters(): torch.nn.init.zeros_(p) return module inner_dim = sd["add_conv_in.weight"].shape[0] add_cond_in_dim = sd["add_conv_in.weight"].shape[1] attn_cond_in_dim = sd["attn_conv_in.weight"].shape[1] transformer.add_conv_in = torch.nn.Conv3d(add_cond_in_dim, inner_dim, kernel_size=transformer.patch_size, stride=transformer.patch_size) transformer.add_proj = zero_module(torch.nn.Linear(inner_dim, inner_dim)) transformer.attn_conv_in = torch.nn.Conv3d(attn_cond_in_dim, inner_dim, kernel_size=transformer.patch_size, stride=transformer.patch_size) comfy_model = WanVideoModel( WanVideoModelConfig(base_dtype), model_type=comfy.model_base.ModelType.FLOW, device=device, ) if quantization == "disabled": for k, v in sd.items(): if isinstance(v, torch.Tensor): if v.dtype == torch.float8_e4m3fn: quantization = "fp8_e4m3fn" break elif v.dtype == torch.float8_e5m2: quantization = "fp8_e5m2" break if "fp8_e4m3fn" in quantization: dtype = torch.float8_e4m3fn elif quantization == "fp8_e5m2": dtype = torch.float8_e5m2 else: dtype = base_dtype params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter", "add"} #if lora is not None: # transformer_load_device = device if not lora_low_mem_load: log.info("Using accelerate to load and assign model weights to device...") param_count = sum(1 for _ in transformer.named_parameters()) for name, param in tqdm(transformer.named_parameters(), desc=f"Loading transformer parameters to {transformer_load_device}", total=param_count, leave=True): dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype if "patch_embedding" in name: dtype_to_use = torch.float32 set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name]) comfy_model.diffusion_model = transformer comfy_model.load_device = transformer_load_device patcher = comfy.model_patcher.ModelPatcher(comfy_model, device, offload_device) patcher.model.is_patched = False control_lora = False if lora is not None: for l in lora: log.info(f"Loading LoRA: {l['name']} with strength: {l['strength']}") lora_path = l["path"] lora_strength = l["strength"] lora_sd = load_torch_file(lora_path, safe_load=True) if "dwpose_embedding.0.weight" in lora_sd: #unianimate from .unianimate.nodes import update_transformer log.info("Unianimate LoRA detected, patching model...") transformer = update_transformer(transformer, lora_sd) lora_sd = standardize_lora_key_format(lora_sd) if l["blocks"]: lora_sd = filter_state_dict_by_blocks(lora_sd, l["blocks"], l.get("layer_filter", [])) #spacepxl's control LoRA patch # for key in lora_sd.keys(): # print(key) if "diffusion_model.patch_embedding.lora_A.weight" in lora_sd: log.info("Control-LoRA detected, patching model...") control_lora = True in_cls = transformer.patch_embedding.__class__ # nn.Conv3d old_in_dim = transformer.in_dim # 16 new_in_dim = lora_sd["diffusion_model.patch_embedding.lora_A.weight"].shape[1] assert new_in_dim == 32 new_in = in_cls( new_in_dim, transformer.patch_embedding.out_channels, transformer.patch_embedding.kernel_size, transformer.patch_embedding.stride, transformer.patch_embedding.padding, ).to(device=device, dtype=torch.float32) new_in.weight.zero_() new_in.bias.zero_() new_in.weight[:, :old_in_dim].copy_(transformer.patch_embedding.weight) new_in.bias.copy_(transformer.patch_embedding.bias) transformer.patch_embedding = new_in transformer.expanded_patch_embedding = new_in transformer.register_to_config(in_dim=new_in_dim) patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0) del lora_sd patcher = apply_lora(patcher, device, transformer_load_device, params_to_keep=params_to_keep, dtype=dtype, base_dtype=base_dtype, state_dict=sd, low_mem_load=lora_low_mem_load) #patcher.load(device, full_load=True) patcher.model.is_patched = True if "fast" in quantization: from .fp8_optimization import convert_fp8_linear if quantization == "fp8_e4m3fn_fast_no_ffn": params_to_keep.update({"ffn"}) print(params_to_keep) convert_fp8_linear(patcher.model.diffusion_model, base_dtype, params_to_keep=params_to_keep) del sd if vram_management_args is not None: from .diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from .wanvideo.modules.model import WanLayerNorm, WanRMSNorm total_params_in_model = sum(p.numel() for p in patcher.model.diffusion_model.parameters()) log.info(f"Total number of parameters in the loaded model: {total_params_in_model}") offload_percent = vram_management_args["offload_percent"] offload_params = int(total_params_in_model * offload_percent) params_to_keep = total_params_in_model - offload_params log.info(f"Selected params to offload: {offload_params}") enable_vram_management( patcher.model.diffusion_model, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, WanLayerNorm: AutoWrappedModule, WanRMSNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device=offload_device, onload_dtype=dtype, onload_device=device, computation_dtype=base_dtype, computation_device=device, ), max_num_param=params_to_keep, overflow_module_config = dict( offload_dtype=dtype, offload_device=offload_device, onload_dtype=dtype, onload_device=offload_device, computation_dtype=base_dtype, computation_device=device, ), compile_args = compile_args, ) #compile if compile_args is not None and vram_management_args is None: torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] try: if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'): torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"] except Exception as e: log.warning(f"Could not set recompile_limit: {e}") if compile_args["compile_transformer_blocks_only"]: for i, block in enumerate(patcher.model.diffusion_model.blocks): patcher.model.diffusion_model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if vace_layers is not None: for i, block in enumerate(patcher.model.diffusion_model.vace_blocks): patcher.model.diffusion_model.vace_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) else: patcher.model.diffusion_model = torch.compile(patcher.model.diffusion_model, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if load_device == "offload_device" and patcher.model.diffusion_model.device != offload_device: log.info(f"Moving diffusion model from {patcher.model.diffusion_model.device} to {offload_device}") patcher.model.diffusion_model.to(offload_device) gc.collect() mm.soft_empty_cache() patcher.model["dtype"] = base_dtype patcher.model["base_path"] = model_path patcher.model["model_name"] = model patcher.model["manual_offloading"] = manual_offloading patcher.model["quantization"] = quantization patcher.model["auto_cpu_offload"] = True if vram_management_args is not None else False patcher.model["control_lora"] = control_lora if 'transformer_options' not in patcher.model_options: patcher.model_options['transformer_options'] = {} patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args for model in mm.current_loaded_models: if model._model() == patcher: mm.current_loaded_models.remove(model) return (patcher,) class WanVideoSetBlockSwap: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("WANVIDEOMODEL", ), "block_swap_args": ("BLOCKSWAPARGS", ), } } RETURN_TYPES = ("WANVIDEOMODEL",) RETURN_NAMES = ("model", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" def loadmodel(self, model, block_swap_args): patcher = model.clone() if 'transformer_options' not in patcher.model_options: patcher.model_options['transformer_options'] = {} patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args return (patcher,) #region load VAE class WanVideoVAELoader: @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), }, "optional": { "precision": (["fp16", "fp32", "bf16"], {"default": "bf16"} ), } } RETURN_TYPES = ("WANVAE",) RETURN_NAMES = ("vae", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Loads Wan VAE model from 'ComfyUI/models/vae'" def loadmodel(self, model_name, precision): from .wanvideo.wan_video_vae import WanVideoVAE device = mm.get_torch_device() offload_device = mm.unet_offload_device() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] #with open(os.path.join(script_directory, 'configs', 'hy_vae_config.json')) as f: # vae_config = json.load(f) model_path = folder_paths.get_full_path("vae", model_name) vae_sd = load_torch_file(model_path, safe_load=True) has_model_prefix = any(k.startswith("model.") for k in vae_sd.keys()) if not has_model_prefix: vae_sd = {f"model.{k}": v for k, v in vae_sd.items()} vae = WanVideoVAE(dtype=dtype) vae.load_state_dict(vae_sd) vae.eval() vae.to(device = offload_device, dtype = dtype) return (vae,) class WanVideoTinyVAELoader: @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (folder_paths.get_filename_list("vae_approx"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae_approx'"}), }, "optional": { "precision": (["fp16", "fp32", "bf16"], {"default": "fp16"}), "parallel": ("BOOLEAN", {"default": False, "tooltip": "uses more memory but is faster"}), } } RETURN_TYPES = ("WANVAE",) RETURN_NAMES = ("vae", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Loads Wan VAE model from 'ComfyUI/models/vae'" def loadmodel(self, model_name, precision, parallel=False): from .taehv import TAEHV device = mm.get_torch_device() offload_device = mm.unet_offload_device() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] model_path = folder_paths.get_full_path("vae_approx", model_name) vae_sd = load_torch_file(model_path, safe_load=True) vae = TAEHV(vae_sd, parallel=parallel) vae.to(device = offload_device, dtype = dtype) return (vae,) class WanVideoTorchCompileSettings: @classmethod def INPUT_TYPES(s): return { "required": { "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only the transformer blocks, usually enough and can make compilation faster and less error prone"}), }, "optional": { "dynamo_recompile_limit": ("INT", {"default": 128, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.recompile_limit"}), }, } RETURN_TYPES = ("WANCOMPILEARGS",) RETURN_NAMES = ("torch_compile_args",) FUNCTION = "set_args" CATEGORY = "WanVideoWrapper" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128): compile_args = { "backend": backend, "fullgraph": fullgraph, "mode": mode, "dynamic": dynamic, "dynamo_cache_size_limit": dynamo_cache_size_limit, "dynamo_recompile_limit": dynamo_recompile_limit, "compile_transformer_blocks_only": compile_transformer_blocks_only, } return (compile_args, ) #region TextEncode class LoadWanVideoT5TextEncoder: @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/text_encoders'"}), "precision": (["fp32", "bf16"], {"default": "bf16"} ), }, "optional": { "load_device": (["main_device", "offload_device"], {"default": "offload_device"}), "quantization": (['disabled', 'fp8_e4m3fn'], {"default": 'disabled', "tooltip": "optional quantization method"}), } } RETURN_TYPES = ("WANTEXTENCODER",) RETURN_NAMES = ("wan_t5_model", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Loads Wan text_encoder model from 'ComfyUI/models/LLM'" def loadmodel(self, model_name, precision, load_device="offload_device", quantization="disabled"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() text_encoder_load_device = device if load_device == "main_device" else offload_device tokenizer_path = os.path.join(script_directory, "configs", "T5_tokenizer") dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] model_path = folder_paths.get_full_path("text_encoders", model_name) sd = load_torch_file(model_path, safe_load=True) if "token_embedding.weight" not in sd and "shared.weight" not in sd: raise ValueError("Invalid T5 text encoder model, this node expects the 'umt5-xxl' model") if "scaled_fp8" in sd: raise ValueError("Invalid T5 text encoder model, fp8 scaled is not supported by this node") # Convert state dict keys from T5 format to the expected format if "shared.weight" in sd: log.info("Converting T5 text encoder model to the expected format...") converted_sd = {} for key, value in sd.items(): # Handle encoder block patterns if key.startswith('encoder.block.'): parts = key.split('.') block_num = parts[2] # Self-attention components if 'layer.0.SelfAttention' in key: if key.endswith('.k.weight'): new_key = f"blocks.{block_num}.attn.k.weight" elif key.endswith('.o.weight'): new_key = f"blocks.{block_num}.attn.o.weight" elif key.endswith('.q.weight'): new_key = f"blocks.{block_num}.attn.q.weight" elif key.endswith('.v.weight'): new_key = f"blocks.{block_num}.attn.v.weight" elif 'relative_attention_bias' in key: new_key = f"blocks.{block_num}.pos_embedding.embedding.weight" else: new_key = key # Layer norms elif 'layer.0.layer_norm' in key: new_key = f"blocks.{block_num}.norm1.weight" elif 'layer.1.layer_norm' in key: new_key = f"blocks.{block_num}.norm2.weight" # Feed-forward components elif 'layer.1.DenseReluDense' in key: if 'wi_0' in key: new_key = f"blocks.{block_num}.ffn.gate.0.weight" elif 'wi_1' in key: new_key = f"blocks.{block_num}.ffn.fc1.weight" elif 'wo' in key: new_key = f"blocks.{block_num}.ffn.fc2.weight" else: new_key = key else: new_key = key elif key == "shared.weight": new_key = "token_embedding.weight" elif key == "encoder.final_layer_norm.weight": new_key = "norm.weight" else: new_key = key converted_sd[new_key] = value sd = converted_sd T5_text_encoder = T5EncoderModel( text_len=512, dtype=dtype, device=text_encoder_load_device, state_dict=sd, tokenizer_path=tokenizer_path, quantization=quantization ) text_encoder = { "model": T5_text_encoder, "dtype": dtype, } return (text_encoder,) class LoadWanVideoClipTextEncoder: @classmethod def INPUT_TYPES(s): return { "required": { "model_name": (folder_paths.get_filename_list("clip_vision") + folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/clip_vision'"}), "precision": (["fp16", "fp32", "bf16"], {"default": "fp16"} ), }, "optional": { "load_device": (["main_device", "offload_device"], {"default": "offload_device"}), } } RETURN_TYPES = ("CLIP_VISION",) RETURN_NAMES = ("wan_clip_vision", ) FUNCTION = "loadmodel" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Loads Wan clip_vision model from 'ComfyUI/models/clip_vision'" def loadmodel(self, model_name, precision, load_device="offload_device"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() text_encoder_load_device = device if load_device == "main_device" else offload_device dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] model_path = folder_paths.get_full_path("clip_vision", model_name) # We also support legacy setups where the model is in the text_encoders folder if model_path is None: model_path = folder_paths.get_full_path("text_encoders", model_name) sd = load_torch_file(model_path, safe_load=True) if "log_scale" not in sd: raise ValueError("Invalid CLIP model, this node expectes the 'open-clip-xlm-roberta-large-vit-huge-14' model") clip_model = CLIPModel(dtype=dtype, device=device, state_dict=sd) clip_model.model.to(text_encoder_load_device) del sd return (clip_model,) class WanVideoTextEncode: @classmethod def INPUT_TYPES(s): return {"required": { "t5": ("WANTEXTENCODER",), "positive_prompt": ("STRING", {"default": "", "multiline": True} ), "negative_prompt": ("STRING", {"default": "", "multiline": True} ), }, "optional": { "force_offload": ("BOOLEAN", {"default": True}), "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), } } RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) RETURN_NAMES = ("text_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Encodes text prompts into text embeddings. For rudimentary prompt travel you can input multiple prompts separated by '|', they will be equally spread over the video length" def process(self, t5, positive_prompt, negative_prompt,force_offload=True, model_to_offload=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() if model_to_offload is not None: log.info(f"Moving video model to {offload_device}") model_to_offload.model.to(offload_device) mm.soft_empty_cache() encoder = t5["model"] dtype = t5["dtype"] # Split positive prompts and process each with weights positive_prompts_raw = [p.strip() for p in positive_prompt.split('|')] positive_prompts = [] all_weights = [] for p in positive_prompts_raw: cleaned_prompt, weights = self.parse_prompt_weights(p) positive_prompts.append(cleaned_prompt) all_weights.append(weights) encoder.model.to(device) with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=True): context = encoder(positive_prompts, device) context_null = encoder([negative_prompt], device) # Apply weights to embeddings if any were extracted for i, weights in enumerate(all_weights): for text, weight in weights.items(): log.info(f"Applying weight {weight} to prompt: {text}") if len(weights) > 0: context[i] = context[i] * weight if force_offload: encoder.model.to(offload_device) mm.soft_empty_cache() prompt_embeds_dict = { "prompt_embeds": context, "negative_prompt_embeds": context_null, } return (prompt_embeds_dict,) def parse_prompt_weights(self, prompt): """Extract text and weights from prompts with (text:weight) format""" import re # Parse all instances of (text:weight) in the prompt pattern = r'\((.*?):([\d\.]+)\)' matches = re.findall(pattern, prompt) # Replace each match with just the text part cleaned_prompt = prompt weights = {} for match in matches: text, weight = match orig_text = f"({text}:{weight})" cleaned_prompt = cleaned_prompt.replace(orig_text, text) weights[text] = float(weight) return cleaned_prompt, weights class WanVideoTextEncodeSingle: @classmethod def INPUT_TYPES(s): return {"required": { "t5": ("WANTEXTENCODER",), "prompt": ("STRING", {"default": "", "multiline": True} ), }, "optional": { "force_offload": ("BOOLEAN", {"default": True}), "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), } } RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) RETURN_NAMES = ("text_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Encodes text prompt into text embedding." def process(self, t5, prompt, force_offload=True, model_to_offload=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() if model_to_offload is not None: log.info(f"Moving video model to {offload_device}") model_to_offload.model.to(offload_device) mm.soft_empty_cache() encoder = t5["model"] dtype = t5["dtype"] encoder.model.to(device) with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=True): encoded = encoder([prompt], device) if force_offload: encoder.model.to(offload_device) mm.soft_empty_cache() prompt_embeds_dict = { "prompt_embeds": encoded, "negative_prompt_embeds": None, } return (prompt_embeds_dict,) class WanVideoApplyNAG: @classmethod def INPUT_TYPES(s): return {"required": { "original_text_embeds": ("WANVIDEOTEXTEMBEDS",), "nag_text_embeds": ("WANVIDEOTEXTEMBEDS",), "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.1}), "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.1}), "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), }, } RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) RETURN_NAMES = ("text_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Adds NAG prompt embeds to original prompt embeds: 'https://github.com/ChenDarYen/Normalized-Attention-Guidance'" def process(self, original_text_embeds, nag_text_embeds, nag_scale, nag_tau, nag_alpha): prompt_embeds_dict_copy = original_text_embeds.copy() prompt_embeds_dict_copy.update({ "nag_prompt_embeds": nag_text_embeds["prompt_embeds"], "nag_params": { "nag_scale": nag_scale, "nag_tau": nag_tau, "nag_alpha": nag_alpha, } }) return (prompt_embeds_dict_copy,) class WanVideoTextEmbedBridge: @classmethod def INPUT_TYPES(s): return {"required": { "positive": ("CONDITIONING",), }, "optional": { "negative": ("CONDITIONING",), } } RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) RETURN_NAMES = ("text_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Bridge between ComfyUI native text embedding and WanVideoWrapper text embedding" def process(self, positive, negative=None): device=mm.get_torch_device() prompt_embeds_dict = { "prompt_embeds": positive[0][0].to(device), "negative_prompt_embeds": negative[0][0].to(device) if negative is not None else None, } return (prompt_embeds_dict,) #region clip image encode class WanVideoImageClipEncode: @classmethod def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), "image": ("IMAGE", {"tooltip": "Image to encode"}), "vae": ("WANVAE",), "generation_width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), "generation_height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), }, "optional": { "force_offload": ("BOOLEAN", {"default": True}), "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for I2V where some noise can add motion and give sharper results"}), "latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), "clip_embed_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), "adjust_resolution": ("BOOLEAN", {"default": True, "tooltip": "Performs the same resolution adjustment as in the original code"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) RETURN_NAMES = ("image_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DEPRECATED = True def process(self, clip_vision, vae, image, num_frames, generation_width, generation_height, force_offload=True, noise_aug_strength=0.0, latent_strength=1.0, clip_embed_strength=1.0, adjust_resolution=True): device = mm.get_torch_device() offload_device = mm.unet_offload_device() self.image_mean = [0.48145466, 0.4578275, 0.40821073] self.image_std = [0.26862954, 0.26130258, 0.27577711] patch_size = (1, 2, 2) vae_stride = (4, 8, 8) H, W = image.shape[1], image.shape[2] max_area = generation_width * generation_height print(clip_vision) clip_vision.model.to(device) if isinstance(clip_vision, ClipVisionModel): clip_context = clip_vision.encode_image(image).last_hidden_state.to(device) else: pixel_values = clip_preprocess(image.to(device), size=224, mean=self.image_mean, std=self.image_std, crop=True).float() clip_context = clip_vision.visual(pixel_values) if clip_embed_strength != 1.0: clip_context *= clip_embed_strength if force_offload: clip_vision.model.to(offload_device) mm.soft_empty_cache() if adjust_resolution: aspect_ratio = H / W lat_h = round( np.sqrt(max_area * aspect_ratio) // vae_stride[1] // patch_size[1] * patch_size[1]) lat_w = round( np.sqrt(max_area / aspect_ratio) // vae_stride[2] // patch_size[2] * patch_size[2]) h = lat_h * vae_stride[1] w = lat_w * vae_stride[2] else: h = generation_height w = generation_width lat_h = h // 8 lat_w = w // 8 # Step 1: Create initial mask with ones for first frame, zeros for others mask = torch.ones(1, num_frames, lat_h, lat_w, device=device) mask[:, 1:] = 0 # Step 2: Repeat first frame 4 times and concatenate with remaining frames first_frame_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) mask = torch.concat([first_frame_repeated, mask[:, 1:]], dim=1) # Step 3: Reshape mask into groups of 4 frames mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) # Step 4: Transpose dimensions and select first batch mask = mask.transpose(1, 2)[0] # Calculate maximum sequence length frames_per_stride = (num_frames - 1) // vae_stride[0] + 1 patches_per_frame = lat_h * lat_w // (patch_size[1] * patch_size[2]) max_seq_len = frames_per_stride * patches_per_frame vae.to(device) # Step 1: Resize and rearrange the input image dimensions #resized_image = image.permute(0, 3, 1, 2) # Rearrange dimensions to (B, C, H, W) #resized_image = torch.nn.functional.interpolate(resized_image, size=(h, w), mode='bicubic') resized_image = common_upscale(image.movedim(-1, 1), w, h, "lanczos", "disabled") resized_image = resized_image.transpose(0, 1) # Transpose to match required format resized_image = resized_image * 2 - 1 if noise_aug_strength > 0.0: resized_image = add_noise_to_reference_video(resized_image, ratio=noise_aug_strength) # Step 2: Create zero padding frames zero_frames = torch.zeros(3, num_frames-1, h, w, device=device) # Step 3: Concatenate image with zero frames concatenated = torch.concat([resized_image.to(device), zero_frames, resized_image.to(device)], dim=1).to(device = device, dtype = vae.dtype) concatenated *= latent_strength y = vae.encode([concatenated], device)[0] y = torch.concat([mask, y]) vae.model.clear_cache() vae.to(offload_device) image_embeds = { "image_embeds": y, "clip_context": clip_context, "max_seq_len": max_seq_len, "num_frames": num_frames, "lat_h": lat_h, "lat_w": lat_w, } return (image_embeds,) class WanVideoImageResizeToClosest: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE", {"tooltip": "Image to resize"}), "generation_width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), "generation_height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), "aspect_ratio_preservation": (["keep_input", "stretch_to_new", "crop_to_new"],), }, } RETURN_TYPES = ("IMAGE", "INT", "INT", ) RETURN_NAMES = ("image","width","height",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Resizes image to the closest supported resolution based on aspect ratio and max pixels, according to the original code" def process(self, image, generation_width, generation_height, aspect_ratio_preservation ): patch_size = (1, 2, 2) vae_stride = (4, 8, 8) H, W = image.shape[1], image.shape[2] max_area = generation_width * generation_height crop = "disabled" if aspect_ratio_preservation == "keep_input": aspect_ratio = H / W elif aspect_ratio_preservation == "stretch_to_new" or aspect_ratio_preservation == "crop_to_new": aspect_ratio = generation_height / generation_width if aspect_ratio_preservation == "crop_to_new": crop = "center" lat_h = round( np.sqrt(max_area * aspect_ratio) // vae_stride[1] // patch_size[1] * patch_size[1]) lat_w = round( np.sqrt(max_area / aspect_ratio) // vae_stride[2] // patch_size[2] * patch_size[2]) h = lat_h * vae_stride[1] w = lat_w * vae_stride[2] resized_image = common_upscale(image.movedim(-1, 1), w, h, "lanczos", crop).movedim(1, -1) return (resized_image, w, h) #region clip vision class WanVideoClipVisionEncode: @classmethod def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), "image_1": ("IMAGE", {"tooltip": "Image to encode"}), "strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), "strength_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), "crop": (["center", "disabled"], {"default": "center", "tooltip": "Crop image to 224x224 before encoding"}), "combine_embeds": (["average", "sum", "concat", "batch"], {"default": "average", "tooltip": "Method to combine multiple clip embeds"}), "force_offload": ("BOOLEAN", {"default": True}), }, "optional": { "image_2": ("IMAGE", ), "negative_image": ("IMAGE", {"tooltip": "image to use for uncond"}), "tiles": ("INT", {"default": 0, "min": 0, "max": 16, "step": 2, "tooltip": "Use matteo's tiled image encoding for improved accuracy"}), "ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ratio of the tile average"}), } } RETURN_TYPES = ("WANVIDIMAGE_CLIPEMBEDS",) RETURN_NAMES = ("image_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, clip_vision, image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2=None, negative_image=None, tiles=0, ratio=1.0): device = mm.get_torch_device() offload_device = mm.unet_offload_device() image_mean = [0.48145466, 0.4578275, 0.40821073] image_std = [0.26862954, 0.26130258, 0.27577711] if image_2 is not None: image = torch.cat([image_1, image_2], dim=0) else: image = image_1 clip_vision.model.to(device) negative_clip_embeds = None if tiles > 0: log.info("Using tiled image encoding") clip_embeds = clip_encode_image_tiled(clip_vision, image.to(device), tiles=tiles, ratio=ratio) if negative_image is not None: negative_clip_embeds = clip_encode_image_tiled(clip_vision, negative_image.to(device), tiles=tiles, ratio=ratio) else: if isinstance(clip_vision, ClipVisionModel): clip_embeds = clip_vision.encode_image(image).penultimate_hidden_states.to(device) if negative_image is not None: negative_clip_embeds = clip_vision.encode_image(negative_image).penultimate_hidden_states.to(device) else: pixel_values = clip_preprocess(image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() clip_embeds = clip_vision.visual(pixel_values) if negative_image is not None: pixel_values = clip_preprocess(negative_image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() negative_clip_embeds = clip_vision.visual(pixel_values) log.info(f"Clip embeds shape: {clip_embeds.shape}, dtype: {clip_embeds.dtype}") weighted_embeds = [] weighted_embeds.append(clip_embeds[0:1] * strength_1) # Handle all additional embeddings if clip_embeds.shape[0] > 1: weighted_embeds.append(clip_embeds[1:2] * strength_2) if clip_embeds.shape[0] > 2: for i in range(2, clip_embeds.shape[0]): weighted_embeds.append(clip_embeds[i:i+1]) # Add as-is without strength modifier # Combine all weighted embeddings if combine_embeds == "average": clip_embeds = torch.mean(torch.stack(weighted_embeds), dim=0) elif combine_embeds == "sum": clip_embeds = torch.sum(torch.stack(weighted_embeds), dim=0) elif combine_embeds == "concat": clip_embeds = torch.cat(weighted_embeds, dim=1) elif combine_embeds == "batch": clip_embeds = torch.cat(weighted_embeds, dim=0) else: clip_embeds = weighted_embeds[0] log.info(f"Combined clip embeds shape: {clip_embeds.shape}") if force_offload: clip_vision.model.to(offload_device) mm.soft_empty_cache() clip_embeds_dict = { "clip_embeds": clip_embeds, "negative_clip_embeds": negative_clip_embeds } return (clip_embeds_dict,) class WanVideoRealisDanceLatents: @classmethod def INPUT_TYPES(s): return {"required": { "ref_latent": ("LATENT", {"tooltip": "Reference image to encode"}), "smpl_latent": ("LATENT", {"tooltip": "SMPL pose image to encode"}), "pose_cond_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the SMPL model"}), "pose_cond_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the SMPL model"}), }, "optional": { "hamer_latent": ("LATENT", {"tooltip": "Hamer hand pose image to encode"}), }, } RETURN_TYPES = ("REALISDANCELATENTS",) RETURN_NAMES = ("realisdance_latents",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, ref_latent, smpl_latent, pose_cond_start_percent, pose_cond_end_percent, hamer_latent=None): if hamer_latent is None: hamer = torch.zeros_like(smpl_latent["samples"]) else: hamer = hamer_latent["samples"] pose_latent = torch.cat((smpl_latent["samples"], hamer), dim=1) realisdance_latents = { "ref_latent": ref_latent["samples"], "pose_latent": pose_latent, "pose_cond_start_percent": pose_cond_start_percent, "pose_cond_end_percent": pose_cond_end_percent, } return (realisdance_latents,) class WanVideoImageToVideoEncode: @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("WANVAE",), "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for I2V where some noise can add motion and give sharper results"}), "start_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), "end_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), "force_offload": ("BOOLEAN", {"default": True}), }, "optional": { "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), "start_image": ("IMAGE", {"tooltip": "Image to encode"}), "end_image": ("IMAGE", {"tooltip": "end frame"}), "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "Control signal for the Fun -model"}), "fun_or_fl2v_model": ("BOOLEAN", {"default": True, "tooltip": "Enable when using official FLF2V or Fun model"}), "temporal_mask": ("MASK", {"tooltip": "mask"}), "extra_latents": ("LATENT", {"tooltip": "Extra latents to add to the input front, used for Skyreels A2 reference images"}), "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), "realisdance_latents": ("REALISDANCELATENTS", {"tooltip": "RealisDance latents"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) RETURN_NAMES = ("image_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, vae, width, height, num_frames, force_offload, noise_aug_strength, start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False, temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, realisdance_latents=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() patch_size = (1, 2, 2) H = height W = width lat_h = H // 8 lat_w = W // 8 num_frames = ((num_frames - 1) // 4) * 4 + 1 two_ref_images = start_image is not None and end_image is not None base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0) if temporal_mask is None: mask = torch.zeros(1, base_frames, lat_h, lat_w, device=device) if start_image is not None: mask[:, 0:start_image.shape[0]] = 1 # First frame if end_image is not None: mask[:, -end_image.shape[0]:] = 1 # End frame if exists else: mask = common_upscale(temporal_mask.unsqueeze(1).to(device), lat_w, lat_h, "nearest", "disabled").squeeze(1) if mask.shape[0] > base_frames: mask = mask[:base_frames] elif mask.shape[0] < base_frames: mask = torch.cat([mask, torch.zeros(base_frames - mask.shape[0], lat_h, lat_w, device=device)]) mask = mask.unsqueeze(0).to(device) # Repeat first frame and optionally end frame start_mask_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) # T, C, H, W if end_image is not None and not fun_or_fl2v_model: end_mask_repeated = torch.repeat_interleave(mask[:, -1:], repeats=4, dim=1) # T, C, H, W mask = torch.cat([start_mask_repeated, mask[:, 1:-1], end_mask_repeated], dim=1) else: mask = torch.cat([start_mask_repeated, mask[:, 1:]], dim=1) # Reshape mask into groups of 4 frames mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W mask = mask.movedim(1, 2)[0]# C, T, H, W # Resize and rearrange the input image dimensions if start_image is not None: resized_start_image = common_upscale(start_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) resized_start_image = resized_start_image * 2 - 1 if noise_aug_strength > 0.0: resized_start_image = add_noise_to_reference_video(resized_start_image, ratio=noise_aug_strength) if end_image is not None: resized_end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) resized_end_image = resized_end_image * 2 - 1 if noise_aug_strength > 0.0: resized_end_image = add_noise_to_reference_video(resized_end_image, ratio=noise_aug_strength) # Concatenate image with zero frames and encode vae.to(device) if temporal_mask is None: if start_image is not None and end_image is None: zero_frames = torch.zeros(3, num_frames-start_image.shape[0], H, W, device=device) concatenated = torch.cat([resized_start_image.to(device), zero_frames], dim=1) elif start_image is None and end_image is not None: zero_frames = torch.zeros(3, num_frames-end_image.shape[0], H, W, device=device) concatenated = torch.cat([zero_frames, resized_end_image.to(device)], dim=1) elif start_image is None and end_image is None: concatenated = torch.zeros(3, num_frames, H, W, device=device) else: if fun_or_fl2v_model: zero_frames = torch.zeros(3, num_frames-(start_image.shape[0]+end_image.shape[0]), H, W, device=device) else: zero_frames = torch.zeros(3, num_frames-1, H, W, device=device) concatenated = torch.cat([resized_start_image.to(device), zero_frames, resized_end_image.to(device)], dim=1) else: temporal_mask = common_upscale(temporal_mask.unsqueeze(1), W, H, "nearest", "disabled").squeeze(1) concatenated = resized_start_image[:,:num_frames] * temporal_mask[:num_frames].unsqueeze(0) y = vae.encode([concatenated.to(device=device, dtype=vae.dtype)], device, end_=(end_image is not None and not fun_or_fl2v_model),tiled=tiled_vae)[0] has_ref = False if extra_latents is not None: samples = extra_latents["samples"].squeeze(0) y = torch.cat([samples, y], dim=1) mask = torch.cat([torch.ones_like(mask[:, 0:samples.shape[1]]), mask], dim=1) num_frames += samples.shape[1] * 4 has_ref = True y[:, :1] *= start_latent_strength y[:, -1:] *= end_latent_strength if control_embeds is None: y = torch.cat([mask, y]) else: if end_image is None: y[:, 1:] = 0 elif start_image is None: y[:, -1:] = 0 else: y[:, 1:-1] = 0 # doesn't seem to work anyway though... # Calculate maximum sequence length patches_per_frame = lat_h * lat_w // (patch_size[1] * patch_size[2]) frames_per_stride = (num_frames - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1) max_seq_len = frames_per_stride * patches_per_frame if realisdance_latents is not None: realisdance_latents["ref_latent_neg"] = vae.encode(torch.zeros(1, 3, 1, H, W, device=device, dtype=vae.dtype), device) vae.model.clear_cache() if force_offload: vae.model.to(offload_device) mm.soft_empty_cache() gc.collect() image_embeds = { "image_embeds": y, "clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None, "negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None, "max_seq_len": max_seq_len, "num_frames": num_frames, "lat_h": lat_h, "lat_w": lat_w, "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, "end_image": resized_end_image if end_image is not None else None, "fun_or_fl2v_model": fun_or_fl2v_model, "has_ref": has_ref, "realisdance_latents": realisdance_latents } return (image_embeds,) class WanVideoEmptyEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), }, "optional": { "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "control signal for the Fun -model"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) RETURN_NAMES = ("image_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, num_frames, width, height, control_embeds=None): vae_stride = (4, 8, 8) target_shape = (16, (num_frames - 1) // vae_stride[0] + 1, height // vae_stride[1], width // vae_stride[2]) embeds = { "target_shape": target_shape, "num_frames": num_frames, "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, } return (embeds,) class WanVideoMiniMaxRemoverEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), "mask_latents": ("LATENT", {"tooltip": "Encoded latents to use as mask"}), }, } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) RETURN_NAMES = ("image_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, num_frames, width, height, latents, mask_latents): vae_stride = (4, 8, 8) target_shape = (16, (num_frames - 1) // vae_stride[0] + 1, height // vae_stride[1], width // vae_stride[2]) embeds = { "target_shape": target_shape, "num_frames": num_frames, "minimax_latents": latents["samples"].squeeze(0), "minimax_mask_latents": mask_latents["samples"].squeeze(0), } return (embeds,) # region phantom class WanVideoPhantomEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), "phantom_latent_1": ("LATENT", {"tooltip": "reference latents for the phantom model"}), "phantom_cfg_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "CFG scale for the extra phantom cond pass"}), "phantom_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the phantom model"}), "phantom_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the phantom model"}), }, "optional": { "phantom_latent_2": ("LATENT", {"tooltip": "reference latents for the phantom model"}), "phantom_latent_3": ("LATENT", {"tooltip": "reference latents for the phantom model"}), "phantom_latent_4": ("LATENT", {"tooltip": "reference latents for the phantom model"}), "vace_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "VACE embeds"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) RETURN_NAMES = ("image_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, num_frames, phantom_cfg_scale, phantom_start_percent, phantom_end_percent, phantom_latent_1, phantom_latent_2=None, phantom_latent_3=None, phantom_latent_4=None, vace_embeds=None): vae_stride = (4, 8, 8) samples = phantom_latent_1["samples"].squeeze(0) if phantom_latent_2 is not None: samples = torch.cat([samples, phantom_latent_2["samples"].squeeze(0)], dim=1) if phantom_latent_3 is not None: samples = torch.cat([samples, phantom_latent_3["samples"].squeeze(0)], dim=1) if phantom_latent_4 is not None: samples = torch.cat([samples, phantom_latent_4["samples"].squeeze(0)], dim=1) C, T, H, W = samples.shape log.info(f"Phantom latents shape: {samples.shape}") target_shape = (16, (num_frames - 1) // vae_stride[0] + 1 + T, H * 8 // vae_stride[1], W * 8 // vae_stride[2]) embeds = { "target_shape": target_shape, "num_frames": num_frames, "phantom_latents": samples, "phantom_cfg_scale": phantom_cfg_scale, "phantom_start_percent": phantom_start_percent, "phantom_end_percent": phantom_end_percent, } if vace_embeds is not None: vace_input = { "vace_context": vace_embeds["vace_context"], "vace_scale": vace_embeds["vace_scale"], "has_ref": vace_embeds["has_ref"], "vace_start_percent": vace_embeds["vace_start_percent"], "vace_end_percent": vace_embeds["vace_end_percent"], "vace_seq_len": vace_embeds["vace_seq_len"], "additional_vace_inputs": vace_embeds["additional_vace_inputs"], } embeds.update(vace_input) return (embeds,) class WanVideoControlEmbeds: @classmethod def INPUT_TYPES(s): return {"required": { "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), }, "optional": { "fun_ref_image": ("LATENT", {"tooltip": "Reference latent for the Fun 1.1 -model"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) RETURN_NAMES = ("image_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, latents, start_percent, end_percent, fun_ref_image=None): samples = latents["samples"].squeeze(0) C, T, H, W = samples.shape num_frames = (T - 1) * 4 + 1 seq_len = math.ceil((H * W) / 4 * ((num_frames - 1) // 4 + 1)) embeds = { "max_seq_len": seq_len, "target_shape": samples.shape, "num_frames": num_frames, "control_embeds": { "control_images": samples, "start_percent": start_percent, "end_percent": end_percent, "fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None, } } return (embeds,) class WanVideoSLG: @classmethod def INPUT_TYPES(s): return {"required": { "blocks": ("STRING", {"default": "10", "tooltip": "Blocks to skip uncond on, separated by comma, index starts from 0"}), "start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), }, } RETURN_TYPES = ("SLGARGS", ) RETURN_NAMES = ("slg_args",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Skips uncond on the selected blocks" def process(self, blocks, start_percent, end_percent): slg_block_list = [int(x.strip()) for x in blocks.split(",")] slg_args = { "blocks": slg_block_list, "start_percent": start_percent, "end_percent": end_percent, } return (slg_args,) #region VACE class WanVideoVACEEncode: @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("WANVAE",), "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), "vace_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply VACE"}), "vace_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply VACE"}), }, "optional": { "input_frames": ("IMAGE",), "ref_images": ("IMAGE",), "input_masks": ("MASK",), "prev_vace_embeds": ("WANVIDIMAGE_EMBEDS",), "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), }, } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) RETURN_NAMES = ("vace_embeds",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, vae, width, height, num_frames, strength, vace_start_percent, vace_end_percent, input_frames=None, ref_images=None, input_masks=None, prev_vace_embeds=None, tiled_vae=False): self.device = mm.get_torch_device() offload_device = mm.unet_offload_device() self.vae = vae.to(self.device) self.vae_stride = (4, 8, 8) width = (width // 16) * 16 height = (height // 16) * 16 target_shape = (16, (num_frames - 1) // self.vae_stride[0] + 1, height // self.vae_stride[1], width // self.vae_stride[2]) # vace context encode if input_frames is None: input_frames = torch.zeros((1, 3, num_frames, height, width), device=self.device, dtype=self.vae.dtype) else: input_frames = input_frames[:num_frames] input_frames = common_upscale(input_frames.clone().movedim(-1, 1), width, height, "lanczos", "disabled").movedim(1, -1) input_frames = input_frames.to(self.vae.dtype).to(self.device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W input_frames = input_frames * 2 - 1 if input_masks is None: input_masks = torch.ones_like(input_frames, device=self.device) else: print("input_masks shape", input_masks.shape) input_masks = input_masks[:num_frames] input_masks = common_upscale(input_masks.clone().unsqueeze(1), width, height, "nearest-exact", "disabled").squeeze(1) input_masks = input_masks.to(self.vae.dtype).to(self.device) input_masks = input_masks.unsqueeze(-1).unsqueeze(0).permute(0, 4, 1, 2, 3).repeat(1, 3, 1, 1, 1) # B, C, T, H, W if ref_images is not None: # Create padded image if ref_images.shape[0] > 1: ref_images = torch.cat([ref_images[i] for i in range(ref_images.shape[0])], dim=1).unsqueeze(0) B, H, W, C = ref_images.shape current_aspect = W / H target_aspect = width / height if current_aspect > target_aspect: # Image is wider than target, pad height new_h = int(W / target_aspect) pad_h = (new_h - H) // 2 padded = torch.ones(ref_images.shape[0], new_h, W, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) padded[:, pad_h:pad_h+H, :, :] = ref_images ref_images = padded elif current_aspect < target_aspect: # Image is taller than target, pad width new_w = int(H * target_aspect) pad_w = (new_w - W) // 2 padded = torch.ones(ref_images.shape[0], H, new_w, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) padded[:, :, pad_w:pad_w+W, :] = ref_images ref_images = padded ref_images = common_upscale(ref_images.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) ref_images = ref_images.to(self.vae.dtype).to(self.device).unsqueeze(0).permute(0, 4, 1, 2, 3).unsqueeze(0) ref_images = ref_images * 2 - 1 z0 = self.vace_encode_frames(input_frames, ref_images, masks=input_masks, tiled_vae=tiled_vae) self.vae.model.clear_cache() m0 = self.vace_encode_masks(input_masks, ref_images) z = self.vace_latent(z0, m0) self.vae.to(offload_device) vace_input = { "vace_context": z, "vace_scale": strength, "has_ref": ref_images is not None, "num_frames": num_frames, "target_shape": target_shape, "vace_start_percent": vace_start_percent, "vace_end_percent": vace_end_percent, "vace_seq_len": math.ceil((z[0].shape[2] * z[0].shape[3]) / 4 * z[0].shape[1]), "additional_vace_inputs": [], } if prev_vace_embeds is not None: if "additional_vace_inputs" in prev_vace_embeds and prev_vace_embeds["additional_vace_inputs"]: vace_input["additional_vace_inputs"] = prev_vace_embeds["additional_vace_inputs"].copy() vace_input["additional_vace_inputs"].append(prev_vace_embeds) return (vace_input,) def vace_encode_frames(self, frames, ref_images, masks=None, tiled_vae=False): if ref_images is None: ref_images = [None] * len(frames) else: assert len(frames) == len(ref_images) if masks is None: latents = self.vae.encode(frames, device=self.device, tiled=tiled_vae) else: inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] inactive = self.vae.encode(inactive, device=self.device, tiled=tiled_vae) reactive = self.vae.encode(reactive, device=self.device, tiled=tiled_vae) latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] self.vae.model.clear_cache() cat_latents = [] for latent, refs in zip(latents, ref_images): if refs is not None: if masks is None: ref_latent = self.vae.encode(refs, device=self.device, tiled=tiled_vae) else: print("refs shape", refs.shape)#torch.Size([3, 1, 512, 512]) ref_latent = self.vae.encode(refs, device=self.device, tiled=tiled_vae) ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] assert all([x.shape[1] == 1 for x in ref_latent]) latent = torch.cat([*ref_latent, latent], dim=1) cat_latents.append(latent) return cat_latents def vace_encode_masks(self, masks, ref_images=None): if ref_images is None: ref_images = [None] * len(masks) else: assert len(masks) == len(ref_images) result_masks = [] for mask, refs in zip(masks, ref_images): c, depth, height, width = mask.shape new_depth = int((depth + 3) // self.vae_stride[0]) height = 2 * (int(height) // (self.vae_stride[1] * 2)) width = 2 * (int(width) // (self.vae_stride[2] * 2)) # reshape mask = mask[0, :, :, :] mask = mask.view( depth, height, self.vae_stride[1], width, self.vae_stride[1] ) # depth, height, 8, width, 8 mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width mask = mask.reshape( self.vae_stride[1] * self.vae_stride[2], depth, height, width ) # 8*8, depth, height, width # interpolation mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) if refs is not None: length = len(refs) mask_pad = torch.zeros_like(mask[:, :length, :, :]) mask = torch.cat((mask_pad, mask), dim=1) result_masks.append(mask) return result_masks def vace_latent(self, z, m): return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] class WanVideoVACEStartToEndFrame: @classmethod def INPUT_TYPES(s): return {"required": { "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), "empty_frame_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "White level of empty frame to use"}), }, "optional": { "start_image": ("IMAGE",), "end_image": ("IMAGE",), "control_images": ("IMAGE",), "inpaint_mask": ("MASK", {"tooltip": "Inpaint mask to use for the empty frames"}), }, } RETURN_TYPES = ("IMAGE", "MASK", ) RETURN_NAMES = ("images", "masks",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Helper node to create start/end frame batch and masks for VACE" def process(self, num_frames, empty_frame_level, start_image=None, end_image=None, control_images=None, inpaint_mask=None): B, H, W, C = start_image.shape if start_image is not None else end_image.shape device = start_image.device if start_image is not None else end_image.device masks = torch.ones((num_frames, H, W), device=device) if control_images is not None: control_images = common_upscale(control_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1) if start_image is not None and end_image is not None: if start_image.shape != end_image.shape: end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1) if control_images is None: empty_frames = torch.ones((num_frames - start_image.shape[0] - end_image.shape[0], H, W, 3), device=device) * empty_frame_level else: empty_frames = control_images[start_image.shape[0]:num_frames - end_image.shape[0]] out_batch = torch.cat([start_image, empty_frames, end_image], dim=0) masks[0:start_image.shape[0]] = 0 masks[-end_image.shape[0]:] = 0 elif start_image is not None: if control_images is None: empty_frames = torch.ones((num_frames - start_image.shape[0], H, W, 3), device=device) * empty_frame_level else: empty_frames = control_images[start_image.shape[0]:num_frames] out_batch = torch.cat([start_image, empty_frames], dim=0) masks[0:start_image.shape[0]] = 0 elif end_image is not None: if control_images is None: empty_frames = torch.ones((num_frames - end_image.shape[0], H, W, 3), device=device) * empty_frame_level else: empty_frames = control_images[:num_frames - end_image.shape[0]] out_batch = torch.cat([empty_frames, end_image], dim=0) masks[-end_image.shape[0]:] = 0 if inpaint_mask is not None: inpaint_mask = common_upscale(inpaint_mask.unsqueeze(1), W, H, "nearest-exact", "disabled").squeeze(1).to(device) if inpaint_mask.shape[0] > num_frames: inpaint_mask = inpaint_mask[:num_frames] elif inpaint_mask.shape[0] < num_frames: inpaint_mask = inpaint_mask.repeat(num_frames // inpaint_mask.shape[0] + 1, 1, 1)[:num_frames] empty_mask = torch.ones_like(masks, device=device) masks = inpaint_mask * empty_mask return (out_batch.cpu().float(), masks.cpu().float()) #region context options class WanVideoContextOptions: @classmethod def INPUT_TYPES(s): return {"required": { "context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],), "context_frames": ("INT", {"default": 81, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), "context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), "context_overlap": ("INT", {"default": 16, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), "freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}), "verbose": ("BOOLEAN", {"default": False, "tooltip": "Print debug output"}), }, "optional": { "vae": ("WANVAE",), } } RETURN_TYPES = ("WANVIDCONTEXT", ) RETURN_NAMES = ("context_options",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Context options for WanVideo, allows splitting the video into context windows and attemps blending them for longer generations than the model and memory otherwise would allow." def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise, verbose, image_cond_start_step=6, image_cond_window_count=2, vae=None): context_options = { "context_schedule":context_schedule, "context_frames":context_frames, "context_stride":context_stride, "context_overlap":context_overlap, "freenoise":freenoise, "verbose":verbose, "vae": vae, } return (context_options,) class CreateCFGScheduleFloatList: @classmethod def INPUT_TYPES(s): return {"required": { "steps": ("INT", {"default": 30, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of steps to schedule cfg for"} ), "cfg_scale_start": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}), "cfg_scale_end": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}), "interpolation": (["linear", "ease_in", "ease_out"], {"default": "linear", "tooltip": "Interpolation method to use for the cfg scale"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "Start percent of the steps to apply cfg"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "End percent of the steps to apply cfg"}), } } RETURN_TYPES = ("FLOAT", ) RETURN_NAMES = ("float_list",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Helper node to generate a list of floats that can be used to schedule cfg scale for the steps, outside the set range cfg is set to 1.0" def process(self, steps, cfg_scale_start, cfg_scale_end, interpolation, start_percent, end_percent): # Create a list of floats for the cfg schedule cfg_list = [1.0] * steps start_idx = min(int(steps * start_percent), steps - 1) end_idx = min(int(steps * end_percent), steps - 1) for i in range(start_idx, end_idx + 1): if i >= steps: break if end_idx == start_idx: t = 0 else: t = (i - start_idx) / (end_idx - start_idx) if interpolation == "linear": factor = t elif interpolation == "ease_in": factor = t * t elif interpolation == "ease_out": factor = t * (2 - t) cfg_list[i] = round(cfg_scale_start + factor * (cfg_scale_end - cfg_scale_start), 2) # If start_percent > 0, always include the first step if start_percent > 0: cfg_list[0] = 1.0 return (cfg_list,) class WanVideoFlowEdit: @classmethod def INPUT_TYPES(s): return {"required": { "source_embeds": ("WANVIDEOTEXTEMBEDS", ), "skip_steps": ("INT", {"default": 4, "min": 0}), "drift_steps": ("INT", {"default": 0, "min": 0}), "drift_flow_shift": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 30.0, "step": 0.01}), "source_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), "drift_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), }, "optional": { "source_image_embeds": ("WANVIDIMAGE_EMBEDS", ), } } RETURN_TYPES = ("FLOWEDITARGS", ) RETURN_NAMES = ("flowedit_args",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Flowedit options for WanVideo" def process(self, **kwargs): return (kwargs,) class WanVideoLoopArgs: @classmethod def INPUT_TYPES(s): return {"required": { "shift_skip": ("INT", {"default": 6, "min": 0, "tooltip": "Skip step of latent shift"}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the looping effect"}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the looping effect"}), }, } RETURN_TYPES = ("LOOPARGS", ) RETURN_NAMES = ("loop_args",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Looping through latent shift as shown in https://github.com/YisuiTT/Mobius/" def process(self, **kwargs): return (kwargs,) class WanVideoExperimentalArgs: @classmethod def INPUT_TYPES(s): return {"required": { "video_attention_split_steps": ("STRING", {"default": "", "tooltip": "Steps to split self attention when using multiple prompts"}), "cfg_zero_star": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WeichenFan/CFG-Zero-star"}), "use_zero_init": ("BOOLEAN", {"default": False}), "zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "Steps to split self attention when using multiple prompts"}), "use_fresca": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WikiChao/FreSca"}), "fresca_scale_low": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "fresca_scale_high": ("FLOAT", {"default": 1.25, "min": 0.0, "max": 10.0, "step": 0.01}), "fresca_freq_cutoff": ("INT", {"default": 20, "min": 0, "max": 10000, "step": 1}), }, } RETURN_TYPES = ("EXPERIMENTALARGS", ) RETURN_NAMES = ("exp_args",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" DESCRIPTION = "Experimental stuff" EXPERIMENTAL = True def process(self, **kwargs): return (kwargs,) #region Sampler class WanVideoSampler: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("WANVIDEOMODEL",), "image_embeds": ("WANVIDIMAGE_EMBEDS", ), "steps": ("INT", {"default": 30, "min": 1}), "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), "shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}), "scheduler": (["unipc", "unipc/beta", "dpm++", "dpm++/beta","dpm++_sde", "dpm++_sde/beta", "euler", "euler/beta", "euler/accvideo", "deis", "lcm", "lcm/beta", "flowmatch_causvid", "flowmatch_distill"], { "default": 'unipc' }), "riflex_freq_index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": "Frequency index for RIFLEX, disabled when 0, default 6. Allows for new frames to be generated after without looping"}), }, "optional": { "text_embeds": ("WANVIDEOTEXTEMBEDS", ), "samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "feta_args": ("FETAARGS", ), "context_options": ("WANVIDCONTEXT", ), "cache_args": ("CACHEARGS", ), "flowedit_args": ("FLOWEDITARGS", ), "batched_cfg": ("BOOLEAN", {"default": False, "tooltip": "Batch cond and uncond for faster sampling, possibly faster on some hardware, uses more memory"}), "slg_args": ("SLGARGS", ), "rope_function": (["default", "comfy"], {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile"}), "loop_args": ("LOOPARGS", ), "experimental_args": ("EXPERIMENTALARGS", ), "sigmas": ("SIGMAS", ), "unianimate_poses": ("UNIANIMATE_POSE", ), "fantasytalking_embeds": ("FANTASYTALKING_EMBEDS", ), "uni3c_embeds": ("UNI3C_EMBEDS", ), } } RETURN_TYPES = ("LATENT", ) RETURN_NAMES = ("samples",) FUNCTION = "process" CATEGORY = "WanVideoWrapper" def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, riflex_freq_index, text_embeds=None, force_offload=True, samples=None, feta_args=None, denoise_strength=1.0, context_options=None, cache_args=None, teacache_args=None, flowedit_args=None, batched_cfg=False, slg_args=None, rope_function="default", loop_args=None, experimental_args=None, sigmas=None, unianimate_poses=None, fantasytalking_embeds=None, uni3c_embeds=None): patcher = model model = model.model transformer = model.diffusion_model dtype = model["dtype"] control_lora = model["control_lora"] transformer_options = patcher.model_options.get("transformer_options", None) device = mm.get_torch_device() offload_device = mm.unet_offload_device() steps = int(steps/denoise_strength) if text_embeds == None: text_embeds = { "prompt_embeds": [], "negative_prompt_embeds": [], } if isinstance(cfg, list): if steps != len(cfg): log.info(f"Received {len(cfg)} cfg values, but only {steps} steps. Setting step count to match.") steps = len(cfg) timesteps = None if 'unipc' in scheduler: sample_scheduler = FlowUniPCMultistepScheduler(shift=shift) if sigmas is None: sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler)) else: sample_scheduler.sigmas = sigmas.to(device) sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device) sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps) elif scheduler in ['euler/beta', 'euler']: sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta')) if flowedit_args: #seems to work better timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=get_sampling_sigmas(steps, shift)) else: sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas.tolist() if sigmas is not None else None) elif scheduler in ['euler/accvideo']: if steps != 50: raise Exception("Steps must be set to 50 for accvideo scheduler, 10 actual steps are used") sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta')) sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas.tolist() if sigmas is not None else None) start_latent_list = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50] sample_scheduler.sigmas = sample_scheduler.sigmas[start_latent_list] steps = len(start_latent_list) - 1 sample_scheduler.timesteps = timesteps = sample_scheduler.timesteps[start_latent_list[:steps]] elif 'dpm++' in scheduler: if 'sde' in scheduler: algorithm_type = "sde-dpmsolver++" else: algorithm_type = "dpmsolver++" sample_scheduler = FlowDPMSolverMultistepScheduler(shift=shift, algorithm_type=algorithm_type) if sigmas is None: sample_scheduler.set_timesteps(steps, device=device, use_beta_sigmas=('beta' in scheduler)) else: sample_scheduler.sigmas = sigmas.to(device) sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device) sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps) elif scheduler == 'deis': sample_scheduler = DEISMultistepScheduler(use_flow_sigmas=True, prediction_type="flow_prediction", flow_shift=shift) sample_scheduler.set_timesteps(steps, device=device) sample_scheduler.sigmas[-1] = 1e-6 elif 'lcm' in scheduler: sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta')) sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas.tolist() if sigmas is not None else None) elif 'flowmatch_causvid' in scheduler: if transformer.dim == 5120: denoising_list = [999, 934, 862, 756, 603, 410, 250, 140, 74] else: if steps != 4: raise ValueError("CausVid 1.3B schedule is only for 4 steps") denoising_list = [1000, 750, 500, 250] sample_scheduler = FlowMatchScheduler(num_inference_steps=steps, shift=shift, sigma_min=0, extra_one_step=True) sample_scheduler.timesteps = torch.tensor(denoising_list)[:steps].to(device) sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)]) elif 'flowmatch_distill' in scheduler: sample_scheduler = FlowMatchScheduler( shift=shift, sigma_min=0.0, extra_one_step=True ) sample_scheduler.set_timesteps(1000, training=True) denoising_step_list = torch.tensor([999, 750, 500, 250] , dtype=torch.long) temp_timesteps = torch.cat((sample_scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))) denoising_step_list = temp_timesteps[1000 - denoising_step_list] print("denoising_step_list: ", denoising_step_list) #denoising_step_list = [999, 750, 500, 250] if steps != 4: raise ValueError("This scheduler is only for 4 steps") #sample_scheduler = FlowMatchScheduler(num_inference_steps=steps, shift=shift, sigma_min=0, extra_one_step=True) sample_scheduler.timesteps = torch.tensor(denoising_step_list)[:steps].to(device) sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)]) if timesteps is None: timesteps = sample_scheduler.timesteps log.info(f"timesteps: {timesteps}") if denoise_strength < 1.0: steps = int(steps * denoise_strength) timesteps = timesteps[-(steps + 1):] seed_g = torch.Generator(device=torch.device("cpu")) seed_g.manual_seed(seed) control_latents = control_camera_latents = clip_fea = clip_fea_neg = end_image = recammaster = camera_embed = unianim_data = None vace_data = vace_context = vace_scale = None fun_or_fl2v_model = has_ref = drop_last = False phantom_latents = None fun_ref_image = None image_cond = image_embeds.get("image_embeds", None) ATI_tracks = None add_cond = attn_cond = attn_cond_neg = None if image_cond is not None: log.info(f"image_cond shape: {image_cond.shape}") #ATI tracks if transformer_options is not None: ATI_tracks = transformer_options.get("ati_tracks", None) if ATI_tracks is not None: from .ATI.motion_patch import patch_motion topk = transformer_options.get("ati_topk", 2) temperature = transformer_options.get("ati_temperature", 220.0) ati_start_percent = transformer_options.get("ati_start_percent", 0.0) ati_end_percent = transformer_options.get("ati_end_percent", 1.0) image_cond_ati = patch_motion(ATI_tracks.to(image_cond.device, image_cond.dtype), image_cond, topk=topk, temperature=temperature) log.info(f"ATI tracks shape: {ATI_tracks.shape}") realisdance_latents = image_embeds.get("realisdance_latents", None) if realisdance_latents is not None: add_cond = realisdance_latents["pose_latent"] attn_cond = realisdance_latents["ref_latent"] attn_cond_neg = realisdance_latents["ref_latent_neg"] add_cond_start_percent = realisdance_latents["pose_cond_start_percent"] add_cond_end_percent = realisdance_latents["pose_cond_end_percent"] end_image = image_embeds.get("end_image", None) lat_h = image_embeds.get("lat_h", None) lat_w = image_embeds.get("lat_w", None) if lat_h is None or lat_w is None: raise ValueError("Clip encoded image embeds must be provided for I2V (Image to Video) model") fun_or_fl2v_model = image_embeds.get("fun_or_fl2v_model", False) noise = torch.randn( 16, (image_embeds["num_frames"] - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1), lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=torch.device("cpu")) seq_len = image_embeds["max_seq_len"] clip_fea = image_embeds.get("clip_context", None) if clip_fea is not None: clip_fea = clip_fea.to(dtype) clip_fea_neg = image_embeds.get("negative_clip_context", None) if clip_fea_neg is not None: clip_fea_neg = clip_fea_neg.to(dtype) control_embeds = image_embeds.get("control_embeds", None) if control_embeds is not None: if transformer.in_dim not in [48, 32]: raise ValueError("Control signal only works with Fun-Control model") control_latents = control_embeds.get("control_images", None) control_camera_latents = control_embeds.get("control_camera_latents", None) control_camera_start_percent = control_embeds.get("control_camera_start_percent", 0.0) control_camera_end_percent = control_embeds.get("control_camera_end_percent", 1.0) control_start_percent = control_embeds.get("start_percent", 0.0) control_end_percent = control_embeds.get("end_percent", 1.0) drop_last = image_embeds.get("drop_last", False) has_ref = image_embeds.get("has_ref", False) else: #t2v target_shape = image_embeds.get("target_shape", None) if target_shape is None: raise ValueError("Empty image embeds must be provided for T2V (Text to Video") has_ref = image_embeds.get("has_ref", False) vace_context = image_embeds.get("vace_context", None) vace_scale = image_embeds.get("vace_scale", None) if not isinstance(vace_scale, list): vace_scale = [vace_scale] * (steps+1) vace_start_percent = image_embeds.get("vace_start_percent", 0.0) vace_end_percent = image_embeds.get("vace_end_percent", 1.0) vace_seqlen = image_embeds.get("vace_seq_len", None) vace_additional_embeds = image_embeds.get("additional_vace_inputs", []) if vace_context is not None: vace_data = [ {"context": vace_context, "scale": vace_scale, "start": vace_start_percent, "end": vace_end_percent, "seq_len": vace_seqlen } ] if len(vace_additional_embeds) > 0: for i in range(len(vace_additional_embeds)): if vace_additional_embeds[i].get("has_ref", False): has_ref = True vace_scale = vace_additional_embeds[i]["vace_scale"] if not isinstance(vace_scale, list): vace_scale = [vace_scale] * (steps+1) vace_data.append({ "context": vace_additional_embeds[i]["vace_context"], "scale": vace_scale, "start": vace_additional_embeds[i]["vace_start_percent"], "end": vace_additional_embeds[i]["vace_end_percent"], "seq_len": vace_additional_embeds[i]["vace_seq_len"] }) noise = torch.randn( target_shape[0], target_shape[1] + 1 if has_ref else target_shape[1], target_shape[2], target_shape[3], dtype=torch.float32, device=torch.device("cpu"), generator=seed_g) seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1]) recammaster = image_embeds.get("recammaster", None) if recammaster is not None: camera_embed = recammaster.get("camera_embed", None) recam_latents = recammaster.get("source_latents", None) orig_noise_len = noise.shape[1] log.info(f"RecamMaster camera embed shape: {camera_embed.shape}") log.info(f"RecamMaster source video shape: {recam_latents.shape}") seq_len *= 2 control_embeds = image_embeds.get("control_embeds", None) if control_embeds is not None: control_latents = control_embeds.get("control_images", None) if control_latents is not None: control_latents = control_latents.to(device) control_camera_latents = control_embeds.get("control_camera_latents", None) control_camera_start_percent = control_embeds.get("control_camera_start_percent", 0.0) control_camera_end_percent = control_embeds.get("control_camera_end_percent", 1.0) if control_camera_latents is not None: control_camera_latents = control_camera_latents.to(device) if control_lora: image_cond = control_latents.to(device) if not patcher.model.is_patched: log.info("Re-loading control LoRA...") patcher = apply_lora(patcher, device, device, low_mem_load=False) patcher.model.is_patched = True else: if transformer.in_dim not in [48, 32]: raise ValueError("Control signal only works with Fun-Control model") image_cond = torch.zeros_like(noise).to(device) #fun control clip_fea = None fun_ref_image = control_embeds.get("fun_ref_image", None) control_start_percent = control_embeds.get("start_percent", 0.0) control_end_percent = control_embeds.get("end_percent", 1.0) else: if transformer.in_dim == 36: #fun inp mask_latents = torch.tile( torch.zeros_like(noise[:1]), [4, 1, 1, 1] ) masked_video_latents_input = torch.zeros_like(noise) image_cond = torch.cat([mask_latents, masked_video_latents_input], dim=0).to(device) phantom_latents = image_embeds.get("phantom_latents", None) phantom_cfg_scale = image_embeds.get("phantom_cfg_scale", None) if not isinstance(phantom_cfg_scale, list): phantom_cfg_scale = [phantom_cfg_scale] * (steps +1) phantom_start_percent = image_embeds.get("phantom_start_percent", 0.0) phantom_end_percent = image_embeds.get("phantom_end_percent", 1.0) if phantom_latents is not None: phantom_latents = phantom_latents.to(device) latent_video_length = noise.shape[1] if unianimate_poses is not None: transformer.dwpose_embedding.to(device, model["dtype"]) dwpose_data = unianimate_poses["pose"].to(device, model["dtype"]) dwpose_data = torch.cat([dwpose_data[:,:,:1].repeat(1,1,3,1,1), dwpose_data], dim=2) dwpose_data = transformer.dwpose_embedding(dwpose_data) log.info(f"UniAnimate pose embed shape: {dwpose_data.shape}") if dwpose_data.shape[2] > latent_video_length: log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is longer than the video length {latent_video_length}, truncating") dwpose_data = dwpose_data[:,:, :latent_video_length] elif dwpose_data.shape[2] < latent_video_length: log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is shorter than the video length {latent_video_length}, padding with last pose") pad_len = latent_video_length - dwpose_data.shape[2] pad = dwpose_data[:,:,:1].repeat(1,1,pad_len,1,1) dwpose_data = torch.cat([dwpose_data, pad], dim=2) dwpose_data_flat = rearrange(dwpose_data, 'b c f h w -> b (f h w) c').contiguous() random_ref_dwpose_data = None if image_cond is not None: transformer.randomref_embedding_pose.to(device) random_ref_dwpose = unianimate_poses.get("ref", None) if random_ref_dwpose is not None: random_ref_dwpose_data = transformer.randomref_embedding_pose( random_ref_dwpose.to(device) ).unsqueeze(2).to(model["dtype"]) # [1, 20, 104, 60] unianim_data = { "dwpose": dwpose_data_flat, "random_ref": random_ref_dwpose_data.squeeze(0) if random_ref_dwpose_data is not None else None, "strength": unianimate_poses["strength"], "start_percent": unianimate_poses["start_percent"], "end_percent": unianimate_poses["end_percent"] } audio_proj = None if fantasytalking_embeds is not None: audio_proj = fantasytalking_embeds["audio_proj"].to(device) audio_context_lens = fantasytalking_embeds["audio_context_lens"] audio_scale = fantasytalking_embeds["audio_scale"] audio_cfg_scale = fantasytalking_embeds["audio_cfg_scale"] if not isinstance(audio_cfg_scale, list): audio_cfg_scale = [audio_cfg_scale] * (steps +1) log.info(f"Audio proj shape: {audio_proj.shape}, audio context lens: {audio_context_lens}") minimax_latents = minimax_mask_latents = None minimax_latents = image_embeds.get("minimax_latents", None) minimax_mask_latents = image_embeds.get("minimax_mask_latents", None) if minimax_latents is not None: log.info(f"minimax_latents: {minimax_latents.shape}") log.info(f"minimax_mask_latents: {minimax_mask_latents.shape}") minimax_latents = minimax_latents.to(device, dtype) minimax_mask_latents = minimax_mask_latents.to(device, dtype) is_looped = False if context_options is not None: def create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=False): window_mask = torch.ones_like(noise_pred_context) # Apply left-side blending for all except first chunk (or always in loop mode) if min(c) > 0 or (looped and max(c) == latent_video_length - 1): ramp_up = torch.linspace(0, 1, context_overlap, device=noise_pred_context.device) ramp_up = ramp_up.view(1, -1, 1, 1) window_mask[:, :context_overlap] = ramp_up # Apply right-side blending for all except last chunk (or always in loop mode) if max(c) < latent_video_length - 1 or (looped and min(c) == 0): ramp_down = torch.linspace(1, 0, context_overlap, device=noise_pred_context.device) ramp_down = ramp_down.view(1, -1, 1, 1) window_mask[:, -context_overlap:] = ramp_down return window_mask context_schedule = context_options["context_schedule"] context_frames = (context_options["context_frames"] - 1) // 4 + 1 context_stride = context_options["context_stride"] // 4 context_overlap = context_options["context_overlap"] // 4 context_vae = context_options.get("vae", None) if context_vae is not None: context_vae.to(device) self.window_tracker = WindowTracker(verbose=context_options["verbose"]) # Get total number of prompts num_prompts = len(text_embeds["prompt_embeds"]) log.info(f"Number of prompts: {num_prompts}") # Calculate which section this context window belongs to section_size = latent_video_length / num_prompts log.info(f"Section size: {section_size}") is_looped = context_schedule == "uniform_looped" seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * context_frames) if context_options["freenoise"]: log.info("Applying FreeNoise") # code from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) delta = context_frames - context_overlap for start_idx in range(0, latent_video_length-context_frames, delta): place_idx = start_idx + context_frames if place_idx >= latent_video_length: break end_idx = place_idx - 1 if end_idx + delta >= latent_video_length: final_delta = latent_video_length - place_idx list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) list_idx = list_idx[torch.randperm(final_delta, generator=seed_g)] noise[:, place_idx:place_idx + final_delta, :, :] = noise[:, list_idx, :, :] break list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) list_idx = list_idx[torch.randperm(delta, generator=seed_g)] noise[:, place_idx:place_idx + delta, :, :] = noise[:, list_idx, :, :] log.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") from .context import get_context_scheduler context = get_context_scheduler(context_schedule) if samples is not None: input_samples = samples["samples"].squeeze(0).to(noise) if input_samples.shape[1] != noise.shape[1]: input_samples = torch.cat([input_samples[:, :1].repeat(1, noise.shape[1] - input_samples.shape[1], 1, 1), input_samples], dim=1) original_image = input_samples.to(device) if denoise_strength < 1.0: latent_timestep = timesteps[:1].to(noise) noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples mask = samples.get("mask", None) if mask is not None: if mask.shape[2] != noise.shape[1]: mask = torch.cat([torch.zeros(1, noise.shape[0], noise.shape[1] - mask.shape[2], noise.shape[2], noise.shape[3]), mask], dim=2) latent = noise.to(device) freqs = None transformer.rope_embedder.k = None transformer.rope_embedder.num_frames = None if rope_function=="comfy": transformer.rope_embedder.k = riflex_freq_index transformer.rope_embedder.num_frames = latent_video_length else: d = transformer.dim // transformer.num_heads freqs = torch.cat([ rope_params(1024, d - 4 * (d // 6), L_test=latent_video_length, k=riflex_freq_index), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)) ], dim=1) if not isinstance(cfg, list): cfg = [cfg] * (steps +1) log.info(f"Seq len: {seq_len}") pbar = ProgressBar(steps) if args.preview_method in [LatentPreviewMethod.Auto, LatentPreviewMethod.Latent2RGB]: #default for latent2rgb from latent_preview import prepare_callback else: from .latent_preview import prepare_callback #custom for tiny VAE previews callback = prepare_callback(patcher, steps) #blockswap init if transformer_options is not None: block_swap_args = transformer_options.get("block_swap_args", None) if block_swap_args is not None: transformer.use_non_blocking = block_swap_args.get("use_non_blocking", True) for name, param in transformer.named_parameters(): if "block" not in name: param.data = param.data.to(device) if "control_adapter" in name: param.data = param.data.to(device) elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: param.data = param.data.to(offload_device, non_blocking=transformer.use_non_blocking) elif block_swap_args["offload_img_emb"] and "img_emb" in name: param.data = param.data.to(offload_device, non_blocking=transformer.use_non_blocking) transformer.block_swap( block_swap_args["blocks_to_swap"] - 1 , block_swap_args["offload_txt_emb"], block_swap_args["offload_img_emb"], vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), ) elif model["auto_cpu_offload"]: for module in transformer.modules(): if hasattr(module, "offload"): module.offload() if hasattr(module, "onload"): module.onload() elif model["manual_offloading"]: transformer.to(device) #controlnet controlnet_latents = controlnet = None if transformer_options is not None: controlnet = transformer_options.get("controlnet", None) if controlnet is not None: self.controlnet = controlnet["controlnet"] controlnet_start = controlnet["controlnet_start"] controlnet_end = controlnet["controlnet_end"] controlnet_latents = controlnet["control_latents"] controlnet["controlnet_weight"] = controlnet["controlnet_strength"] controlnet["controlnet_stride"] = controlnet["control_stride"] #uni3c pcd_data = None if uni3c_embeds is not None: transformer.controlnet = uni3c_embeds["controlnet"] pcd_data = { "render_latent": uni3c_embeds["render_latent"], "render_mask": uni3c_embeds["render_mask"], "camera_embedding": uni3c_embeds["camera_embedding"], "controlnet_weight": uni3c_embeds["controlnet_weight"], "start": uni3c_embeds["start"], "end": uni3c_embeds["end"], } #feta if feta_args is not None and latent_video_length > 1: set_enhance_weight(feta_args["weight"]) feta_start_percent = feta_args["start_percent"] feta_end_percent = feta_args["end_percent"] if context_options is not None: set_num_frames(context_frames) else: set_num_frames(latent_video_length) enable_enhance() else: feta_args = None disable_enhance() # Initialize Cache if enabled transformer.enable_teacache = transformer.enable_magcache = False if teacache_args is not None: #for backward compatibility on old workflows cache_args = teacache_args if cache_args is not None: transformer.cache_device = cache_args["cache_device"] if cache_args["cache_type"] == "TeaCache": log.info(f"TeaCache: Using cache device: {transformer.cache_device}") transformer.teacache_state.clear_all() transformer.enable_teacache = True transformer.rel_l1_thresh = cache_args["rel_l1_thresh"] transformer.teacache_start_step = cache_args["start_step"] transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] transformer.teacache_use_coefficients = cache_args["use_coefficients"] transformer.teacache_mode = cache_args["mode"] elif cache_args["cache_type"] == "MagCache": log.info(f"MagCache: Using cache device: {transformer.cache_device}") transformer.magcache_state.clear_all() transformer.enable_magcache = True transformer.magcache_start_step = cache_args["start_step"] transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] transformer.magcache_thresh = cache_args["magcache_thresh"] transformer.magcache_K = cache_args["magcache_K"] if slg_args is not None: assert batched_cfg is not None, "Batched cfg is not supported with SLG" transformer.slg_blocks = slg_args["blocks"] transformer.slg_start_percent = slg_args["start_percent"] transformer.slg_end_percent = slg_args["end_percent"] else: transformer.slg_blocks = None self.cache_state = [None, None] if phantom_latents is not None: log.info(f"Phantom latents shape: {phantom_latents.shape}") self.cache_state = [None, None, None] self.cache_state_source = [None, None] self.cache_states_context = [] if flowedit_args is not None: source_embeds = flowedit_args["source_embeds"] source_image_embeds = flowedit_args.get("source_image_embeds", image_embeds) source_image_cond = source_image_embeds.get("image_embeds", None) source_clip_fea = source_image_embeds.get("clip_fea", clip_fea) if source_image_cond is not None: source_image_cond = source_image_cond.to(dtype) skip_steps = flowedit_args["skip_steps"] drift_steps = flowedit_args["drift_steps"] source_cfg = flowedit_args["source_cfg"] if not isinstance(source_cfg, list): source_cfg = [source_cfg] * (steps +1) drift_cfg = flowedit_args["drift_cfg"] if not isinstance(drift_cfg, list): drift_cfg = [drift_cfg] * (steps +1) x_init = samples["samples"].clone().squeeze(0).to(device) x_tgt = samples["samples"].squeeze(0).to(device) sample_scheduler = FlowMatchEulerDiscreteScheduler( num_train_timesteps=1000, shift=flowedit_args["drift_flow_shift"], use_dynamic_shifting=False) sampling_sigmas = get_sampling_sigmas(steps, flowedit_args["drift_flow_shift"]) drift_timesteps, _ = retrieve_timesteps( sample_scheduler, device=device, sigmas=sampling_sigmas) if drift_steps > 0: drift_timesteps = torch.cat([drift_timesteps, torch.tensor([0]).to(drift_timesteps.device)]).to(drift_timesteps.device) timesteps[-drift_steps:] = drift_timesteps[-drift_steps:] use_cfg_zero_star = use_fresca = False if experimental_args is not None: video_attention_split_steps = experimental_args.get("video_attention_split_steps", []) if video_attention_split_steps: transformer.video_attention_split_steps = [int(x.strip()) for x in video_attention_split_steps.split(",")] else: transformer.video_attention_split_steps = [] use_zero_init = experimental_args.get("use_zero_init", True) use_cfg_zero_star = experimental_args.get("cfg_zero_star", False) zero_star_steps = experimental_args.get("zero_star_steps", 0) use_fresca = experimental_args.get("use_fresca", False) if use_fresca: fresca_scale_low = experimental_args.get("fresca_scale_low", 1.0) fresca_scale_high = experimental_args.get("fresca_scale_high", 1.25) fresca_freq_cutoff = experimental_args.get("fresca_freq_cutoff", 20) #region model pred def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None, control_latents=None, vace_data=None, unianim_data=None, audio_proj=None, control_camera_latents=None, add_cond=None, cache_state=None): z = z.to(dtype) with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=("fp8" in model["quantization"])): if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init: return z*0, None nonlocal patcher current_step_percentage = idx / len(timesteps) control_lora_enabled = False image_cond_input = None if control_latents is not None: if control_lora: control_lora_enabled = True else: if (control_start_percent <= current_step_percentage <= control_end_percent) or \ (control_end_percent > 0 and idx == 0 and current_step_percentage >= control_start_percent): image_cond_input = torch.cat([control_latents.to(z), image_cond.to(z)]) else: image_cond_input = torch.cat([torch.zeros_like(image_cond, dtype=dtype), image_cond.to(z)]) if fun_ref_image is not None: fun_ref_input = fun_ref_image.to(z) else: fun_ref_input = torch.zeros_like(z, dtype=z.dtype)[:, 0].unsqueeze(1) #fun_ref_input = None if control_lora: if not control_start_percent <= current_step_percentage <= control_end_percent: control_lora_enabled = False if patcher.model.is_patched: log.info("Unloading LoRA...") patcher.unpatch_model(device) patcher.model.is_patched = False else: image_cond_input = control_latents.to(z) if not patcher.model.is_patched: log.info("Loading LoRA...") patcher = apply_lora(patcher, device, device, low_mem_load=False) patcher.model.is_patched = True elif ATI_tracks is not None and ((ati_start_percent <= current_step_percentage <= ati_end_percent) or (ati_end_percent > 0 and idx == 0 and current_step_percentage >= ati_start_percent)): image_cond_input = image_cond_ati.to(z) else: image_cond_input = image_cond.to(z) if image_cond is not None else None if control_camera_latents is not None: if (control_camera_start_percent <= current_step_percentage <= control_camera_end_percent) or \ (control_end_percent > 0 and idx == 0 and current_step_percentage >= control_camera_start_percent): control_camera_input = control_camera_latents.to(z) else: control_camera_input = None if recammaster is not None: z = torch.cat([z, recam_latents.to(z)], dim=1) use_phantom = False if phantom_latents is not None: if (phantom_start_percent <= current_step_percentage <= phantom_end_percent) or \ (phantom_end_percent > 0 and idx == 0 and current_step_percentage >= phantom_start_percent): z_pos = torch.cat([z[:,:-phantom_latents.shape[1]], phantom_latents.to(z)], dim=1) z_phantom_img = torch.cat([z[:,:-phantom_latents.shape[1]], phantom_latents.to(z)], dim=1) z_neg = torch.cat([z[:,:-phantom_latents.shape[1]], torch.zeros_like(phantom_latents).to(z)], dim=1) use_phantom = True if cache_state is not None and len(cache_state) != 3: cache_state.append(None) if not use_phantom: z_pos = z_neg = z if controlnet_latents is not None: if (controlnet_start <= current_step_percentage < controlnet_end): self.controlnet.to(device) controlnet_states = self.controlnet( hidden_states=z.unsqueeze(0).to(device, self.controlnet.dtype), timestep=timestep, encoder_hidden_states=positive_embeds[0].unsqueeze(0).to(device, self.controlnet.dtype), attention_kwargs=None, controlnet_states=controlnet_latents.to(device, self.controlnet.dtype), return_dict=False, )[0] if isinstance(controlnet_states, (tuple, list)): controlnet["controlnet_states"] = [x.to(z) for x in controlnet_states] else: controlnet["controlnet_states"] = controlnet_states.to(z) add_cond_input = None if add_cond is not None: if (add_cond_start_percent <= current_step_percentage <= add_cond_end_percent) or \ (add_cond_end_percent > 0 and idx == 0 and current_step_percentage >= add_cond_start_percent): add_cond_input = add_cond if minimax_latents is not None: z_pos = z_neg = torch.cat([z, minimax_latents, minimax_mask_latents], dim=0) base_params = { 'seq_len': seq_len, 'device': device, 'freqs': freqs, 't': timestep, 'current_step': idx, 'control_lora_enabled': control_lora_enabled, 'camera_embed': camera_embed, 'unianim_data': unianim_data, 'fun_ref': fun_ref_input if fun_ref_image is not None else None, 'fun_camera': control_camera_input if control_camera_latents is not None else None, 'audio_proj': audio_proj if fantasytalking_embeds is not None else None, 'audio_context_lens': audio_context_lens if fantasytalking_embeds is not None else None, 'audio_scale': audio_scale if fantasytalking_embeds is not None else None, "pcd_data": pcd_data, "controlnet": controlnet, "add_cond": add_cond_input, "nag_params": text_embeds.get("nag_params", {}), "nag_context": text_embeds.get("nag_prompt_embeds", None), } batch_size = 1 if not math.isclose(cfg_scale, 1.0) and len(positive_embeds) > 1: negative_embeds = negative_embeds * len(positive_embeds) if not batched_cfg: #cond noise_pred_cond, cache_state_cond = transformer( [z_pos], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, pred_id=cache_state[0] if cache_state else None, vace_data=vace_data, attn_cond=attn_cond, **base_params ) noise_pred_cond = noise_pred_cond[0].to(intermediate_device) if math.isclose(cfg_scale, 1.0): if use_fresca: noise_pred_cond = fourier_filter( noise_pred_cond, scale_low=fresca_scale_low, scale_high=fresca_scale_high, freq_cutoff=fresca_freq_cutoff, ) return noise_pred_cond, [cache_state_cond] #uncond if fantasytalking_embeds is not None: if not math.isclose(audio_cfg_scale[idx], 1.0): base_params['audio_proj'] = None noise_pred_uncond, cache_state_uncond = transformer( [z_neg], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, y=[image_cond_input] if image_cond_input is not None else None, is_uncond=True, current_step_percentage=current_step_percentage, pred_id=cache_state[1] if cache_state else None, vace_data=vace_data, attn_cond=attn_cond_neg, **base_params ) noise_pred_uncond = noise_pred_uncond[0].to(intermediate_device) #phantom if use_phantom and not math.isclose(phantom_cfg_scale[idx], 1.0): noise_pred_phantom, cache_state_phantom = transformer( [z_phantom_img], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, y=[image_cond_input] if image_cond_input is not None else None, is_uncond=True, current_step_percentage=current_step_percentage, pred_id=cache_state[2] if cache_state else None, vace_data=None, **base_params ) noise_pred_phantom = noise_pred_phantom[0].to(intermediate_device) noise_pred = noise_pred_uncond + phantom_cfg_scale[idx] * (noise_pred_phantom - noise_pred_uncond) + cfg_scale * (noise_pred_cond - noise_pred_phantom) return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_phantom] #fantasytalking if fantasytalking_embeds is not None: if not math.isclose(audio_cfg_scale[idx], 1.0): if cache_state is not None and len(cache_state) != 3: cache_state.append(None) base_params['audio_proj'] = None noise_pred_no_audio, cache_state_audio = transformer( [z_pos], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, pred_id=cache_state[2] if cache_state else None, vace_data=vace_data, **base_params ) noise_pred_no_audio = noise_pred_no_audio[0].to(intermediate_device) noise_pred = ( noise_pred_uncond + cfg_scale * (noise_pred_no_audio - noise_pred_uncond) + audio_cfg_scale[idx] * (noise_pred_cond - noise_pred_no_audio) ) return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_audio] #batched else: cache_state_uncond = None [noise_pred_cond, noise_pred_uncond], cache_state_cond = transformer( [z] + [z], context=positive_embeds + negative_embeds, y=[image_cond_input] + [image_cond_input] if image_cond_input is not None else None, clip_fea=clip_fea.repeat(2,1,1), is_uncond=False, current_step_percentage=current_step_percentage, pred_id=cache_state[0] if cache_state else None, **base_params ) #cfg #https://github.com/WeichenFan/CFG-Zero-star/ if use_cfg_zero_star: alpha = optimized_scale( noise_pred_cond.view(batch_size, -1), noise_pred_uncond.view(batch_size, -1) ).view(batch_size, 1, 1, 1) else: alpha = 1.0 #https://github.com/WikiChao/FreSca if use_fresca: filtered_cond = fourier_filter( noise_pred_cond - noise_pred_uncond, scale_low=fresca_scale_low, scale_high=fresca_scale_high, freq_cutoff=fresca_freq_cutoff, ) noise_pred = noise_pred_uncond * alpha + cfg_scale * filtered_cond * alpha else: noise_pred = noise_pred_uncond * alpha + cfg_scale * (noise_pred_cond - noise_pred_uncond * alpha) return noise_pred, [cache_state_cond, cache_state_uncond] log.info(f"Sampling {(latent_video_length-1) * 4 + 1} frames at {latent.shape[3]*8}x{latent.shape[2]*8} with {steps} steps") intermediate_device = device # diff diff prep masks = None if samples is not None and mask is not None: mask = 1 - mask thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps) thresholds = thresholds.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).to(device) masks = mask.repeat(len(timesteps), 1, 1, 1, 1).to(device) masks = masks > thresholds latent_shift_loop = False if loop_args is not None: latent_shift_loop = True is_looped = True latent_skip = loop_args["shift_skip"] latent_shift_start_percent = loop_args["start_percent"] latent_shift_end_percent = loop_args["end_percent"] shift_idx = 0 #clear memory before sampling mm.unload_all_models() mm.soft_empty_cache() gc.collect() try: torch.cuda.reset_peak_memory_stats(device) except: pass #region main loop start for idx, t in enumerate(tqdm(timesteps)): if flowedit_args is not None: if idx < skip_steps: continue # diff diff if masks is not None: if idx < len(timesteps) - 1: noise_timestep = timesteps[idx+1] image_latent = sample_scheduler.scale_noise( original_image, torch.tensor([noise_timestep]), noise.to(device) ) mask = masks[idx] mask = mask.to(latent) latent = image_latent * mask + latent * (1-mask) # end diff diff latent_model_input = latent.to(device) timestep = torch.tensor([t]).to(device) current_step_percentage = idx / len(timesteps) ### latent shift if latent_shift_loop: if latent_shift_start_percent <= current_step_percentage <= latent_shift_end_percent: latent_model_input = torch.cat([latent_model_input[:, shift_idx:]] + [latent_model_input[:, :shift_idx]], dim=1) #enhance-a-video if feta_args is not None and feta_start_percent <= current_step_percentage <= feta_end_percent: enable_enhance() else: disable_enhance() #flow-edit if flowedit_args is not None: sigma = t / 1000.0 sigma_prev = (timesteps[idx + 1] if idx < len(timesteps) - 1 else timesteps[-1]) / 1000.0 noise = torch.randn(x_init.shape, generator=seed_g, device=torch.device("cpu")) if idx < len(timesteps) - drift_steps: cfg = drift_cfg zt_src = (1-sigma) * x_init + sigma * noise.to(t) zt_tgt = x_tgt + zt_src - x_init #source if idx < len(timesteps) - drift_steps: if context_options is not None: counter = torch.zeros_like(zt_src, device=intermediate_device) vt_src = torch.zeros_like(zt_src, device=intermediate_device) context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) for c in context_queue: window_id = self.window_tracker.get_window_id(c) if cache_args is not None: current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) else: current_teacache = None prompt_index = min(int(max(c) / section_size), num_prompts - 1) if context_options["verbose"]: log.info(f"Prompt index: {prompt_index}") if len(source_embeds["prompt_embeds"]) > 1: positive = source_embeds["prompt_embeds"][prompt_index] else: positive = source_embeds["prompt_embeds"] partial_img_emb = None if source_image_cond is not None: partial_img_emb = source_image_cond[:, c, :, :] partial_img_emb[:, 0, :, :] = source_image_cond[:, 0, :, :].to(intermediate_device) partial_zt_src = zt_src[:, c, :, :] vt_src_context, new_teacache = predict_with_cfg( partial_zt_src, cfg[idx], positive, source_embeds["negative_prompt_embeds"], timestep, idx, partial_img_emb, control_latents, source_clip_fea, current_teacache) if cache_args is not None: self.window_tracker.cache_states[window_id] = new_teacache window_mask = create_window_mask(vt_src_context, c, latent_video_length, context_overlap) vt_src[:, c, :, :] += vt_src_context * window_mask counter[:, c, :, :] += window_mask vt_src /= counter else: vt_src, self.cache_state_source = predict_with_cfg( zt_src, cfg[idx], source_embeds["prompt_embeds"], source_embeds["negative_prompt_embeds"], timestep, idx, source_image_cond, source_clip_fea, control_latents, cache_state=self.cache_state_source) else: if idx == len(timesteps) - drift_steps: x_tgt = zt_tgt zt_tgt = x_tgt vt_src = 0 #target if context_options is not None: counter = torch.zeros_like(zt_tgt, device=intermediate_device) vt_tgt = torch.zeros_like(zt_tgt, device=intermediate_device) context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) for c in context_queue: window_id = self.window_tracker.get_window_id(c) if cache_args is not None: current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) else: current_teacache = None prompt_index = min(int(max(c) / section_size), num_prompts - 1) if context_options["verbose"]: log.info(f"Prompt index: {prompt_index}") if len(text_embeds["prompt_embeds"]) > 1: positive = text_embeds["prompt_embeds"][prompt_index] else: positive = text_embeds["prompt_embeds"] partial_img_emb = None partial_control_latents = None if image_cond is not None: partial_img_emb = image_cond[:, c, :, :] partial_img_emb[:, 0, :, :] = image_cond[:, 0, :, :].to(intermediate_device) if control_latents is not None: partial_control_latents = control_latents[:, c, :, :] partial_zt_tgt = zt_tgt[:, c, :, :] vt_tgt_context, new_teacache = predict_with_cfg( partial_zt_tgt, cfg[idx], positive, text_embeds["negative_prompt_embeds"], timestep, idx, partial_img_emb, partial_control_latents, clip_fea, current_teacache) if cache_args is not None: self.window_tracker.cache_states[window_id] = new_teacache window_mask = create_window_mask(vt_tgt_context, c, latent_video_length, context_overlap) vt_tgt[:, c, :, :] += vt_tgt_context * window_mask counter[:, c, :, :] += window_mask vt_tgt /= counter else: vt_tgt, self.cache_state = predict_with_cfg( zt_tgt, cfg[idx], text_embeds["prompt_embeds"], text_embeds["negative_prompt_embeds"], timestep, idx, image_cond, clip_fea, control_latents, cache_state=self.cache_state) v_delta = vt_tgt - vt_src x_tgt = x_tgt.to(torch.float32) v_delta = v_delta.to(torch.float32) x_tgt = x_tgt + (sigma_prev - sigma) * v_delta x0 = x_tgt #context windowing elif context_options is not None: counter = torch.zeros_like(latent_model_input, device=intermediate_device) noise_pred = torch.zeros_like(latent_model_input, device=intermediate_device) context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) for c in context_queue: window_id = self.window_tracker.get_window_id(c) if cache_args is not None: current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) else: current_teacache = None prompt_index = min(int(max(c) / section_size), num_prompts - 1) if context_options["verbose"]: log.info(f"Prompt index: {prompt_index}") # Use the appropriate prompt for this section if len(text_embeds["prompt_embeds"]) > 1: positive = text_embeds["prompt_embeds"][prompt_index] else: positive = text_embeds["prompt_embeds"] partial_img_emb = None partial_control_latents = None if image_cond is not None: partial_img_emb = image_cond[:, c] partial_img_emb[:, 0] = image_cond[:, 0].to(intermediate_device) if control_latents is not None: partial_control_latents = control_latents[:, c] partial_control_camera_latents = None if control_camera_latents is not None: partial_control_camera_latents = control_camera_latents[:, :, c] partial_vace_context = None if vace_data is not None: window_vace_data = [] for vace_entry in vace_data: partial_context = vace_entry["context"][0][:, c] if has_ref: partial_context[:, 0] = vace_entry["context"][0][:, 0] window_vace_data.append({ "context": [partial_context], "scale": vace_entry["scale"], "start": vace_entry["start"], "end": vace_entry["end"], "seq_len": vace_entry["seq_len"] }) partial_vace_context = window_vace_data partial_audio_proj = None if fantasytalking_embeds is not None: partial_audio_proj = audio_proj[:, c] partial_latent_model_input = latent_model_input[:, c] partial_unianim_data = None if unianim_data is not None: partial_dwpose = dwpose_data[:, :, c] partial_dwpose_flat=rearrange(partial_dwpose, 'b c f h w -> b (f h w) c') partial_unianim_data = { "dwpose": partial_dwpose_flat, "random_ref": unianim_data["random_ref"], "strength": unianimate_poses["strength"], "start_percent": unianimate_poses["start_percent"], "end_percent": unianimate_poses["end_percent"] } partial_add_cond = None if add_cond is not None: partial_add_cond = add_cond[:, :, c].to(device, dtype) noise_pred_context, new_teacache = predict_with_cfg( partial_latent_model_input, cfg[idx], positive, text_embeds["negative_prompt_embeds"], timestep, idx, partial_img_emb, clip_fea, partial_control_latents, partial_vace_context, partial_unianim_data,partial_audio_proj, partial_control_camera_latents, partial_add_cond, current_teacache) if cache_args is not None: self.window_tracker.cache_states[window_id] = new_teacache window_mask = create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=is_looped) noise_pred[:, c] += noise_pred_context * window_mask counter[:, c] += window_mask noise_pred /= counter #region normal inference else: noise_pred, self.cache_state = predict_with_cfg( latent_model_input, cfg[idx], text_embeds["prompt_embeds"], text_embeds["negative_prompt_embeds"], timestep, idx, image_cond, clip_fea, control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond, cache_state=self.cache_state) if latent_shift_loop: #reverse latent shift if latent_shift_start_percent <= current_step_percentage <= latent_shift_end_percent: noise_pred = torch.cat([noise_pred[:, latent_video_length - shift_idx:]] + [noise_pred[:, :latent_video_length - shift_idx]], dim=1) shift_idx = (shift_idx + latent_skip) % latent_video_length if flowedit_args is None: latent = latent.to(intermediate_device) step_args = { "generator": seed_g, } if isinstance(sample_scheduler, DEISMultistepScheduler) or isinstance(sample_scheduler, FlowMatchScheduler): step_args.pop("generator", None) temp_x0 = sample_scheduler.step( noise_pred[:, :orig_noise_len].unsqueeze(0) if recammaster is not None else noise_pred.unsqueeze(0), t, latent[:, :orig_noise_len].unsqueeze(0) if recammaster is not None else latent.unsqueeze(0), #return_dict=False, **step_args)[0] latent = temp_x0.squeeze(0) x0 = latent.to(device) if callback is not None: if recammaster is not None: callback_latent = (latent_model_input[:, :orig_noise_len].to(device) - noise_pred[:, :orig_noise_len].to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) elif phantom_latents is not None: callback_latent = (latent_model_input[:,:-phantom_latents.shape[1]].to(device) - noise_pred[:,:-phantom_latents.shape[1]].to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) else: callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) callback(idx, callback_latent, None, steps) else: pbar.update(1) del latent_model_input, timestep else: if callback is not None: callback_latent = (zt_tgt.to(device) - vt_tgt.to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) callback(idx, callback_latent, None, steps) else: pbar.update(1) if phantom_latents is not None: x0 = x0[:,:-phantom_latents.shape[1]] if cache_args is not None: cache_type = cache_args["cache_type"] states = transformer.teacache_state.states if cache_type == "TeaCache" else transformer.magcache_state.states state_names = { 0: "conditional", 1: "unconditional" } for pred_id, state in states.items(): name = state_names.get(pred_id, f"prediction_{pred_id}") if 'skipped_steps' in state: log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}") transformer.teacache_state.clear_all() transformer.magcache_state.clear_all() del states if force_offload: if model["manual_offloading"]: transformer.to(offload_device) mm.soft_empty_cache() gc.collect() try: print_memory(device) torch.cuda.reset_peak_memory_stats(device) except: pass return ({ "samples": x0.unsqueeze(0).cpu(), "looped": is_looped, "end_image": end_image if not fun_or_fl2v_model else None, "has_ref": has_ref, "drop_last": drop_last, }, ) class WindowTracker: def __init__(self, verbose=False): self.window_map = {} # Maps frame sequence to persistent ID self.next_id = 0 self.cache_states = {} # Maps persistent ID to teacache state self.verbose = verbose def get_window_id(self, frames): key = tuple(sorted(frames)) # Order-independent frame sequence if key not in self.window_map: self.window_map[key] = self.next_id if self.verbose: log.info(f"New window pattern {key} -> ID {self.next_id}") self.next_id += 1 return self.window_map[key] def get_teacache(self, window_id, base_state): if window_id not in self.cache_states: if self.verbose: log.info(f"Initializing persistent teacache for window {window_id}") self.cache_states[window_id] = base_state.copy() return self.cache_states[window_id] #region VideoDecode class WanVideoDecode: @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("WANVAE",), "samples": ("LATENT",), "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": ( "Drastically reduces memory use but will introduce seams at tile stride boundaries. " "The location and number of seams is dictated by the tile stride size. " "The visibility of seams can be controlled by increasing the tile size. " "Seams become less obvious at 1.5x stride and are barely noticeable at 2x stride size. " "Which is to say if you use a stride width of 160, the seams are barely noticeable with a tile width of 320." )}), "tile_x": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile width in pixels. Smaller values use less VRAM but will make seams more obvious."}), "tile_y": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile height in pixels. Smaller values use less VRAM but will make seams more obvious."}), "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride width in pixels. Smaller values use less VRAM but will introduce more seams."}), "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride height in pixels. Smaller values use less VRAM but will introduce more seams."}), }, } @classmethod def VALIDATE_INPUTS(s, tile_x, tile_y, tile_stride_x, tile_stride_y): if tile_x <= tile_stride_x: return "Tile width must be larger than the tile stride width." if tile_y <= tile_stride_y: return "Tile height must be larger than the tile stride height." return True RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("images",) FUNCTION = "decode" CATEGORY = "WanVideoWrapper" def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() latents = samples["samples"] end_image = samples.get("end_image", None) has_ref = samples.get("has_ref", False) drop_last = samples.get("drop_last", False) is_looped = samples.get("looped", False) vae.to(device) latents = latents.to(device = device, dtype = vae.dtype) mm.soft_empty_cache() if has_ref: latents = latents[:, :, 1:] if drop_last: latents = latents[:, :, :-1] #if is_looped: # latents = torch.cat([latents[:, :, :warmup_latent_count],latents], dim=2) if type(vae).__name__ == "TAEHV": images = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3) images = torch.clamp(images, 0.0, 1.0) images = images.permute(1, 2, 3, 0).cpu().float() return (images,) else: if end_image is not None: enable_vae_tiling = False images = vae.decode(latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8))[0] vae.model.clear_cache() #images = (images - images.min()) / (images.max() - images.min()) images = torch.clamp(images, -1.0, 1.0) images = (images + 1.0) / 2.0 if is_looped: #images = images[:, warmup_latent_count * 4:] temp_latents = torch.cat([latents[:, :, -3:]] + [latents[:, :, :2]], dim=2) temp_images = vae.decode(temp_latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8))[0] temp_images = (temp_images - temp_images.min()) / (temp_images.max() - temp_images.min()) out = temp_images[:, 9:] out = torch.cat([out, images[:, 5:]], dim=1) images = out if end_image is not None: #end_image = (end_image - end_image.min()) / (end_image.max() - end_image.min()) #image[:, -1] = end_image[:, 0].to(image) #not sure about this images = images[:, 0:-1] vae.model.clear_cache() vae.to(offload_device) mm.soft_empty_cache() images = torch.clamp(images, 0.0, 1.0) images = images.permute(1, 2, 3, 0).cpu().float() return (images,) #region VideoEncode class WanVideoEncode: @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("WANVAE",), "image": ("IMAGE",), "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), "tile_x": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), "tile_y": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), }, "optional": { "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for leapfusion I2V where some noise can add motion and give sharper results"}), "latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for leapfusion I2V where lower values allow for more motion"}), "mask": ("MASK", ), } } RETURN_TYPES = ("LATENT",) RETURN_NAMES = ("samples",) FUNCTION = "encode" CATEGORY = "WanVideoWrapper" def encode(self, vae, image, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, noise_aug_strength=0.0, latent_strength=1.0, mask=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() vae.to(device) image = image.clone() B, H, W, C = image.shape if W % 16 != 0 or H % 16 != 0: new_height = (H // 16) * 16 new_width = (W // 16) * 16 log.warning(f"Image size {W}x{H} is not divisible by 16, resizing to {new_width}x{new_height}") image = common_upscale(image.movedim(-1, 1), new_width, new_height, "lanczos", "disabled").movedim(1, -1) image = image.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W if noise_aug_strength > 0.0: image = add_noise_to_reference_video(image, ratio=noise_aug_strength) if isinstance(vae, TAEHV): latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False)# B, T, C, H, W latents = latents.permute(0, 2, 1, 3, 4) else: latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8)) vae.model.clear_cache() if latent_strength != 1.0: latents *= latent_strength log.info(f"encoded latents shape {latents.shape}") latent_mask = None if mask is None: vae.to(offload_device) else: #latent_mask = mask.clone().to(vae.dtype).to(device) * 2.0 - 1.0 #latent_mask = latent_mask.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1, 1) #latent_mask = vae.encode(latent_mask, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y)) target_h, target_w = latents.shape[3:] mask = torch.nn.functional.interpolate( mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W] size=(latents.shape[2], target_h, target_w), mode='trilinear', align_corners=False ).squeeze(0) # Remove batch dim, keep channel dim # Add batch & channel dims for final output latent_mask = mask.unsqueeze(0).repeat(1, latents.shape[1], 1, 1, 1) log.info(f"latent mask shape {latent_mask.shape}") vae.to(offload_device) mm.soft_empty_cache() return ({"samples": latents, "mask": latent_mask},) NODE_CLASS_MAPPINGS = { "WanVideoSampler": WanVideoSampler, "WanVideoDecode": WanVideoDecode, "WanVideoTextEncode": WanVideoTextEncode, "WanVideoTextEncodeSingle": WanVideoTextEncodeSingle, "WanVideoModelLoader": WanVideoModelLoader, "WanVideoVAELoader": WanVideoVAELoader, "LoadWanVideoT5TextEncoder": LoadWanVideoT5TextEncoder, "WanVideoImageClipEncode": WanVideoImageClipEncode,#deprecated "WanVideoClipVisionEncode": WanVideoClipVisionEncode, "WanVideoImageToVideoEncode": WanVideoImageToVideoEncode, "LoadWanVideoClipTextEncoder": LoadWanVideoClipTextEncoder, "WanVideoEncode": WanVideoEncode, "WanVideoBlockSwap": WanVideoBlockSwap, "WanVideoTorchCompileSettings": WanVideoTorchCompileSettings, "WanVideoEmptyEmbeds": WanVideoEmptyEmbeds, "WanVideoLoraSelect": WanVideoLoraSelect, "WanVideoLoraBlockEdit": WanVideoLoraBlockEdit, "WanVideoEnhanceAVideo": WanVideoEnhanceAVideo, "WanVideoContextOptions": WanVideoContextOptions, "WanVideoTeaCache": WanVideoTeaCache, "WanVideoMagCache": WanVideoMagCache, "WanVideoVRAMManagement": WanVideoVRAMManagement, "WanVideoTextEmbedBridge": WanVideoTextEmbedBridge, "WanVideoFlowEdit": WanVideoFlowEdit, "WanVideoControlEmbeds": WanVideoControlEmbeds, "WanVideoSLG": WanVideoSLG, "WanVideoTinyVAELoader": WanVideoTinyVAELoader, "WanVideoLoopArgs": WanVideoLoopArgs, "WanVideoImageResizeToClosest": WanVideoImageResizeToClosest, "WanVideoSetBlockSwap": WanVideoSetBlockSwap, "WanVideoExperimentalArgs": WanVideoExperimentalArgs, "WanVideoVACEEncode": WanVideoVACEEncode, "WanVideoVACEStartToEndFrame": WanVideoVACEStartToEndFrame, "WanVideoVACEModelSelect": WanVideoVACEModelSelect, "WanVideoPhantomEmbeds": WanVideoPhantomEmbeds, "CreateCFGScheduleFloatList": CreateCFGScheduleFloatList, "WanVideoRealisDanceLatents": WanVideoRealisDanceLatents, "WanVideoApplyNAG": WanVideoApplyNAG, "WanVideoMiniMaxRemoverEmbeds": WanVideoMiniMaxRemoverEmbeds, "WanVideoLoraSelectMulti": WanVideoLoraSelectMulti } NODE_DISPLAY_NAME_MAPPINGS = { "WanVideoSampler": "WanVideo Sampler", "WanVideoDecode": "WanVideo Decode", "WanVideoTextEncode": "WanVideo TextEncode", "WanVideoTextEncodeSingle": "WanVideo TextEncodeSingle", "WanVideoTextImageEncode": "WanVideo TextImageEncode (IP2V)", "WanVideoModelLoader": "WanVideo Model Loader", "WanVideoVAELoader": "WanVideo VAE Loader", "LoadWanVideoT5TextEncoder": "Load WanVideo T5 TextEncoder", "WanVideoImageClipEncode": "WanVideo ImageClip Encode (Deprecated)", "WanVideoClipVisionEncode": "WanVideo ClipVision Encode", "WanVideoImageToVideoEncode": "WanVideo ImageToVideo Encode", "LoadWanVideoClipTextEncoder": "Load WanVideo Clip Encoder", "WanVideoEncode": "WanVideo Encode", "WanVideoBlockSwap": "WanVideo BlockSwap", "WanVideoTorchCompileSettings": "WanVideo Torch Compile Settings", "WanVideoEmptyEmbeds": "WanVideo Empty Embeds", "WanVideoLoraSelect": "WanVideo Lora Select", "WanVideoLoraBlockEdit": "WanVideo Lora Block Edit", "WanVideoEnhanceAVideo": "WanVideo Enhance-A-Video", "WanVideoContextOptions": "WanVideo Context Options", "WanVideoTeaCache": "WanVideo TeaCache", "WanVideoMagCache": "WanVideo MagCache", "WanVideoVRAMManagement": "WanVideo VRAM Management", "WanVideoTextEmbedBridge": "WanVideo TextEmbed Bridge", "WanVideoFlowEdit": "WanVideo FlowEdit", "WanVideoControlEmbeds": "WanVideo Control Embeds", "WanVideoSLG": "WanVideo SLG", "WanVideoTinyVAELoader": "WanVideo Tiny VAE Loader", "WanVideoLoopArgs": "WanVideo Loop Args", "WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest", "WanVideoSetBlockSwap": "WanVideo Set BlockSwap", "WanVideoExperimentalArgs": "WanVideo Experimental Args", "WanVideoVACEEncode": "WanVideo VACE Encode", "WanVideoVACEStartToEndFrame": "WanVideo VACE Start To End Frame", "WanVideoVACEModelSelect": "WanVideo VACE Model Select", "WanVideoPhantomEmbeds": "WanVideo Phantom Embeds", "CreateCFGScheduleFloatList": "WanVideo CFG Schedule Float List", "WanVideoRealisDanceLatents": "WanVideo RealisDance Latents", "WanVideoApplyNAG": "WanVideo Apply NAG", "WanVideoMiniMaxRemoverEmbeds": "WanVideo MiniMax Remover Embeds", "WanVideoLoraSelectMulti": "WanVideo Lora Select Multi" }