Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| import gc | |
| from .utils import log, print_memory, apply_lora, clip_encode_image_tiled, fourier_filter | |
| import numpy as np | |
| import math | |
| from tqdm import tqdm | |
| from .wanvideo.modules.clip import CLIPModel | |
| from .wanvideo.modules.model import WanModel, rope_params | |
| from .wanvideo.modules.t5 import T5EncoderModel | |
| from .wanvideo.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, | |
| get_sampling_sigmas, retrieve_timesteps) | |
| from .wanvideo.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
| from .wanvideo.utils.basic_flowmatch import FlowMatchScheduler | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, DEISMultistepScheduler | |
| from .wanvideo.utils.scheduling_flow_match_lcm import FlowMatchLCMScheduler | |
| from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight, set_num_frames | |
| from .taehv import TAEHV | |
| from accelerate import init_empty_weights | |
| from accelerate.utils import set_module_tensor_to_device | |
| from einops import rearrange | |
| import folder_paths | |
| import comfy.model_management as mm | |
| from comfy.utils import load_torch_file, ProgressBar, common_upscale | |
| import comfy.model_base | |
| import comfy.latent_formats | |
| from comfy.clip_vision import clip_preprocess, ClipVisionModel | |
| from comfy.sd import load_lora_for_models | |
| from comfy.cli_args import args, LatentPreviewMethod | |
| script_directory = os.path.dirname(os.path.abspath(__file__)) | |
| def add_noise_to_reference_video(image, ratio=None): | |
| sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio | |
| image_noise = torch.randn_like(image) * sigma[:, None, None, None] | |
| image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) | |
| image = image + image_noise | |
| return image | |
| def optimized_scale(positive_flat, negative_flat): | |
| # Calculate dot production | |
| dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) | |
| # Squared norm of uncondition | |
| squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 | |
| # st_star = v_cond^T * v_uncond / ||v_uncond||^2 | |
| st_star = dot_product / squared_norm | |
| return st_star | |
| class WanVideoBlockSwap: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "blocks_to_swap": ("INT", {"default": 20, "min": 0, "max": 40, "step": 1, "tooltip": "Number of transformer blocks to swap, the 14B model has 40, while the 1.3B model has 30 blocks"}), | |
| "offload_img_emb": ("BOOLEAN", {"default": False, "tooltip": "Offload img_emb to offload_device"}), | |
| "offload_txt_emb": ("BOOLEAN", {"default": False, "tooltip": "Offload time_emb to offload_device"}), | |
| }, | |
| "optional": { | |
| "use_non_blocking": ("BOOLEAN", {"default": True, "tooltip": "Use non-blocking memory transfer for offloading, reserves more RAM but is faster"}), | |
| "vace_blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 15, "step": 1, "tooltip": "Number of VACE blocks to swap, the VACE model has 15 blocks"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("BLOCKSWAPARGS",) | |
| RETURN_NAMES = ("block_swap_args",) | |
| FUNCTION = "setargs" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Settings for block swapping, reduces VRAM use by swapping blocks to CPU memory" | |
| def setargs(self, **kwargs): | |
| return (kwargs, ) | |
| class WanVideoVRAMManagement: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "offload_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Percentage of parameters to offload"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("VRAM_MANAGEMENTARGS",) | |
| RETURN_NAMES = ("vram_management_args",) | |
| FUNCTION = "setargs" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower" | |
| def setargs(self, **kwargs): | |
| return (kwargs, ) | |
| class WanVideoTeaCache: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.001, | |
| "tooltip": "Higher values will make TeaCache more aggressive, faster, but may cause artifacts. Good value range for 1.3B: 0.05 - 0.08, for other models 0.15-0.30"}), | |
| "start_step": ("INT", {"default": 1, "min": 0, "max": 9999, "step": 1, "tooltip": "Start percentage of the steps to apply TeaCache"}), | |
| "end_step": ("INT", {"default": -1, "min": -1, "max": 9999, "step": 1, "tooltip": "End steps to apply TeaCache"}), | |
| "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), | |
| "use_coefficients": ("BOOLEAN", {"default": True, "tooltip": "Use calculated coefficients for more accuracy. When enabled therel_l1_thresh should be about 10 times higher than without"}), | |
| }, | |
| "optional": { | |
| "mode": (["e", "e0"], {"default": "e", "tooltip": "Choice between using e (time embeds, default) or e0 (modulated time embeds)"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("CACHEARGS",) | |
| RETURN_NAMES = ("cache_args",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = """ | |
| Patch WanVideo model to use TeaCache. Speeds up inference by caching the output and | |
| applying it instead of doing the step. Best results are achieved by choosing the | |
| appropriate coefficients for the model. Early steps should never be skipped, with too | |
| aggressive values this can happen and the motion suffers. Starting later can help with that too. | |
| When NOT using coefficients, the threshold value should be | |
| about 10 times smaller than the value used with coefficients. | |
| Official recommended values https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4Wan2.1: | |
| <pre style='font-family:monospace'> | |
| +-------------------+--------+---------+--------+ | |
| | Model | Low | Medium | High | | |
| +-------------------+--------+---------+--------+ | |
| | Wan2.1 t2v 1.3B | 0.05 | 0.07 | 0.08 | | |
| | Wan2.1 t2v 14B | 0.14 | 0.15 | 0.20 | | |
| | Wan2.1 i2v 480P | 0.13 | 0.19 | 0.26 | | |
| | Wan2.1 i2v 720P | 0.18 | 0.20 | 0.30 | | |
| +-------------------+--------+---------+--------+ | |
| </pre> | |
| """ | |
| EXPERIMENTAL = True | |
| def process(self, rel_l1_thresh, start_step, end_step, cache_device, use_coefficients, mode="e"): | |
| if cache_device == "main_device": | |
| cache_device = mm.get_torch_device() | |
| else: | |
| cache_device = mm.unet_offload_device() | |
| cache_args = { | |
| "cache_type": "TeaCache", | |
| "rel_l1_thresh": rel_l1_thresh, | |
| "start_step": start_step, | |
| "end_step": end_step, | |
| "cache_device": cache_device, | |
| "use_coefficients": use_coefficients, | |
| "mode": mode, | |
| } | |
| return (cache_args,) | |
| class WanVideoMagCache: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "magcache_thresh": ("FLOAT", {"default": 0.02, "min": 0.0, "max": 0.3, "step": 0.001, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}), | |
| "magcache_K": ("INT", {"default": 4, "min": 0, "max": 6, "step": 1, "tooltip": "The maxium skip steps of MagCache."}), | |
| "start_step": ("INT", {"default": 1, "min": 0, "max": 9999, "step": 1, "tooltip": "Step to start applying MagCache"}), | |
| "end_step": ("INT", {"default": -1, "min": -1, "max": 9999, "step": 1, "tooltip": "Step to end applying MagCache"}), | |
| "cache_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Device to cache to"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("CACHEARGS",) | |
| RETURN_NAMES = ("cache_args",) | |
| FUNCTION = "setargs" | |
| CATEGORY = "WanVideoWrapper" | |
| EXPERIMENTAL = True | |
| DESCRIPTION = "MagCache for WanVideoWrapper, source https://github.com/Zehong-Ma/MagCache" | |
| def setargs(self, magcache_thresh, magcache_K, start_step, end_step, cache_device): | |
| if cache_device == "main_device": | |
| cache_device = mm.get_torch_device() | |
| else: | |
| cache_device = mm.unet_offload_device() | |
| cache_args = { | |
| "cache_type": "MagCache", | |
| "magcache_thresh": magcache_thresh, | |
| "magcache_K": magcache_K, | |
| "start_step": start_step, | |
| "end_step": end_step, | |
| "cache_device": cache_device, | |
| } | |
| return (cache_args,) | |
| class WanVideoModel(comfy.model_base.BaseModel): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.pipeline = {} | |
| def __getitem__(self, k): | |
| return self.pipeline[k] | |
| def __setitem__(self, k, v): | |
| self.pipeline[k] = v | |
| try: | |
| from comfy.latent_formats import Wan21 | |
| latent_format = Wan21 | |
| except: #for backwards compatibility | |
| log.warning("Wan21 latent format not found, update ComfyUI for better livepreview") | |
| from comfy.latent_formats import HunyuanVideo | |
| latent_format = HunyuanVideo | |
| class WanVideoModelConfig: | |
| def __init__(self, dtype): | |
| self.unet_config = {} | |
| self.unet_extra_config = {} | |
| self.latent_format = latent_format | |
| self.latent_format.latent_channels = 16 | |
| self.manual_cast_dtype = dtype | |
| self.sampling_settings = {"multiplier": 1.0} | |
| self.memory_usage_factor = 2.0 | |
| self.unet_config["disable_unet_model_creation"] = True | |
| def filter_state_dict_by_blocks(state_dict, blocks_mapping, layer_filter=[]): | |
| filtered_dict = {} | |
| if isinstance(layer_filter, str): | |
| layer_filters = [layer_filter] if layer_filter else [] | |
| else: | |
| # Filter out empty strings | |
| layer_filters = [f for f in layer_filter if f] if layer_filter else [] | |
| #print("layer_filter: ", layer_filters) | |
| for key in state_dict: | |
| if not any(filter_str in key for filter_str in layer_filters): | |
| if 'blocks.' in key: | |
| block_pattern = key.split('diffusion_model.')[1].split('.', 2)[0:2] | |
| block_key = f'{block_pattern[0]}.{block_pattern[1]}.' | |
| if block_key in blocks_mapping: | |
| filtered_dict[key] = state_dict[key] | |
| else: | |
| filtered_dict[key] = state_dict[key] | |
| for key in filtered_dict: | |
| print(key) | |
| #from safetensors.torch import save_file | |
| #save_file(filtered_dict, "filtered_state_dict_2.safetensors") | |
| return filtered_dict | |
| def standardize_lora_key_format(lora_sd): | |
| new_sd = {} | |
| for k, v in lora_sd.items(): | |
| # Diffusers format | |
| if k.startswith('transformer.'): | |
| k = k.replace('transformer.', 'diffusion_model.') | |
| if k.startswith('pipe.dit.'): #unianimate-dit/diffsynth | |
| k = k.replace('pipe.dit.', 'diffusion_model.') | |
| # Fun LoRA format | |
| if k.startswith('lora_unet__'): | |
| # Split into main path and weight type parts | |
| parts = k.split('.') | |
| main_part = parts[0] # e.g. lora_unet__blocks_0_cross_attn_k | |
| weight_type = '.'.join(parts[1:]) if len(parts) > 1 else None # e.g. lora_down.weight | |
| # Process the main part - convert from underscore to dot format | |
| if 'blocks_' in main_part: | |
| # Extract components | |
| components = main_part[len('lora_unet__'):].split('_') | |
| # Start with diffusion_model | |
| new_key = "diffusion_model" | |
| # Add blocks.N | |
| if components[0] == 'blocks': | |
| new_key += f".blocks.{components[1]}" | |
| # Handle different module types | |
| idx = 2 | |
| if idx < len(components): | |
| if components[idx] == 'self' and idx+1 < len(components) and components[idx+1] == 'attn': | |
| new_key += ".self_attn" | |
| idx += 2 | |
| elif components[idx] == 'cross' and idx+1 < len(components) and components[idx+1] == 'attn': | |
| new_key += ".cross_attn" | |
| idx += 2 | |
| elif components[idx] == 'ffn': | |
| new_key += ".ffn" | |
| idx += 1 | |
| # Add the component (k, q, v, o) and handle img suffix | |
| if idx < len(components): | |
| component = components[idx] | |
| idx += 1 | |
| # Check for img suffix | |
| if idx < len(components) and components[idx] == 'img': | |
| component += '_img' | |
| idx += 1 | |
| new_key += f".{component}" | |
| # Handle weight type - this is the critical fix | |
| if weight_type: | |
| if weight_type == 'alpha': | |
| new_key += '.alpha' | |
| elif weight_type == 'lora_down.weight' or weight_type == 'lora_down': | |
| new_key += '.lora_A.weight' | |
| elif weight_type == 'lora_up.weight' or weight_type == 'lora_up': | |
| new_key += '.lora_B.weight' | |
| else: | |
| # Keep original weight type if not matching our patterns | |
| new_key += f'.{weight_type}' | |
| # Add .weight suffix if missing | |
| if not new_key.endswith('.weight'): | |
| new_key += '.weight' | |
| k = new_key | |
| else: | |
| # For other lora_unet__ formats (head, embeddings, etc.) | |
| new_key = main_part.replace('lora_unet__', 'diffusion_model.') | |
| # Fix specific component naming patterns | |
| new_key = new_key.replace('_self_attn', '.self_attn') | |
| new_key = new_key.replace('_cross_attn', '.cross_attn') | |
| new_key = new_key.replace('_ffn', '.ffn') | |
| new_key = new_key.replace('blocks_', 'blocks.') | |
| new_key = new_key.replace('head_head', 'head.head') | |
| new_key = new_key.replace('img_emb', 'img_emb') | |
| new_key = new_key.replace('text_embedding', 'text.embedding') | |
| new_key = new_key.replace('time_embedding', 'time.embedding') | |
| new_key = new_key.replace('time_projection', 'time.projection') | |
| # Replace remaining underscores with dots, carefully | |
| parts = new_key.split('.') | |
| final_parts = [] | |
| for part in parts: | |
| if part in ['img_emb', 'self_attn', 'cross_attn']: | |
| final_parts.append(part) # Keep these intact | |
| else: | |
| final_parts.append(part.replace('_', '.')) | |
| new_key = '.'.join(final_parts) | |
| # Handle weight type | |
| if weight_type: | |
| if weight_type == 'alpha': | |
| new_key += '.alpha' | |
| elif weight_type == 'lora_down.weight' or weight_type == 'lora_down': | |
| new_key += '.lora_A.weight' | |
| elif weight_type == 'lora_up.weight' or weight_type == 'lora_up': | |
| new_key += '.lora_B.weight' | |
| else: | |
| new_key += f'.{weight_type}' | |
| if not new_key.endswith('.weight'): | |
| new_key += '.weight' | |
| k = new_key | |
| # Handle special embedded components | |
| special_components = { | |
| 'time.projection': 'time_projection', | |
| 'img.emb': 'img_emb', | |
| 'text.emb': 'text_emb', | |
| 'time.emb': 'time_emb', | |
| } | |
| for old, new in special_components.items(): | |
| if old in k: | |
| k = k.replace(old, new) | |
| # Fix diffusion.model -> diffusion_model | |
| if k.startswith('diffusion.model.'): | |
| k = k.replace('diffusion.model.', 'diffusion_model.') | |
| # Finetrainer format | |
| if '.attn1.' in k: | |
| k = k.replace('.attn1.', '.cross_attn.') | |
| k = k.replace('.to_k.', '.k.') | |
| k = k.replace('.to_q.', '.q.') | |
| k = k.replace('.to_v.', '.v.') | |
| k = k.replace('.to_out.0.', '.o.') | |
| elif '.attn2.' in k: | |
| k = k.replace('.attn2.', '.cross_attn.') | |
| k = k.replace('.to_k.', '.k.') | |
| k = k.replace('.to_q.', '.q.') | |
| k = k.replace('.to_v.', '.v.') | |
| k = k.replace('.to_out.0.', '.o.') | |
| if "img_attn.proj" in k: | |
| k = k.replace("img_attn.proj", "img_attn_proj") | |
| if "img_attn.qkv" in k: | |
| k = k.replace("img_attn.qkv", "img_attn_qkv") | |
| if "txt_attn.proj" in k: | |
| k = k.replace("txt_attn.proj", "txt_attn_proj") | |
| if "txt_attn.qkv" in k: | |
| k = k.replace("txt_attn.qkv", "txt_attn_qkv") | |
| new_sd[k] = v | |
| return new_sd | |
| class WanVideoEnhanceAVideo: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}), | |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}), | |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("FETAARGS",) | |
| RETURN_NAMES = ("feta_args",) | |
| FUNCTION = "setargs" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video" | |
| def setargs(self, **kwargs): | |
| return (kwargs, ) | |
| class WanVideoLoraSelect: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "lora": (folder_paths.get_filename_list("loras"), | |
| {"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}), | |
| "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), | |
| }, | |
| "optional": { | |
| "prev_lora":("WANVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), | |
| "blocks":("SELECTEDBLOCKS", ), | |
| "low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDLORA",) | |
| RETURN_NAMES = ("lora", ) | |
| FUNCTION = "getlorapath" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" | |
| def getlorapath(self, lora, strength, blocks={}, prev_lora=None, low_mem_load=False): | |
| loras_list = [] | |
| lora = { | |
| "path": folder_paths.get_full_path("loras", lora), | |
| "strength": strength, | |
| "name": lora.split(".")[0], | |
| "blocks": blocks.get("selected_blocks", {}), | |
| "layer_filter": blocks.get("layer_filter", ""), | |
| "low_mem_load": low_mem_load, | |
| } | |
| if prev_lora is not None: | |
| loras_list.extend(prev_lora) | |
| loras_list.append(lora) | |
| return (loras_list,) | |
| class WanVideoLoraSelectMulti: | |
| def INPUT_TYPES(s): | |
| lora_files = folder_paths.get_filename_list("loras") | |
| lora_files = ["none"] + lora_files # Add "none" as the first option | |
| return { | |
| "required": { | |
| "lora_0": (lora_files, {"default": "none"}), | |
| "strength_0": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), | |
| "lora_1": (lora_files, {"default": "none"}), | |
| "strength_1": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), | |
| "lora_2": (lora_files, {"default": "none"}), | |
| "strength_2": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), | |
| "lora_3": (lora_files, {"default": "none"}), | |
| "strength_3": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), | |
| "lora_4": (lora_files, {"default": "none"}), | |
| "strength_4": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), | |
| }, | |
| "optional": { | |
| "prev_lora":("WANVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), | |
| "blocks":("SELECTEDBLOCKS", ), | |
| "low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDLORA",) | |
| RETURN_NAMES = ("lora", ) | |
| FUNCTION = "getlorapath" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" | |
| def getlorapath(self, lora_0, strength_0, lora_1, strength_1, lora_2, strength_2, | |
| lora_3, strength_3, lora_4, strength_4, blocks={}, prev_lora=None, | |
| low_mem_load=False): | |
| loras_list = [] | |
| if prev_lora is not None: | |
| loras_list.extend(prev_lora) | |
| # Process each LoRA | |
| lora_inputs = [ | |
| (lora_0, strength_0), | |
| (lora_1, strength_1), | |
| (lora_2, strength_2), | |
| (lora_3, strength_3), | |
| (lora_4, strength_4) | |
| ] | |
| for lora_name, strength in lora_inputs: | |
| # Skip if the LoRA is empty | |
| if not lora_name or lora_name == "none": | |
| continue | |
| lora = { | |
| "path": folder_paths.get_full_path("loras", lora_name), | |
| "strength": strength, | |
| "name": lora_name.split(".")[0], | |
| "blocks": blocks.get("selected_blocks", {}), | |
| "layer_filter": blocks.get("layer_filter", ""), | |
| "low_mem_load": low_mem_load, | |
| } | |
| loras_list.append(lora) | |
| return (loras_list,) | |
| class WanVideoVACEModelSelect: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "vace_model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' VACE model to use when not using model that has it included"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("VACEPATH",) | |
| RETURN_NAMES = ("vace_model", ) | |
| FUNCTION = "getvacepath" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "VACE model to use when not using model that has it included, loaded from 'ComfyUI/models/diffusion_models'" | |
| def getvacepath(self, vace_model): | |
| vace_model = { | |
| "path": folder_paths.get_full_path("diffusion_models", vace_model), | |
| } | |
| return (vace_model,) | |
| class WanVideoLoraBlockEdit: | |
| def __init__(self): | |
| self.loaded_lora = None | |
| def INPUT_TYPES(s): | |
| arg_dict = {} | |
| argument = ("BOOLEAN", {"default": True}) | |
| for i in range(40): | |
| arg_dict["blocks.{}.".format(i)] = argument | |
| return {"required": arg_dict, "optional": {"layer_filter": ("STRING", {"default": "", "multiline": True})}} | |
| RETURN_TYPES = ("SELECTEDBLOCKS", ) | |
| RETURN_NAMES = ("blocks", ) | |
| OUTPUT_TOOLTIPS = ("The modified lora model",) | |
| FUNCTION = "select" | |
| CATEGORY = "WanVideoWrapper" | |
| def select(self, layer_filter=[], **kwargs): | |
| selected_blocks = {k: v for k, v in kwargs.items() if v is True and isinstance(v, bool)} | |
| print("Selected blocks LoRA: ", selected_blocks) | |
| selected = { | |
| "selected_blocks": selected_blocks, | |
| "layer_filter": [x.strip() for x in layer_filter.split(",")] | |
| } | |
| return (selected,) | |
| #region Model loading | |
| class WanVideoModelLoader: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), | |
| "base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}), | |
| "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2', 'fp8_e4m3fn_fast_no_ffn'], {"default": 'disabled', "tooltip": "optional quantization method"}), | |
| "load_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}), | |
| }, | |
| "optional": { | |
| "attention_mode": ([ | |
| "sdpa", | |
| "flash_attn_2", | |
| "flash_attn_3", | |
| "sageattn", | |
| "flex_attention", | |
| #"spargeattn", needs tuning | |
| #"spargeattn_tune", | |
| ], {"default": "sdpa"}), | |
| "compile_args": ("WANCOMPILEARGS", ), | |
| "block_swap_args": ("BLOCKSWAPARGS", ), | |
| "lora": ("WANVIDLORA", {"default": None}), | |
| "vram_management_args": ("VRAM_MANAGEMENTARGS", {"default": None, "tooltip": "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower"}), | |
| "vace_model": ("VACEPATH", {"default": None, "tooltip": "VACE model to use when not using model that has it included"}), | |
| "fantasytalking_model": ("FANTASYTALKINGMODEL", {"default": None, "tooltip": "FantasyTalking model https://github.com/Fantasy-AMAP"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDEOMODEL",) | |
| RETURN_NAMES = ("model", ) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "WanVideoWrapper" | |
| def loadmodel(self, model, base_precision, load_device, quantization, | |
| compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, vram_management_args=None, vace_model=None, fantasytalking_model=None): | |
| assert not (vram_management_args is not None and block_swap_args is not None), "Can't use both block_swap_args and vram_management_args at the same time" | |
| lora_low_mem_load = False | |
| if lora is not None: | |
| for l in lora: | |
| lora_low_mem_load = l.get("low_mem_load") if lora is not None else False | |
| transformer = None | |
| mm.unload_all_models() | |
| mm.cleanup_models() | |
| mm.soft_empty_cache() | |
| manual_offloading = True | |
| if "sage" in attention_mode: | |
| try: | |
| from sageattention import sageattn | |
| except Exception as e: | |
| raise ValueError(f"Can't import SageAttention: {str(e)}") | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| manual_offloading = True | |
| transformer_load_device = device if load_device == "main_device" else offload_device | |
| base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] | |
| if base_precision == "fp16_fast": | |
| if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): | |
| torch.backends.cuda.matmul.allow_fp16_accumulation = True | |
| else: | |
| raise ValueError("torch.backends.cuda.matmul.allow_fp16_accumulation is not available in this version of torch, requires torch 2.7.0.dev2025 02 26 nightly minimum currently") | |
| else: | |
| try: | |
| if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): | |
| torch.backends.cuda.matmul.allow_fp16_accumulation = False | |
| except: | |
| pass | |
| model_path = folder_paths.get_full_path_or_raise("diffusion_models", model) | |
| sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True) | |
| if "vace_blocks.0.after_proj.weight" in sd and not "patch_embedding.weight" in sd: | |
| raise ValueError("You are attempting to load a VACE module as a WanVideo model, instead you should use the vace_model input and matching T2V base model") | |
| if vace_model is not None: | |
| vace_sd = load_torch_file(vace_model["path"], device=transformer_load_device, safe_load=True) | |
| sd.update(vace_sd) | |
| first_key = next(iter(sd)) | |
| if first_key.startswith("model.diffusion_model."): | |
| new_sd = {} | |
| for key, value in sd.items(): | |
| new_key = key.replace("model.diffusion_model.", "", 1) | |
| new_sd[new_key] = value | |
| sd = new_sd | |
| elif first_key.startswith("model."): | |
| new_sd = {} | |
| for key, value in sd.items(): | |
| new_key = key.replace("model.", "", 1) | |
| new_sd[new_key] = value | |
| sd = new_sd | |
| if not "patch_embedding.weight" in sd: | |
| raise ValueError("Invalid WanVideo model selected") | |
| dim = sd["patch_embedding.weight"].shape[0] | |
| in_channels = sd["patch_embedding.weight"].shape[1] | |
| log.info(f"Detected model in_channels: {in_channels}") | |
| ffn_dim = sd["blocks.0.ffn.0.bias"].shape[0] | |
| if not "text_embedding.0.weight" in sd: | |
| model_type = "no_cross_attn" #minimaxremover | |
| elif "model_type.Wan2_1-FLF2V-14B-720P" in sd or "img_emb.emb_pos" in sd or "flf2v" in model.lower(): | |
| model_type = "fl2v" | |
| elif in_channels in [36, 48]: | |
| model_type = "i2v" | |
| elif in_channels == 16: | |
| model_type = "t2v" | |
| elif "control_adapter.conv.weight" in sd: | |
| model_type = "t2v" | |
| num_heads = 40 if dim == 5120 else 12 | |
| num_layers = 40 if dim == 5120 else 30 | |
| vace_layers, vace_in_dim = None, None | |
| if "vace_blocks.0.after_proj.weight" in sd: | |
| if in_channels != 16: | |
| raise ValueError("VACE only works properly with T2V models.") | |
| model_type = "t2v" | |
| if dim == 5120: | |
| vace_layers = [0, 5, 10, 15, 20, 25, 30, 35] | |
| else: | |
| vace_layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28] | |
| vace_in_dim = 96 | |
| log.info(f"Model type: {model_type}, num_heads: {num_heads}, num_layers: {num_layers}") | |
| teacache_coefficients_map = { | |
| "1_3B": { | |
| "e": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], | |
| "e0": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], | |
| }, | |
| "14B": { | |
| "e": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], | |
| "e0": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], | |
| }, | |
| "i2v_480": { | |
| "e": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], | |
| "e0": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], | |
| }, | |
| "i2v_720":{ | |
| "e": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], | |
| "e0": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], | |
| }, | |
| } | |
| magcache_ratios_map = { | |
| "1_3B": np.array([1.0]*2+[1.0124, 1.02213, 1.00166, 1.0041, 0.99791, 1.00061, 0.99682, 0.99762, 0.99634, 0.99685, 0.99567, 0.99586, 0.99416, 0.99422, 0.99578, 0.99575, 0.9957, 0.99563, 0.99511, 0.99506, 0.99535, 0.99531, 0.99552, 0.99549, 0.99541, 0.99539, 0.9954, 0.99536, 0.99489, 0.99485, 0.99518, 0.99514, 0.99484, 0.99478, 0.99481, 0.99479, 0.99415, 0.99413, 0.99419, 0.99416, 0.99396, 0.99393, 0.99388, 0.99386, 0.99349, 0.99349, 0.99309, 0.99304, 0.9927, 0.9927, 0.99228, 0.99226, 0.99171, 0.9917, 0.99137, 0.99135, 0.99068, 0.99063, 0.99005, 0.99003, 0.98944, 0.98942, 0.98849, 0.98849, 0.98758, 0.98757, 0.98644, 0.98643, 0.98504, 0.98503, 0.9836, 0.98359, 0.98202, 0.98201, 0.97977, 0.97978, 0.97717, 0.97718, 0.9741, 0.97411, 0.97003, 0.97002, 0.96538, 0.96541, 0.9593, 0.95933, 0.95086, 0.95089, 0.94013, 0.94019, 0.92402, 0.92414, 0.90241, 0.9026, 0.86821, 0.86868, 0.81838, 0.81939]), | |
| "14B": np.array([1.0]*2+[1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]), | |
| "i2v_480": np.array([1.0]*2+[0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616]), | |
| "i2v_720": np.array([1.0]*2+[0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768]), | |
| } | |
| model_variant = "14B" #default to this | |
| if model_type == "i2v" or model_type == "fl2v": | |
| if "480" in model or "fun" in model.lower() or "a2" in model.lower() or "540" in model: #just a guess for the Fun model for now... | |
| model_variant = "i2v_480" | |
| elif "720" in model: | |
| model_variant = "i2v_720" | |
| elif model_type == "t2v": | |
| model_variant = "14B" | |
| if dim == 1536: | |
| model_variant = "1_3B" | |
| log.info(f"Model variant detected: {model_variant}") | |
| TRANSFORMER_CONFIG= { | |
| "dim": dim, | |
| "ffn_dim": ffn_dim, | |
| "eps": 1e-06, | |
| "freq_dim": 256, | |
| "in_dim": in_channels, | |
| "model_type": model_type, | |
| "out_dim": 16, | |
| "text_len": 512, | |
| "num_heads": num_heads, | |
| "num_layers": num_layers, | |
| "attention_mode": attention_mode, | |
| "main_device": device, | |
| "offload_device": offload_device, | |
| "teacache_coefficients": teacache_coefficients_map[model_variant], | |
| "magcache_ratios": magcache_ratios_map[model_variant], | |
| "vace_layers": vace_layers, | |
| "vace_in_dim": vace_in_dim, | |
| "inject_sample_info": True if "fps_embedding.weight" in sd else False, | |
| "add_ref_conv": True if "ref_conv.weight" in sd else False, | |
| "in_dim_ref_conv": sd["ref_conv.weight"].shape[1] if "ref_conv.weight" in sd else None, | |
| "add_control_adapter": True if "control_adapter.conv.weight" in sd else False, | |
| } | |
| with init_empty_weights(): | |
| transformer = WanModel(**TRANSFORMER_CONFIG) | |
| transformer.eval() | |
| #ReCamMaster | |
| if "blocks.0.cam_encoder.weight" in sd: | |
| log.info("ReCamMaster model detected, patching model...") | |
| import torch.nn as nn | |
| for block in transformer.blocks: | |
| block.cam_encoder = nn.Linear(12, dim) | |
| block.projector = nn.Linear(dim, dim) | |
| block.cam_encoder.weight.data.zero_() | |
| block.cam_encoder.bias.data.zero_() | |
| block.projector.weight = nn.Parameter(torch.eye(dim)) | |
| block.projector.bias = nn.Parameter(torch.zeros(dim)) | |
| # FantasyTalking https://github.com/Fantasy-AMAP | |
| if fantasytalking_model is not None: | |
| log.info("FantasyTalking model detected, patching model...") | |
| context_dim = fantasytalking_model["sd"]["proj_model.proj.weight"].shape[0] | |
| import torch.nn as nn | |
| for block in transformer.blocks: | |
| block.cross_attn.k_proj = nn.Linear(context_dim, dim, bias=False) | |
| block.cross_attn.v_proj = nn.Linear(context_dim, dim, bias=False) | |
| sd.update(fantasytalking_model["sd"]) | |
| # RealisDance-DiT | |
| if "add_conv_in.weight" in sd: | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| torch.nn.init.zeros_(p) | |
| return module | |
| inner_dim = sd["add_conv_in.weight"].shape[0] | |
| add_cond_in_dim = sd["add_conv_in.weight"].shape[1] | |
| attn_cond_in_dim = sd["attn_conv_in.weight"].shape[1] | |
| transformer.add_conv_in = torch.nn.Conv3d(add_cond_in_dim, inner_dim, kernel_size=transformer.patch_size, stride=transformer.patch_size) | |
| transformer.add_proj = zero_module(torch.nn.Linear(inner_dim, inner_dim)) | |
| transformer.attn_conv_in = torch.nn.Conv3d(attn_cond_in_dim, inner_dim, kernel_size=transformer.patch_size, stride=transformer.patch_size) | |
| comfy_model = WanVideoModel( | |
| WanVideoModelConfig(base_dtype), | |
| model_type=comfy.model_base.ModelType.FLOW, | |
| device=device, | |
| ) | |
| if quantization == "disabled": | |
| for k, v in sd.items(): | |
| if isinstance(v, torch.Tensor): | |
| if v.dtype == torch.float8_e4m3fn: | |
| quantization = "fp8_e4m3fn" | |
| break | |
| elif v.dtype == torch.float8_e5m2: | |
| quantization = "fp8_e5m2" | |
| break | |
| if "fp8_e4m3fn" in quantization: | |
| dtype = torch.float8_e4m3fn | |
| elif quantization == "fp8_e5m2": | |
| dtype = torch.float8_e5m2 | |
| else: | |
| dtype = base_dtype | |
| params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation", "text_embedding", "adapter", "add"} | |
| #if lora is not None: | |
| # transformer_load_device = device | |
| if not lora_low_mem_load: | |
| log.info("Using accelerate to load and assign model weights to device...") | |
| param_count = sum(1 for _ in transformer.named_parameters()) | |
| for name, param in tqdm(transformer.named_parameters(), | |
| desc=f"Loading transformer parameters to {transformer_load_device}", | |
| total=param_count, | |
| leave=True): | |
| dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype | |
| if "patch_embedding" in name: | |
| dtype_to_use = torch.float32 | |
| set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name]) | |
| comfy_model.diffusion_model = transformer | |
| comfy_model.load_device = transformer_load_device | |
| patcher = comfy.model_patcher.ModelPatcher(comfy_model, device, offload_device) | |
| patcher.model.is_patched = False | |
| control_lora = False | |
| if lora is not None: | |
| for l in lora: | |
| log.info(f"Loading LoRA: {l['name']} with strength: {l['strength']}") | |
| lora_path = l["path"] | |
| lora_strength = l["strength"] | |
| lora_sd = load_torch_file(lora_path, safe_load=True) | |
| if "dwpose_embedding.0.weight" in lora_sd: #unianimate | |
| from .unianimate.nodes import update_transformer | |
| log.info("Unianimate LoRA detected, patching model...") | |
| transformer = update_transformer(transformer, lora_sd) | |
| lora_sd = standardize_lora_key_format(lora_sd) | |
| if l["blocks"]: | |
| lora_sd = filter_state_dict_by_blocks(lora_sd, l["blocks"], l.get("layer_filter", [])) | |
| #spacepxl's control LoRA patch | |
| # for key in lora_sd.keys(): | |
| # print(key) | |
| if "diffusion_model.patch_embedding.lora_A.weight" in lora_sd: | |
| log.info("Control-LoRA detected, patching model...") | |
| control_lora = True | |
| in_cls = transformer.patch_embedding.__class__ # nn.Conv3d | |
| old_in_dim = transformer.in_dim # 16 | |
| new_in_dim = lora_sd["diffusion_model.patch_embedding.lora_A.weight"].shape[1] | |
| assert new_in_dim == 32 | |
| new_in = in_cls( | |
| new_in_dim, | |
| transformer.patch_embedding.out_channels, | |
| transformer.patch_embedding.kernel_size, | |
| transformer.patch_embedding.stride, | |
| transformer.patch_embedding.padding, | |
| ).to(device=device, dtype=torch.float32) | |
| new_in.weight.zero_() | |
| new_in.bias.zero_() | |
| new_in.weight[:, :old_in_dim].copy_(transformer.patch_embedding.weight) | |
| new_in.bias.copy_(transformer.patch_embedding.bias) | |
| transformer.patch_embedding = new_in | |
| transformer.expanded_patch_embedding = new_in | |
| transformer.register_to_config(in_dim=new_in_dim) | |
| patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0) | |
| del lora_sd | |
| patcher = apply_lora(patcher, device, transformer_load_device, params_to_keep=params_to_keep, dtype=dtype, base_dtype=base_dtype, state_dict=sd, low_mem_load=lora_low_mem_load) | |
| #patcher.load(device, full_load=True) | |
| patcher.model.is_patched = True | |
| if "fast" in quantization: | |
| from .fp8_optimization import convert_fp8_linear | |
| if quantization == "fp8_e4m3fn_fast_no_ffn": | |
| params_to_keep.update({"ffn"}) | |
| print(params_to_keep) | |
| convert_fp8_linear(patcher.model.diffusion_model, base_dtype, params_to_keep=params_to_keep) | |
| del sd | |
| if vram_management_args is not None: | |
| from .diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear | |
| from .wanvideo.modules.model import WanLayerNorm, WanRMSNorm | |
| total_params_in_model = sum(p.numel() for p in patcher.model.diffusion_model.parameters()) | |
| log.info(f"Total number of parameters in the loaded model: {total_params_in_model}") | |
| offload_percent = vram_management_args["offload_percent"] | |
| offload_params = int(total_params_in_model * offload_percent) | |
| params_to_keep = total_params_in_model - offload_params | |
| log.info(f"Selected params to offload: {offload_params}") | |
| enable_vram_management( | |
| patcher.model.diffusion_model, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv3d: AutoWrappedModule, | |
| torch.nn.LayerNorm: AutoWrappedModule, | |
| WanLayerNorm: AutoWrappedModule, | |
| WanRMSNorm: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device=offload_device, | |
| onload_dtype=dtype, | |
| onload_device=device, | |
| computation_dtype=base_dtype, | |
| computation_device=device, | |
| ), | |
| max_num_param=params_to_keep, | |
| overflow_module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device=offload_device, | |
| onload_dtype=dtype, | |
| onload_device=offload_device, | |
| computation_dtype=base_dtype, | |
| computation_device=device, | |
| ), | |
| compile_args = compile_args, | |
| ) | |
| #compile | |
| if compile_args is not None and vram_management_args is None: | |
| torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] | |
| try: | |
| if hasattr(torch, '_dynamo') and hasattr(torch._dynamo, 'config'): | |
| torch._dynamo.config.recompile_limit = compile_args["dynamo_recompile_limit"] | |
| except Exception as e: | |
| log.warning(f"Could not set recompile_limit: {e}") | |
| if compile_args["compile_transformer_blocks_only"]: | |
| for i, block in enumerate(patcher.model.diffusion_model.blocks): | |
| patcher.model.diffusion_model.blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) | |
| if vace_layers is not None: | |
| for i, block in enumerate(patcher.model.diffusion_model.vace_blocks): | |
| patcher.model.diffusion_model.vace_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) | |
| else: | |
| patcher.model.diffusion_model = torch.compile(patcher.model.diffusion_model, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) | |
| if load_device == "offload_device" and patcher.model.diffusion_model.device != offload_device: | |
| log.info(f"Moving diffusion model from {patcher.model.diffusion_model.device} to {offload_device}") | |
| patcher.model.diffusion_model.to(offload_device) | |
| gc.collect() | |
| mm.soft_empty_cache() | |
| patcher.model["dtype"] = base_dtype | |
| patcher.model["base_path"] = model_path | |
| patcher.model["model_name"] = model | |
| patcher.model["manual_offloading"] = manual_offloading | |
| patcher.model["quantization"] = quantization | |
| patcher.model["auto_cpu_offload"] = True if vram_management_args is not None else False | |
| patcher.model["control_lora"] = control_lora | |
| if 'transformer_options' not in patcher.model_options: | |
| patcher.model_options['transformer_options'] = {} | |
| patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args | |
| for model in mm.current_loaded_models: | |
| if model._model() == patcher: | |
| mm.current_loaded_models.remove(model) | |
| return (patcher,) | |
| class WanVideoSetBlockSwap: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": ("WANVIDEOMODEL", ), | |
| "block_swap_args": ("BLOCKSWAPARGS", ), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDEOMODEL",) | |
| RETURN_NAMES = ("model", ) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "WanVideoWrapper" | |
| def loadmodel(self, model, block_swap_args): | |
| patcher = model.clone() | |
| if 'transformer_options' not in patcher.model_options: | |
| patcher.model_options['transformer_options'] = {} | |
| patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args | |
| return (patcher,) | |
| #region load VAE | |
| class WanVideoVAELoader: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model_name": (folder_paths.get_filename_list("vae"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae'"}), | |
| }, | |
| "optional": { | |
| "precision": (["fp16", "fp32", "bf16"], | |
| {"default": "bf16"} | |
| ), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVAE",) | |
| RETURN_NAMES = ("vae", ) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Loads Wan VAE model from 'ComfyUI/models/vae'" | |
| def loadmodel(self, model_name, precision): | |
| from .wanvideo.wan_video_vae import WanVideoVAE | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
| #with open(os.path.join(script_directory, 'configs', 'hy_vae_config.json')) as f: | |
| # vae_config = json.load(f) | |
| model_path = folder_paths.get_full_path("vae", model_name) | |
| vae_sd = load_torch_file(model_path, safe_load=True) | |
| has_model_prefix = any(k.startswith("model.") for k in vae_sd.keys()) | |
| if not has_model_prefix: | |
| vae_sd = {f"model.{k}": v for k, v in vae_sd.items()} | |
| vae = WanVideoVAE(dtype=dtype) | |
| vae.load_state_dict(vae_sd) | |
| vae.eval() | |
| vae.to(device = offload_device, dtype = dtype) | |
| return (vae,) | |
| class WanVideoTinyVAELoader: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model_name": (folder_paths.get_filename_list("vae_approx"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae_approx'"}), | |
| }, | |
| "optional": { | |
| "precision": (["fp16", "fp32", "bf16"], {"default": "fp16"}), | |
| "parallel": ("BOOLEAN", {"default": False, "tooltip": "uses more memory but is faster"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVAE",) | |
| RETURN_NAMES = ("vae", ) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Loads Wan VAE model from 'ComfyUI/models/vae'" | |
| def loadmodel(self, model_name, precision, parallel=False): | |
| from .taehv import TAEHV | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
| model_path = folder_paths.get_full_path("vae_approx", model_name) | |
| vae_sd = load_torch_file(model_path, safe_load=True) | |
| vae = TAEHV(vae_sd, parallel=parallel) | |
| vae.to(device = offload_device, dtype = dtype) | |
| return (vae,) | |
| class WanVideoTorchCompileSettings: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "backend": (["inductor","cudagraphs"], {"default": "inductor"}), | |
| "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), | |
| "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), | |
| "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), | |
| "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), | |
| "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only the transformer blocks, usually enough and can make compilation faster and less error prone"}), | |
| }, | |
| "optional": { | |
| "dynamo_recompile_limit": ("INT", {"default": 128, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.recompile_limit"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("WANCOMPILEARGS",) | |
| RETURN_NAMES = ("torch_compile_args",) | |
| FUNCTION = "set_args" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" | |
| def set_args(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, dynamo_recompile_limit=128): | |
| compile_args = { | |
| "backend": backend, | |
| "fullgraph": fullgraph, | |
| "mode": mode, | |
| "dynamic": dynamic, | |
| "dynamo_cache_size_limit": dynamo_cache_size_limit, | |
| "dynamo_recompile_limit": dynamo_recompile_limit, | |
| "compile_transformer_blocks_only": compile_transformer_blocks_only, | |
| } | |
| return (compile_args, ) | |
| #region TextEncode | |
| class LoadWanVideoT5TextEncoder: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model_name": (folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/text_encoders'"}), | |
| "precision": (["fp32", "bf16"], | |
| {"default": "bf16"} | |
| ), | |
| }, | |
| "optional": { | |
| "load_device": (["main_device", "offload_device"], {"default": "offload_device"}), | |
| "quantization": (['disabled', 'fp8_e4m3fn'], {"default": 'disabled', "tooltip": "optional quantization method"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANTEXTENCODER",) | |
| RETURN_NAMES = ("wan_t5_model", ) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Loads Wan text_encoder model from 'ComfyUI/models/LLM'" | |
| def loadmodel(self, model_name, precision, load_device="offload_device", quantization="disabled"): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| text_encoder_load_device = device if load_device == "main_device" else offload_device | |
| tokenizer_path = os.path.join(script_directory, "configs", "T5_tokenizer") | |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
| model_path = folder_paths.get_full_path("text_encoders", model_name) | |
| sd = load_torch_file(model_path, safe_load=True) | |
| if "token_embedding.weight" not in sd and "shared.weight" not in sd: | |
| raise ValueError("Invalid T5 text encoder model, this node expects the 'umt5-xxl' model") | |
| if "scaled_fp8" in sd: | |
| raise ValueError("Invalid T5 text encoder model, fp8 scaled is not supported by this node") | |
| # Convert state dict keys from T5 format to the expected format | |
| if "shared.weight" in sd: | |
| log.info("Converting T5 text encoder model to the expected format...") | |
| converted_sd = {} | |
| for key, value in sd.items(): | |
| # Handle encoder block patterns | |
| if key.startswith('encoder.block.'): | |
| parts = key.split('.') | |
| block_num = parts[2] | |
| # Self-attention components | |
| if 'layer.0.SelfAttention' in key: | |
| if key.endswith('.k.weight'): | |
| new_key = f"blocks.{block_num}.attn.k.weight" | |
| elif key.endswith('.o.weight'): | |
| new_key = f"blocks.{block_num}.attn.o.weight" | |
| elif key.endswith('.q.weight'): | |
| new_key = f"blocks.{block_num}.attn.q.weight" | |
| elif key.endswith('.v.weight'): | |
| new_key = f"blocks.{block_num}.attn.v.weight" | |
| elif 'relative_attention_bias' in key: | |
| new_key = f"blocks.{block_num}.pos_embedding.embedding.weight" | |
| else: | |
| new_key = key | |
| # Layer norms | |
| elif 'layer.0.layer_norm' in key: | |
| new_key = f"blocks.{block_num}.norm1.weight" | |
| elif 'layer.1.layer_norm' in key: | |
| new_key = f"blocks.{block_num}.norm2.weight" | |
| # Feed-forward components | |
| elif 'layer.1.DenseReluDense' in key: | |
| if 'wi_0' in key: | |
| new_key = f"blocks.{block_num}.ffn.gate.0.weight" | |
| elif 'wi_1' in key: | |
| new_key = f"blocks.{block_num}.ffn.fc1.weight" | |
| elif 'wo' in key: | |
| new_key = f"blocks.{block_num}.ffn.fc2.weight" | |
| else: | |
| new_key = key | |
| else: | |
| new_key = key | |
| elif key == "shared.weight": | |
| new_key = "token_embedding.weight" | |
| elif key == "encoder.final_layer_norm.weight": | |
| new_key = "norm.weight" | |
| else: | |
| new_key = key | |
| converted_sd[new_key] = value | |
| sd = converted_sd | |
| T5_text_encoder = T5EncoderModel( | |
| text_len=512, | |
| dtype=dtype, | |
| device=text_encoder_load_device, | |
| state_dict=sd, | |
| tokenizer_path=tokenizer_path, | |
| quantization=quantization | |
| ) | |
| text_encoder = { | |
| "model": T5_text_encoder, | |
| "dtype": dtype, | |
| } | |
| return (text_encoder,) | |
| class LoadWanVideoClipTextEncoder: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model_name": (folder_paths.get_filename_list("clip_vision") + folder_paths.get_filename_list("text_encoders"), {"tooltip": "These models are loaded from 'ComfyUI/models/clip_vision'"}), | |
| "precision": (["fp16", "fp32", "bf16"], | |
| {"default": "fp16"} | |
| ), | |
| }, | |
| "optional": { | |
| "load_device": (["main_device", "offload_device"], {"default": "offload_device"}), | |
| } | |
| } | |
| RETURN_TYPES = ("CLIP_VISION",) | |
| RETURN_NAMES = ("wan_clip_vision", ) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Loads Wan clip_vision model from 'ComfyUI/models/clip_vision'" | |
| def loadmodel(self, model_name, precision, load_device="offload_device"): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| text_encoder_load_device = device if load_device == "main_device" else offload_device | |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
| model_path = folder_paths.get_full_path("clip_vision", model_name) | |
| # We also support legacy setups where the model is in the text_encoders folder | |
| if model_path is None: | |
| model_path = folder_paths.get_full_path("text_encoders", model_name) | |
| sd = load_torch_file(model_path, safe_load=True) | |
| if "log_scale" not in sd: | |
| raise ValueError("Invalid CLIP model, this node expectes the 'open-clip-xlm-roberta-large-vit-huge-14' model") | |
| clip_model = CLIPModel(dtype=dtype, device=device, state_dict=sd) | |
| clip_model.model.to(text_encoder_load_device) | |
| del sd | |
| return (clip_model,) | |
| class WanVideoTextEncode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "t5": ("WANTEXTENCODER",), | |
| "positive_prompt": ("STRING", {"default": "", "multiline": True} ), | |
| "negative_prompt": ("STRING", {"default": "", "multiline": True} ), | |
| }, | |
| "optional": { | |
| "force_offload": ("BOOLEAN", {"default": True}), | |
| "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) | |
| RETURN_NAMES = ("text_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Encodes text prompts into text embeddings. For rudimentary prompt travel you can input multiple prompts separated by '|', they will be equally spread over the video length" | |
| def process(self, t5, positive_prompt, negative_prompt,force_offload=True, model_to_offload=None): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| if model_to_offload is not None: | |
| log.info(f"Moving video model to {offload_device}") | |
| model_to_offload.model.to(offload_device) | |
| mm.soft_empty_cache() | |
| encoder = t5["model"] | |
| dtype = t5["dtype"] | |
| # Split positive prompts and process each with weights | |
| positive_prompts_raw = [p.strip() for p in positive_prompt.split('|')] | |
| positive_prompts = [] | |
| all_weights = [] | |
| for p in positive_prompts_raw: | |
| cleaned_prompt, weights = self.parse_prompt_weights(p) | |
| positive_prompts.append(cleaned_prompt) | |
| all_weights.append(weights) | |
| encoder.model.to(device) | |
| with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=True): | |
| context = encoder(positive_prompts, device) | |
| context_null = encoder([negative_prompt], device) | |
| # Apply weights to embeddings if any were extracted | |
| for i, weights in enumerate(all_weights): | |
| for text, weight in weights.items(): | |
| log.info(f"Applying weight {weight} to prompt: {text}") | |
| if len(weights) > 0: | |
| context[i] = context[i] * weight | |
| if force_offload: | |
| encoder.model.to(offload_device) | |
| mm.soft_empty_cache() | |
| prompt_embeds_dict = { | |
| "prompt_embeds": context, | |
| "negative_prompt_embeds": context_null, | |
| } | |
| return (prompt_embeds_dict,) | |
| def parse_prompt_weights(self, prompt): | |
| """Extract text and weights from prompts with (text:weight) format""" | |
| import re | |
| # Parse all instances of (text:weight) in the prompt | |
| pattern = r'\((.*?):([\d\.]+)\)' | |
| matches = re.findall(pattern, prompt) | |
| # Replace each match with just the text part | |
| cleaned_prompt = prompt | |
| weights = {} | |
| for match in matches: | |
| text, weight = match | |
| orig_text = f"({text}:{weight})" | |
| cleaned_prompt = cleaned_prompt.replace(orig_text, text) | |
| weights[text] = float(weight) | |
| return cleaned_prompt, weights | |
| class WanVideoTextEncodeSingle: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "t5": ("WANTEXTENCODER",), | |
| "prompt": ("STRING", {"default": "", "multiline": True} ), | |
| }, | |
| "optional": { | |
| "force_offload": ("BOOLEAN", {"default": True}), | |
| "model_to_offload": ("WANVIDEOMODEL", {"tooltip": "Model to move to offload_device before encoding"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) | |
| RETURN_NAMES = ("text_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Encodes text prompt into text embedding." | |
| def process(self, t5, prompt, force_offload=True, model_to_offload=None): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| if model_to_offload is not None: | |
| log.info(f"Moving video model to {offload_device}") | |
| model_to_offload.model.to(offload_device) | |
| mm.soft_empty_cache() | |
| encoder = t5["model"] | |
| dtype = t5["dtype"] | |
| encoder.model.to(device) | |
| with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=True): | |
| encoded = encoder([prompt], device) | |
| if force_offload: | |
| encoder.model.to(offload_device) | |
| mm.soft_empty_cache() | |
| prompt_embeds_dict = { | |
| "prompt_embeds": encoded, | |
| "negative_prompt_embeds": None, | |
| } | |
| return (prompt_embeds_dict,) | |
| class WanVideoApplyNAG: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "original_text_embeds": ("WANVIDEOTEXTEMBEDS",), | |
| "nag_text_embeds": ("WANVIDEOTEXTEMBEDS",), | |
| "nag_scale": ("FLOAT", {"default": 11.0, "min": 0.0, "max": 100.0, "step": 0.1}), | |
| "nag_tau": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 10.0, "step": 0.1}), | |
| "nag_alpha": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| }, | |
| } | |
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) | |
| RETURN_NAMES = ("text_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Adds NAG prompt embeds to original prompt embeds: 'https://github.com/ChenDarYen/Normalized-Attention-Guidance'" | |
| def process(self, original_text_embeds, nag_text_embeds, nag_scale, nag_tau, nag_alpha): | |
| prompt_embeds_dict_copy = original_text_embeds.copy() | |
| prompt_embeds_dict_copy.update({ | |
| "nag_prompt_embeds": nag_text_embeds["prompt_embeds"], | |
| "nag_params": { | |
| "nag_scale": nag_scale, | |
| "nag_tau": nag_tau, | |
| "nag_alpha": nag_alpha, | |
| } | |
| }) | |
| return (prompt_embeds_dict_copy,) | |
| class WanVideoTextEmbedBridge: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "positive": ("CONDITIONING",), | |
| }, | |
| "optional": { | |
| "negative": ("CONDITIONING",), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDEOTEXTEMBEDS", ) | |
| RETURN_NAMES = ("text_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Bridge between ComfyUI native text embedding and WanVideoWrapper text embedding" | |
| def process(self, positive, negative=None): | |
| device=mm.get_torch_device() | |
| prompt_embeds_dict = { | |
| "prompt_embeds": positive[0][0].to(device), | |
| "negative_prompt_embeds": negative[0][0].to(device) if negative is not None else None, | |
| } | |
| return (prompt_embeds_dict,) | |
| #region clip image encode | |
| class WanVideoImageClipEncode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "clip_vision": ("CLIP_VISION",), | |
| "image": ("IMAGE", {"tooltip": "Image to encode"}), | |
| "vae": ("WANVAE",), | |
| "generation_width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), | |
| "generation_height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), | |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), | |
| }, | |
| "optional": { | |
| "force_offload": ("BOOLEAN", {"default": True}), | |
| "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for I2V where some noise can add motion and give sharper results"}), | |
| "latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), | |
| "clip_embed_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), | |
| "adjust_resolution": ("BOOLEAN", {"default": True, "tooltip": "Performs the same resolution adjustment as in the original code"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) | |
| RETURN_NAMES = ("image_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DEPRECATED = True | |
| def process(self, clip_vision, vae, image, num_frames, generation_width, generation_height, force_offload=True, noise_aug_strength=0.0, | |
| latent_strength=1.0, clip_embed_strength=1.0, adjust_resolution=True): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| self.image_mean = [0.48145466, 0.4578275, 0.40821073] | |
| self.image_std = [0.26862954, 0.26130258, 0.27577711] | |
| patch_size = (1, 2, 2) | |
| vae_stride = (4, 8, 8) | |
| H, W = image.shape[1], image.shape[2] | |
| max_area = generation_width * generation_height | |
| print(clip_vision) | |
| clip_vision.model.to(device) | |
| if isinstance(clip_vision, ClipVisionModel): | |
| clip_context = clip_vision.encode_image(image).last_hidden_state.to(device) | |
| else: | |
| pixel_values = clip_preprocess(image.to(device), size=224, mean=self.image_mean, std=self.image_std, crop=True).float() | |
| clip_context = clip_vision.visual(pixel_values) | |
| if clip_embed_strength != 1.0: | |
| clip_context *= clip_embed_strength | |
| if force_offload: | |
| clip_vision.model.to(offload_device) | |
| mm.soft_empty_cache() | |
| if adjust_resolution: | |
| aspect_ratio = H / W | |
| lat_h = round( | |
| np.sqrt(max_area * aspect_ratio) // vae_stride[1] // | |
| patch_size[1] * patch_size[1]) | |
| lat_w = round( | |
| np.sqrt(max_area / aspect_ratio) // vae_stride[2] // | |
| patch_size[2] * patch_size[2]) | |
| h = lat_h * vae_stride[1] | |
| w = lat_w * vae_stride[2] | |
| else: | |
| h = generation_height | |
| w = generation_width | |
| lat_h = h // 8 | |
| lat_w = w // 8 | |
| # Step 1: Create initial mask with ones for first frame, zeros for others | |
| mask = torch.ones(1, num_frames, lat_h, lat_w, device=device) | |
| mask[:, 1:] = 0 | |
| # Step 2: Repeat first frame 4 times and concatenate with remaining frames | |
| first_frame_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) | |
| mask = torch.concat([first_frame_repeated, mask[:, 1:]], dim=1) | |
| # Step 3: Reshape mask into groups of 4 frames | |
| mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) | |
| # Step 4: Transpose dimensions and select first batch | |
| mask = mask.transpose(1, 2)[0] | |
| # Calculate maximum sequence length | |
| frames_per_stride = (num_frames - 1) // vae_stride[0] + 1 | |
| patches_per_frame = lat_h * lat_w // (patch_size[1] * patch_size[2]) | |
| max_seq_len = frames_per_stride * patches_per_frame | |
| vae.to(device) | |
| # Step 1: Resize and rearrange the input image dimensions | |
| #resized_image = image.permute(0, 3, 1, 2) # Rearrange dimensions to (B, C, H, W) | |
| #resized_image = torch.nn.functional.interpolate(resized_image, size=(h, w), mode='bicubic') | |
| resized_image = common_upscale(image.movedim(-1, 1), w, h, "lanczos", "disabled") | |
| resized_image = resized_image.transpose(0, 1) # Transpose to match required format | |
| resized_image = resized_image * 2 - 1 | |
| if noise_aug_strength > 0.0: | |
| resized_image = add_noise_to_reference_video(resized_image, ratio=noise_aug_strength) | |
| # Step 2: Create zero padding frames | |
| zero_frames = torch.zeros(3, num_frames-1, h, w, device=device) | |
| # Step 3: Concatenate image with zero frames | |
| concatenated = torch.concat([resized_image.to(device), zero_frames, resized_image.to(device)], dim=1).to(device = device, dtype = vae.dtype) | |
| concatenated *= latent_strength | |
| y = vae.encode([concatenated], device)[0] | |
| y = torch.concat([mask, y]) | |
| vae.model.clear_cache() | |
| vae.to(offload_device) | |
| image_embeds = { | |
| "image_embeds": y, | |
| "clip_context": clip_context, | |
| "max_seq_len": max_seq_len, | |
| "num_frames": num_frames, | |
| "lat_h": lat_h, | |
| "lat_w": lat_w, | |
| } | |
| return (image_embeds,) | |
| class WanVideoImageResizeToClosest: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "image": ("IMAGE", {"tooltip": "Image to resize"}), | |
| "generation_width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), | |
| "generation_height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), | |
| "aspect_ratio_preservation": (["keep_input", "stretch_to_new", "crop_to_new"],), | |
| }, | |
| } | |
| RETURN_TYPES = ("IMAGE", "INT", "INT", ) | |
| RETURN_NAMES = ("image","width","height",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Resizes image to the closest supported resolution based on aspect ratio and max pixels, according to the original code" | |
| def process(self, image, generation_width, generation_height, aspect_ratio_preservation ): | |
| patch_size = (1, 2, 2) | |
| vae_stride = (4, 8, 8) | |
| H, W = image.shape[1], image.shape[2] | |
| max_area = generation_width * generation_height | |
| crop = "disabled" | |
| if aspect_ratio_preservation == "keep_input": | |
| aspect_ratio = H / W | |
| elif aspect_ratio_preservation == "stretch_to_new" or aspect_ratio_preservation == "crop_to_new": | |
| aspect_ratio = generation_height / generation_width | |
| if aspect_ratio_preservation == "crop_to_new": | |
| crop = "center" | |
| lat_h = round( | |
| np.sqrt(max_area * aspect_ratio) // vae_stride[1] // | |
| patch_size[1] * patch_size[1]) | |
| lat_w = round( | |
| np.sqrt(max_area / aspect_ratio) // vae_stride[2] // | |
| patch_size[2] * patch_size[2]) | |
| h = lat_h * vae_stride[1] | |
| w = lat_w * vae_stride[2] | |
| resized_image = common_upscale(image.movedim(-1, 1), w, h, "lanczos", crop).movedim(1, -1) | |
| return (resized_image, w, h) | |
| #region clip vision | |
| class WanVideoClipVisionEncode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "clip_vision": ("CLIP_VISION",), | |
| "image_1": ("IMAGE", {"tooltip": "Image to encode"}), | |
| "strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), | |
| "strength_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), | |
| "crop": (["center", "disabled"], {"default": "center", "tooltip": "Crop image to 224x224 before encoding"}), | |
| "combine_embeds": (["average", "sum", "concat", "batch"], {"default": "average", "tooltip": "Method to combine multiple clip embeds"}), | |
| "force_offload": ("BOOLEAN", {"default": True}), | |
| }, | |
| "optional": { | |
| "image_2": ("IMAGE", ), | |
| "negative_image": ("IMAGE", {"tooltip": "image to use for uncond"}), | |
| "tiles": ("INT", {"default": 0, "min": 0, "max": 16, "step": 2, "tooltip": "Use matteo's tiled image encoding for improved accuracy"}), | |
| "ratio": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Ratio of the tile average"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDIMAGE_CLIPEMBEDS",) | |
| RETURN_NAMES = ("image_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, clip_vision, image_1, strength_1, strength_2, force_offload, crop, combine_embeds, image_2=None, negative_image=None, tiles=0, ratio=1.0): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| image_mean = [0.48145466, 0.4578275, 0.40821073] | |
| image_std = [0.26862954, 0.26130258, 0.27577711] | |
| if image_2 is not None: | |
| image = torch.cat([image_1, image_2], dim=0) | |
| else: | |
| image = image_1 | |
| clip_vision.model.to(device) | |
| negative_clip_embeds = None | |
| if tiles > 0: | |
| log.info("Using tiled image encoding") | |
| clip_embeds = clip_encode_image_tiled(clip_vision, image.to(device), tiles=tiles, ratio=ratio) | |
| if negative_image is not None: | |
| negative_clip_embeds = clip_encode_image_tiled(clip_vision, negative_image.to(device), tiles=tiles, ratio=ratio) | |
| else: | |
| if isinstance(clip_vision, ClipVisionModel): | |
| clip_embeds = clip_vision.encode_image(image).penultimate_hidden_states.to(device) | |
| if negative_image is not None: | |
| negative_clip_embeds = clip_vision.encode_image(negative_image).penultimate_hidden_states.to(device) | |
| else: | |
| pixel_values = clip_preprocess(image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() | |
| clip_embeds = clip_vision.visual(pixel_values) | |
| if negative_image is not None: | |
| pixel_values = clip_preprocess(negative_image.to(device), size=224, mean=image_mean, std=image_std, crop=(not crop == "disabled")).float() | |
| negative_clip_embeds = clip_vision.visual(pixel_values) | |
| log.info(f"Clip embeds shape: {clip_embeds.shape}, dtype: {clip_embeds.dtype}") | |
| weighted_embeds = [] | |
| weighted_embeds.append(clip_embeds[0:1] * strength_1) | |
| # Handle all additional embeddings | |
| if clip_embeds.shape[0] > 1: | |
| weighted_embeds.append(clip_embeds[1:2] * strength_2) | |
| if clip_embeds.shape[0] > 2: | |
| for i in range(2, clip_embeds.shape[0]): | |
| weighted_embeds.append(clip_embeds[i:i+1]) # Add as-is without strength modifier | |
| # Combine all weighted embeddings | |
| if combine_embeds == "average": | |
| clip_embeds = torch.mean(torch.stack(weighted_embeds), dim=0) | |
| elif combine_embeds == "sum": | |
| clip_embeds = torch.sum(torch.stack(weighted_embeds), dim=0) | |
| elif combine_embeds == "concat": | |
| clip_embeds = torch.cat(weighted_embeds, dim=1) | |
| elif combine_embeds == "batch": | |
| clip_embeds = torch.cat(weighted_embeds, dim=0) | |
| else: | |
| clip_embeds = weighted_embeds[0] | |
| log.info(f"Combined clip embeds shape: {clip_embeds.shape}") | |
| if force_offload: | |
| clip_vision.model.to(offload_device) | |
| mm.soft_empty_cache() | |
| clip_embeds_dict = { | |
| "clip_embeds": clip_embeds, | |
| "negative_clip_embeds": negative_clip_embeds | |
| } | |
| return (clip_embeds_dict,) | |
| class WanVideoRealisDanceLatents: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "ref_latent": ("LATENT", {"tooltip": "Reference image to encode"}), | |
| "smpl_latent": ("LATENT", {"tooltip": "SMPL pose image to encode"}), | |
| "pose_cond_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the SMPL model"}), | |
| "pose_cond_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the SMPL model"}), | |
| }, | |
| "optional": { | |
| "hamer_latent": ("LATENT", {"tooltip": "Hamer hand pose image to encode"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("REALISDANCELATENTS",) | |
| RETURN_NAMES = ("realisdance_latents",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, ref_latent, smpl_latent, pose_cond_start_percent, pose_cond_end_percent, hamer_latent=None): | |
| if hamer_latent is None: | |
| hamer = torch.zeros_like(smpl_latent["samples"]) | |
| else: | |
| hamer = hamer_latent["samples"] | |
| pose_latent = torch.cat((smpl_latent["samples"], hamer), dim=1) | |
| realisdance_latents = { | |
| "ref_latent": ref_latent["samples"], | |
| "pose_latent": pose_latent, | |
| "pose_cond_start_percent": pose_cond_start_percent, | |
| "pose_cond_end_percent": pose_cond_end_percent, | |
| } | |
| return (realisdance_latents,) | |
| class WanVideoImageToVideoEncode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "vae": ("WANVAE",), | |
| "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), | |
| "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), | |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), | |
| "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for I2V where some noise can add motion and give sharper results"}), | |
| "start_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), | |
| "end_latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for I2V where lower values allow for more motion"}), | |
| "force_offload": ("BOOLEAN", {"default": True}), | |
| }, | |
| "optional": { | |
| "clip_embeds": ("WANVIDIMAGE_CLIPEMBEDS", {"tooltip": "Clip vision encoded image"}), | |
| "start_image": ("IMAGE", {"tooltip": "Image to encode"}), | |
| "end_image": ("IMAGE", {"tooltip": "end frame"}), | |
| "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "Control signal for the Fun -model"}), | |
| "fun_or_fl2v_model": ("BOOLEAN", {"default": True, "tooltip": "Enable when using official FLF2V or Fun model"}), | |
| "temporal_mask": ("MASK", {"tooltip": "mask"}), | |
| "extra_latents": ("LATENT", {"tooltip": "Extra latents to add to the input front, used for Skyreels A2 reference images"}), | |
| "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), | |
| "realisdance_latents": ("REALISDANCELATENTS", {"tooltip": "RealisDance latents"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",) | |
| RETURN_NAMES = ("image_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, vae, width, height, num_frames, force_offload, noise_aug_strength, | |
| start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False, | |
| temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, realisdance_latents=None): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| patch_size = (1, 2, 2) | |
| H = height | |
| W = width | |
| lat_h = H // 8 | |
| lat_w = W // 8 | |
| num_frames = ((num_frames - 1) // 4) * 4 + 1 | |
| two_ref_images = start_image is not None and end_image is not None | |
| base_frames = num_frames + (1 if two_ref_images and not fun_or_fl2v_model else 0) | |
| if temporal_mask is None: | |
| mask = torch.zeros(1, base_frames, lat_h, lat_w, device=device) | |
| if start_image is not None: | |
| mask[:, 0:start_image.shape[0]] = 1 # First frame | |
| if end_image is not None: | |
| mask[:, -end_image.shape[0]:] = 1 # End frame if exists | |
| else: | |
| mask = common_upscale(temporal_mask.unsqueeze(1).to(device), lat_w, lat_h, "nearest", "disabled").squeeze(1) | |
| if mask.shape[0] > base_frames: | |
| mask = mask[:base_frames] | |
| elif mask.shape[0] < base_frames: | |
| mask = torch.cat([mask, torch.zeros(base_frames - mask.shape[0], lat_h, lat_w, device=device)]) | |
| mask = mask.unsqueeze(0).to(device) | |
| # Repeat first frame and optionally end frame | |
| start_mask_repeated = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1) # T, C, H, W | |
| if end_image is not None and not fun_or_fl2v_model: | |
| end_mask_repeated = torch.repeat_interleave(mask[:, -1:], repeats=4, dim=1) # T, C, H, W | |
| mask = torch.cat([start_mask_repeated, mask[:, 1:-1], end_mask_repeated], dim=1) | |
| else: | |
| mask = torch.cat([start_mask_repeated, mask[:, 1:]], dim=1) | |
| # Reshape mask into groups of 4 frames | |
| mask = mask.view(1, mask.shape[1] // 4, 4, lat_h, lat_w) # 1, T, C, H, W | |
| mask = mask.movedim(1, 2)[0]# C, T, H, W | |
| # Resize and rearrange the input image dimensions | |
| if start_image is not None: | |
| resized_start_image = common_upscale(start_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) | |
| resized_start_image = resized_start_image * 2 - 1 | |
| if noise_aug_strength > 0.0: | |
| resized_start_image = add_noise_to_reference_video(resized_start_image, ratio=noise_aug_strength) | |
| if end_image is not None: | |
| resized_end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(0, 1) | |
| resized_end_image = resized_end_image * 2 - 1 | |
| if noise_aug_strength > 0.0: | |
| resized_end_image = add_noise_to_reference_video(resized_end_image, ratio=noise_aug_strength) | |
| # Concatenate image with zero frames and encode | |
| vae.to(device) | |
| if temporal_mask is None: | |
| if start_image is not None and end_image is None: | |
| zero_frames = torch.zeros(3, num_frames-start_image.shape[0], H, W, device=device) | |
| concatenated = torch.cat([resized_start_image.to(device), zero_frames], dim=1) | |
| elif start_image is None and end_image is not None: | |
| zero_frames = torch.zeros(3, num_frames-end_image.shape[0], H, W, device=device) | |
| concatenated = torch.cat([zero_frames, resized_end_image.to(device)], dim=1) | |
| elif start_image is None and end_image is None: | |
| concatenated = torch.zeros(3, num_frames, H, W, device=device) | |
| else: | |
| if fun_or_fl2v_model: | |
| zero_frames = torch.zeros(3, num_frames-(start_image.shape[0]+end_image.shape[0]), H, W, device=device) | |
| else: | |
| zero_frames = torch.zeros(3, num_frames-1, H, W, device=device) | |
| concatenated = torch.cat([resized_start_image.to(device), zero_frames, resized_end_image.to(device)], dim=1) | |
| else: | |
| temporal_mask = common_upscale(temporal_mask.unsqueeze(1), W, H, "nearest", "disabled").squeeze(1) | |
| concatenated = resized_start_image[:,:num_frames] * temporal_mask[:num_frames].unsqueeze(0) | |
| y = vae.encode([concatenated.to(device=device, dtype=vae.dtype)], device, end_=(end_image is not None and not fun_or_fl2v_model),tiled=tiled_vae)[0] | |
| has_ref = False | |
| if extra_latents is not None: | |
| samples = extra_latents["samples"].squeeze(0) | |
| y = torch.cat([samples, y], dim=1) | |
| mask = torch.cat([torch.ones_like(mask[:, 0:samples.shape[1]]), mask], dim=1) | |
| num_frames += samples.shape[1] * 4 | |
| has_ref = True | |
| y[:, :1] *= start_latent_strength | |
| y[:, -1:] *= end_latent_strength | |
| if control_embeds is None: | |
| y = torch.cat([mask, y]) | |
| else: | |
| if end_image is None: | |
| y[:, 1:] = 0 | |
| elif start_image is None: | |
| y[:, -1:] = 0 | |
| else: | |
| y[:, 1:-1] = 0 # doesn't seem to work anyway though... | |
| # Calculate maximum sequence length | |
| patches_per_frame = lat_h * lat_w // (patch_size[1] * patch_size[2]) | |
| frames_per_stride = (num_frames - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1) | |
| max_seq_len = frames_per_stride * patches_per_frame | |
| if realisdance_latents is not None: | |
| realisdance_latents["ref_latent_neg"] = vae.encode(torch.zeros(1, 3, 1, H, W, device=device, dtype=vae.dtype), device) | |
| vae.model.clear_cache() | |
| if force_offload: | |
| vae.model.to(offload_device) | |
| mm.soft_empty_cache() | |
| gc.collect() | |
| image_embeds = { | |
| "image_embeds": y, | |
| "clip_context": clip_embeds.get("clip_embeds", None) if clip_embeds is not None else None, | |
| "negative_clip_context": clip_embeds.get("negative_clip_embeds", None) if clip_embeds is not None else None, | |
| "max_seq_len": max_seq_len, | |
| "num_frames": num_frames, | |
| "lat_h": lat_h, | |
| "lat_w": lat_w, | |
| "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, | |
| "end_image": resized_end_image if end_image is not None else None, | |
| "fun_or_fl2v_model": fun_or_fl2v_model, | |
| "has_ref": has_ref, | |
| "realisdance_latents": realisdance_latents | |
| } | |
| return (image_embeds,) | |
| class WanVideoEmptyEmbeds: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), | |
| "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), | |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), | |
| }, | |
| "optional": { | |
| "control_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "control signal for the Fun -model"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) | |
| RETURN_NAMES = ("image_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, num_frames, width, height, control_embeds=None): | |
| vae_stride = (4, 8, 8) | |
| target_shape = (16, (num_frames - 1) // vae_stride[0] + 1, | |
| height // vae_stride[1], | |
| width // vae_stride[2]) | |
| embeds = { | |
| "target_shape": target_shape, | |
| "num_frames": num_frames, | |
| "control_embeds": control_embeds["control_embeds"] if control_embeds is not None else None, | |
| } | |
| return (embeds,) | |
| class WanVideoMiniMaxRemoverEmbeds: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), | |
| "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), | |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), | |
| "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), | |
| "mask_latents": ("LATENT", {"tooltip": "Encoded latents to use as mask"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) | |
| RETURN_NAMES = ("image_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, num_frames, width, height, latents, mask_latents): | |
| vae_stride = (4, 8, 8) | |
| target_shape = (16, (num_frames - 1) // vae_stride[0] + 1, | |
| height // vae_stride[1], | |
| width // vae_stride[2]) | |
| embeds = { | |
| "target_shape": target_shape, | |
| "num_frames": num_frames, | |
| "minimax_latents": latents["samples"].squeeze(0), | |
| "minimax_mask_latents": mask_latents["samples"].squeeze(0), | |
| } | |
| return (embeds,) | |
| # region phantom | |
| class WanVideoPhantomEmbeds: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), | |
| "phantom_latent_1": ("LATENT", {"tooltip": "reference latents for the phantom model"}), | |
| "phantom_cfg_scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "CFG scale for the extra phantom cond pass"}), | |
| "phantom_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the phantom model"}), | |
| "phantom_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the phantom model"}), | |
| }, | |
| "optional": { | |
| "phantom_latent_2": ("LATENT", {"tooltip": "reference latents for the phantom model"}), | |
| "phantom_latent_3": ("LATENT", {"tooltip": "reference latents for the phantom model"}), | |
| "phantom_latent_4": ("LATENT", {"tooltip": "reference latents for the phantom model"}), | |
| "vace_embeds": ("WANVIDIMAGE_EMBEDS", {"tooltip": "VACE embeds"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) | |
| RETURN_NAMES = ("image_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, num_frames, phantom_cfg_scale, phantom_start_percent, phantom_end_percent, phantom_latent_1, phantom_latent_2=None, phantom_latent_3=None, phantom_latent_4=None, vace_embeds=None): | |
| vae_stride = (4, 8, 8) | |
| samples = phantom_latent_1["samples"].squeeze(0) | |
| if phantom_latent_2 is not None: | |
| samples = torch.cat([samples, phantom_latent_2["samples"].squeeze(0)], dim=1) | |
| if phantom_latent_3 is not None: | |
| samples = torch.cat([samples, phantom_latent_3["samples"].squeeze(0)], dim=1) | |
| if phantom_latent_4 is not None: | |
| samples = torch.cat([samples, phantom_latent_4["samples"].squeeze(0)], dim=1) | |
| C, T, H, W = samples.shape | |
| log.info(f"Phantom latents shape: {samples.shape}") | |
| target_shape = (16, (num_frames - 1) // vae_stride[0] + 1 + T, | |
| H * 8 // vae_stride[1], | |
| W * 8 // vae_stride[2]) | |
| embeds = { | |
| "target_shape": target_shape, | |
| "num_frames": num_frames, | |
| "phantom_latents": samples, | |
| "phantom_cfg_scale": phantom_cfg_scale, | |
| "phantom_start_percent": phantom_start_percent, | |
| "phantom_end_percent": phantom_end_percent, | |
| } | |
| if vace_embeds is not None: | |
| vace_input = { | |
| "vace_context": vace_embeds["vace_context"], | |
| "vace_scale": vace_embeds["vace_scale"], | |
| "has_ref": vace_embeds["has_ref"], | |
| "vace_start_percent": vace_embeds["vace_start_percent"], | |
| "vace_end_percent": vace_embeds["vace_end_percent"], | |
| "vace_seq_len": vace_embeds["vace_seq_len"], | |
| "additional_vace_inputs": vace_embeds["additional_vace_inputs"], | |
| } | |
| embeds.update(vace_input) | |
| return (embeds,) | |
| class WanVideoControlEmbeds: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "latents": ("LATENT", {"tooltip": "Encoded latents to use as control signals"}), | |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), | |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), | |
| }, | |
| "optional": { | |
| "fun_ref_image": ("LATENT", {"tooltip": "Reference latent for the Fun 1.1 -model"}), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) | |
| RETURN_NAMES = ("image_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, latents, start_percent, end_percent, fun_ref_image=None): | |
| samples = latents["samples"].squeeze(0) | |
| C, T, H, W = samples.shape | |
| num_frames = (T - 1) * 4 + 1 | |
| seq_len = math.ceil((H * W) / 4 * ((num_frames - 1) // 4 + 1)) | |
| embeds = { | |
| "max_seq_len": seq_len, | |
| "target_shape": samples.shape, | |
| "num_frames": num_frames, | |
| "control_embeds": { | |
| "control_images": samples, | |
| "start_percent": start_percent, | |
| "end_percent": end_percent, | |
| "fun_ref_image": fun_ref_image["samples"][:,:, 0] if fun_ref_image is not None else None, | |
| } | |
| } | |
| return (embeds,) | |
| class WanVideoSLG: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "blocks": ("STRING", {"default": "10", "tooltip": "Blocks to skip uncond on, separated by comma, index starts from 0"}), | |
| "start_percent": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the control signal"}), | |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the control signal"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("SLGARGS", ) | |
| RETURN_NAMES = ("slg_args",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Skips uncond on the selected blocks" | |
| def process(self, blocks, start_percent, end_percent): | |
| slg_block_list = [int(x.strip()) for x in blocks.split(",")] | |
| slg_args = { | |
| "blocks": slg_block_list, | |
| "start_percent": start_percent, | |
| "end_percent": end_percent, | |
| } | |
| return (slg_args,) | |
| #region VACE | |
| class WanVideoVACEEncode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "vae": ("WANVAE",), | |
| "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), | |
| "height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), | |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), | |
| "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), | |
| "vace_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply VACE"}), | |
| "vace_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply VACE"}), | |
| }, | |
| "optional": { | |
| "input_frames": ("IMAGE",), | |
| "ref_images": ("IMAGE",), | |
| "input_masks": ("MASK",), | |
| "prev_vace_embeds": ("WANVIDIMAGE_EMBEDS",), | |
| "tiled_vae": ("BOOLEAN", {"default": False, "tooltip": "Use tiled VAE encoding for reduced memory use"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", ) | |
| RETURN_NAMES = ("vace_embeds",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, vae, width, height, num_frames, strength, vace_start_percent, vace_end_percent, input_frames=None, ref_images=None, input_masks=None, prev_vace_embeds=None, tiled_vae=False): | |
| self.device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| self.vae = vae.to(self.device) | |
| self.vae_stride = (4, 8, 8) | |
| width = (width // 16) * 16 | |
| height = (height // 16) * 16 | |
| target_shape = (16, (num_frames - 1) // self.vae_stride[0] + 1, | |
| height // self.vae_stride[1], | |
| width // self.vae_stride[2]) | |
| # vace context encode | |
| if input_frames is None: | |
| input_frames = torch.zeros((1, 3, num_frames, height, width), device=self.device, dtype=self.vae.dtype) | |
| else: | |
| input_frames = input_frames[:num_frames] | |
| input_frames = common_upscale(input_frames.clone().movedim(-1, 1), width, height, "lanczos", "disabled").movedim(1, -1) | |
| input_frames = input_frames.to(self.vae.dtype).to(self.device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W | |
| input_frames = input_frames * 2 - 1 | |
| if input_masks is None: | |
| input_masks = torch.ones_like(input_frames, device=self.device) | |
| else: | |
| print("input_masks shape", input_masks.shape) | |
| input_masks = input_masks[:num_frames] | |
| input_masks = common_upscale(input_masks.clone().unsqueeze(1), width, height, "nearest-exact", "disabled").squeeze(1) | |
| input_masks = input_masks.to(self.vae.dtype).to(self.device) | |
| input_masks = input_masks.unsqueeze(-1).unsqueeze(0).permute(0, 4, 1, 2, 3).repeat(1, 3, 1, 1, 1) # B, C, T, H, W | |
| if ref_images is not None: | |
| # Create padded image | |
| if ref_images.shape[0] > 1: | |
| ref_images = torch.cat([ref_images[i] for i in range(ref_images.shape[0])], dim=1).unsqueeze(0) | |
| B, H, W, C = ref_images.shape | |
| current_aspect = W / H | |
| target_aspect = width / height | |
| if current_aspect > target_aspect: | |
| # Image is wider than target, pad height | |
| new_h = int(W / target_aspect) | |
| pad_h = (new_h - H) // 2 | |
| padded = torch.ones(ref_images.shape[0], new_h, W, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) | |
| padded[:, pad_h:pad_h+H, :, :] = ref_images | |
| ref_images = padded | |
| elif current_aspect < target_aspect: | |
| # Image is taller than target, pad width | |
| new_w = int(H * target_aspect) | |
| pad_w = (new_w - W) // 2 | |
| padded = torch.ones(ref_images.shape[0], H, new_w, ref_images.shape[3], device=ref_images.device, dtype=ref_images.dtype) | |
| padded[:, :, pad_w:pad_w+W, :] = ref_images | |
| ref_images = padded | |
| ref_images = common_upscale(ref_images.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) | |
| ref_images = ref_images.to(self.vae.dtype).to(self.device).unsqueeze(0).permute(0, 4, 1, 2, 3).unsqueeze(0) | |
| ref_images = ref_images * 2 - 1 | |
| z0 = self.vace_encode_frames(input_frames, ref_images, masks=input_masks, tiled_vae=tiled_vae) | |
| self.vae.model.clear_cache() | |
| m0 = self.vace_encode_masks(input_masks, ref_images) | |
| z = self.vace_latent(z0, m0) | |
| self.vae.to(offload_device) | |
| vace_input = { | |
| "vace_context": z, | |
| "vace_scale": strength, | |
| "has_ref": ref_images is not None, | |
| "num_frames": num_frames, | |
| "target_shape": target_shape, | |
| "vace_start_percent": vace_start_percent, | |
| "vace_end_percent": vace_end_percent, | |
| "vace_seq_len": math.ceil((z[0].shape[2] * z[0].shape[3]) / 4 * z[0].shape[1]), | |
| "additional_vace_inputs": [], | |
| } | |
| if prev_vace_embeds is not None: | |
| if "additional_vace_inputs" in prev_vace_embeds and prev_vace_embeds["additional_vace_inputs"]: | |
| vace_input["additional_vace_inputs"] = prev_vace_embeds["additional_vace_inputs"].copy() | |
| vace_input["additional_vace_inputs"].append(prev_vace_embeds) | |
| return (vace_input,) | |
| def vace_encode_frames(self, frames, ref_images, masks=None, tiled_vae=False): | |
| if ref_images is None: | |
| ref_images = [None] * len(frames) | |
| else: | |
| assert len(frames) == len(ref_images) | |
| if masks is None: | |
| latents = self.vae.encode(frames, device=self.device, tiled=tiled_vae) | |
| else: | |
| inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] | |
| reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] | |
| inactive = self.vae.encode(inactive, device=self.device, tiled=tiled_vae) | |
| reactive = self.vae.encode(reactive, device=self.device, tiled=tiled_vae) | |
| latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] | |
| self.vae.model.clear_cache() | |
| cat_latents = [] | |
| for latent, refs in zip(latents, ref_images): | |
| if refs is not None: | |
| if masks is None: | |
| ref_latent = self.vae.encode(refs, device=self.device, tiled=tiled_vae) | |
| else: | |
| print("refs shape", refs.shape)#torch.Size([3, 1, 512, 512]) | |
| ref_latent = self.vae.encode(refs, device=self.device, tiled=tiled_vae) | |
| ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] | |
| assert all([x.shape[1] == 1 for x in ref_latent]) | |
| latent = torch.cat([*ref_latent, latent], dim=1) | |
| cat_latents.append(latent) | |
| return cat_latents | |
| def vace_encode_masks(self, masks, ref_images=None): | |
| if ref_images is None: | |
| ref_images = [None] * len(masks) | |
| else: | |
| assert len(masks) == len(ref_images) | |
| result_masks = [] | |
| for mask, refs in zip(masks, ref_images): | |
| c, depth, height, width = mask.shape | |
| new_depth = int((depth + 3) // self.vae_stride[0]) | |
| height = 2 * (int(height) // (self.vae_stride[1] * 2)) | |
| width = 2 * (int(width) // (self.vae_stride[2] * 2)) | |
| # reshape | |
| mask = mask[0, :, :, :] | |
| mask = mask.view( | |
| depth, height, self.vae_stride[1], width, self.vae_stride[1] | |
| ) # depth, height, 8, width, 8 | |
| mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width | |
| mask = mask.reshape( | |
| self.vae_stride[1] * self.vae_stride[2], depth, height, width | |
| ) # 8*8, depth, height, width | |
| # interpolation | |
| mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) | |
| if refs is not None: | |
| length = len(refs) | |
| mask_pad = torch.zeros_like(mask[:, :length, :, :]) | |
| mask = torch.cat((mask_pad, mask), dim=1) | |
| result_masks.append(mask) | |
| return result_masks | |
| def vace_latent(self, z, m): | |
| return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] | |
| class WanVideoVACEStartToEndFrame: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}), | |
| "empty_frame_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "White level of empty frame to use"}), | |
| }, | |
| "optional": { | |
| "start_image": ("IMAGE",), | |
| "end_image": ("IMAGE",), | |
| "control_images": ("IMAGE",), | |
| "inpaint_mask": ("MASK", {"tooltip": "Inpaint mask to use for the empty frames"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("IMAGE", "MASK", ) | |
| RETURN_NAMES = ("images", "masks",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Helper node to create start/end frame batch and masks for VACE" | |
| def process(self, num_frames, empty_frame_level, start_image=None, end_image=None, control_images=None, inpaint_mask=None): | |
| B, H, W, C = start_image.shape if start_image is not None else end_image.shape | |
| device = start_image.device if start_image is not None else end_image.device | |
| masks = torch.ones((num_frames, H, W), device=device) | |
| if control_images is not None: | |
| control_images = common_upscale(control_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1) | |
| if start_image is not None and end_image is not None: | |
| if start_image.shape != end_image.shape: | |
| end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1) | |
| if control_images is None: | |
| empty_frames = torch.ones((num_frames - start_image.shape[0] - end_image.shape[0], H, W, 3), device=device) * empty_frame_level | |
| else: | |
| empty_frames = control_images[start_image.shape[0]:num_frames - end_image.shape[0]] | |
| out_batch = torch.cat([start_image, empty_frames, end_image], dim=0) | |
| masks[0:start_image.shape[0]] = 0 | |
| masks[-end_image.shape[0]:] = 0 | |
| elif start_image is not None: | |
| if control_images is None: | |
| empty_frames = torch.ones((num_frames - start_image.shape[0], H, W, 3), device=device) * empty_frame_level | |
| else: | |
| empty_frames = control_images[start_image.shape[0]:num_frames] | |
| out_batch = torch.cat([start_image, empty_frames], dim=0) | |
| masks[0:start_image.shape[0]] = 0 | |
| elif end_image is not None: | |
| if control_images is None: | |
| empty_frames = torch.ones((num_frames - end_image.shape[0], H, W, 3), device=device) * empty_frame_level | |
| else: | |
| empty_frames = control_images[:num_frames - end_image.shape[0]] | |
| out_batch = torch.cat([empty_frames, end_image], dim=0) | |
| masks[-end_image.shape[0]:] = 0 | |
| if inpaint_mask is not None: | |
| inpaint_mask = common_upscale(inpaint_mask.unsqueeze(1), W, H, "nearest-exact", "disabled").squeeze(1).to(device) | |
| if inpaint_mask.shape[0] > num_frames: | |
| inpaint_mask = inpaint_mask[:num_frames] | |
| elif inpaint_mask.shape[0] < num_frames: | |
| inpaint_mask = inpaint_mask.repeat(num_frames // inpaint_mask.shape[0] + 1, 1, 1)[:num_frames] | |
| empty_mask = torch.ones_like(masks, device=device) | |
| masks = inpaint_mask * empty_mask | |
| return (out_batch.cpu().float(), masks.cpu().float()) | |
| #region context options | |
| class WanVideoContextOptions: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "context_schedule": (["uniform_standard", "uniform_looped", "static_standard"],), | |
| "context_frames": ("INT", {"default": 81, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of pixel frames in the context, NOTE: the latent space has 4 frames in 1"} ), | |
| "context_stride": ("INT", {"default": 4, "min": 4, "max": 100, "step": 1, "tooltip": "Context stride as pixel frames, NOTE: the latent space has 4 frames in 1"} ), | |
| "context_overlap": ("INT", {"default": 16, "min": 4, "max": 100, "step": 1, "tooltip": "Context overlap as pixel frames, NOTE: the latent space has 4 frames in 1"} ), | |
| "freenoise": ("BOOLEAN", {"default": True, "tooltip": "Shuffle the noise"}), | |
| "verbose": ("BOOLEAN", {"default": False, "tooltip": "Print debug output"}), | |
| }, | |
| "optional": { | |
| "vae": ("WANVAE",), | |
| } | |
| } | |
| RETURN_TYPES = ("WANVIDCONTEXT", ) | |
| RETURN_NAMES = ("context_options",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Context options for WanVideo, allows splitting the video into context windows and attemps blending them for longer generations than the model and memory otherwise would allow." | |
| def process(self, context_schedule, context_frames, context_stride, context_overlap, freenoise, verbose, image_cond_start_step=6, image_cond_window_count=2, vae=None): | |
| context_options = { | |
| "context_schedule":context_schedule, | |
| "context_frames":context_frames, | |
| "context_stride":context_stride, | |
| "context_overlap":context_overlap, | |
| "freenoise":freenoise, | |
| "verbose":verbose, | |
| "vae": vae, | |
| } | |
| return (context_options,) | |
| class CreateCFGScheduleFloatList: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "steps": ("INT", {"default": 30, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of steps to schedule cfg for"} ), | |
| "cfg_scale_start": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}), | |
| "cfg_scale_end": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}), | |
| "interpolation": (["linear", "ease_in", "ease_out"], {"default": "linear", "tooltip": "Interpolation method to use for the cfg scale"}), | |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "Start percent of the steps to apply cfg"}), | |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "End percent of the steps to apply cfg"}), | |
| } | |
| } | |
| RETURN_TYPES = ("FLOAT", ) | |
| RETURN_NAMES = ("float_list",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Helper node to generate a list of floats that can be used to schedule cfg scale for the steps, outside the set range cfg is set to 1.0" | |
| def process(self, steps, cfg_scale_start, cfg_scale_end, interpolation, start_percent, end_percent): | |
| # Create a list of floats for the cfg schedule | |
| cfg_list = [1.0] * steps | |
| start_idx = min(int(steps * start_percent), steps - 1) | |
| end_idx = min(int(steps * end_percent), steps - 1) | |
| for i in range(start_idx, end_idx + 1): | |
| if i >= steps: | |
| break | |
| if end_idx == start_idx: | |
| t = 0 | |
| else: | |
| t = (i - start_idx) / (end_idx - start_idx) | |
| if interpolation == "linear": | |
| factor = t | |
| elif interpolation == "ease_in": | |
| factor = t * t | |
| elif interpolation == "ease_out": | |
| factor = t * (2 - t) | |
| cfg_list[i] = round(cfg_scale_start + factor * (cfg_scale_end - cfg_scale_start), 2) | |
| # If start_percent > 0, always include the first step | |
| if start_percent > 0: | |
| cfg_list[0] = 1.0 | |
| return (cfg_list,) | |
| class WanVideoFlowEdit: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "source_embeds": ("WANVIDEOTEXTEMBEDS", ), | |
| "skip_steps": ("INT", {"default": 4, "min": 0}), | |
| "drift_steps": ("INT", {"default": 0, "min": 0}), | |
| "drift_flow_shift": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 30.0, "step": 0.01}), | |
| "source_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), | |
| "drift_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), | |
| }, | |
| "optional": { | |
| "source_image_embeds": ("WANVIDIMAGE_EMBEDS", ), | |
| } | |
| } | |
| RETURN_TYPES = ("FLOWEDITARGS", ) | |
| RETURN_NAMES = ("flowedit_args",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Flowedit options for WanVideo" | |
| def process(self, **kwargs): | |
| return (kwargs,) | |
| class WanVideoLoopArgs: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "shift_skip": ("INT", {"default": 6, "min": 0, "tooltip": "Skip step of latent shift"}), | |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the looping effect"}), | |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the looping effect"}), | |
| }, | |
| } | |
| RETURN_TYPES = ("LOOPARGS", ) | |
| RETURN_NAMES = ("loop_args",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Looping through latent shift as shown in https://github.com/YisuiTT/Mobius/" | |
| def process(self, **kwargs): | |
| return (kwargs,) | |
| class WanVideoExperimentalArgs: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "video_attention_split_steps": ("STRING", {"default": "", "tooltip": "Steps to split self attention when using multiple prompts"}), | |
| "cfg_zero_star": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WeichenFan/CFG-Zero-star"}), | |
| "use_zero_init": ("BOOLEAN", {"default": False}), | |
| "zero_star_steps": ("INT", {"default": 0, "min": 0, "tooltip": "Steps to split self attention when using multiple prompts"}), | |
| "use_fresca": ("BOOLEAN", {"default": False, "tooltip": "https://github.com/WikiChao/FreSca"}), | |
| "fresca_scale_low": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), | |
| "fresca_scale_high": ("FLOAT", {"default": 1.25, "min": 0.0, "max": 10.0, "step": 0.01}), | |
| "fresca_freq_cutoff": ("INT", {"default": 20, "min": 0, "max": 10000, "step": 1}), | |
| }, | |
| } | |
| RETURN_TYPES = ("EXPERIMENTALARGS", ) | |
| RETURN_NAMES = ("exp_args",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| DESCRIPTION = "Experimental stuff" | |
| EXPERIMENTAL = True | |
| def process(self, **kwargs): | |
| return (kwargs,) | |
| #region Sampler | |
| class WanVideoSampler: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": ("WANVIDEOMODEL",), | |
| "image_embeds": ("WANVIDIMAGE_EMBEDS", ), | |
| "steps": ("INT", {"default": 30, "min": 1}), | |
| "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}), | |
| "shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}), | |
| "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), | |
| "force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}), | |
| "scheduler": (["unipc", "unipc/beta", "dpm++", "dpm++/beta","dpm++_sde", "dpm++_sde/beta", "euler", "euler/beta", "euler/accvideo", "deis", "lcm", "lcm/beta", "flowmatch_causvid", "flowmatch_distill"], | |
| { | |
| "default": 'unipc' | |
| }), | |
| "riflex_freq_index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": "Frequency index for RIFLEX, disabled when 0, default 6. Allows for new frames to be generated after without looping"}), | |
| }, | |
| "optional": { | |
| "text_embeds": ("WANVIDEOTEXTEMBEDS", ), | |
| "samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ), | |
| "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "feta_args": ("FETAARGS", ), | |
| "context_options": ("WANVIDCONTEXT", ), | |
| "cache_args": ("CACHEARGS", ), | |
| "flowedit_args": ("FLOWEDITARGS", ), | |
| "batched_cfg": ("BOOLEAN", {"default": False, "tooltip": "Batch cond and uncond for faster sampling, possibly faster on some hardware, uses more memory"}), | |
| "slg_args": ("SLGARGS", ), | |
| "rope_function": (["default", "comfy"], {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile"}), | |
| "loop_args": ("LOOPARGS", ), | |
| "experimental_args": ("EXPERIMENTALARGS", ), | |
| "sigmas": ("SIGMAS", ), | |
| "unianimate_poses": ("UNIANIMATE_POSE", ), | |
| "fantasytalking_embeds": ("FANTASYTALKING_EMBEDS", ), | |
| "uni3c_embeds": ("UNI3C_EMBEDS", ), | |
| } | |
| } | |
| RETURN_TYPES = ("LATENT", ) | |
| RETURN_NAMES = ("samples",) | |
| FUNCTION = "process" | |
| CATEGORY = "WanVideoWrapper" | |
| def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, riflex_freq_index, text_embeds=None, | |
| force_offload=True, samples=None, feta_args=None, denoise_strength=1.0, context_options=None, | |
| cache_args=None, teacache_args=None, flowedit_args=None, batched_cfg=False, slg_args=None, rope_function="default", loop_args=None, | |
| experimental_args=None, sigmas=None, unianimate_poses=None, fantasytalking_embeds=None, uni3c_embeds=None): | |
| patcher = model | |
| model = model.model | |
| transformer = model.diffusion_model | |
| dtype = model["dtype"] | |
| control_lora = model["control_lora"] | |
| transformer_options = patcher.model_options.get("transformer_options", None) | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| steps = int(steps/denoise_strength) | |
| if text_embeds == None: | |
| text_embeds = { | |
| "prompt_embeds": [], | |
| "negative_prompt_embeds": [], | |
| } | |
| if isinstance(cfg, list): | |
| if steps != len(cfg): | |
| log.info(f"Received {len(cfg)} cfg values, but only {steps} steps. Setting step count to match.") | |
| steps = len(cfg) | |
| timesteps = None | |
| if 'unipc' in scheduler: | |
| sample_scheduler = FlowUniPCMultistepScheduler(shift=shift) | |
| if sigmas is None: | |
| sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler)) | |
| else: | |
| sample_scheduler.sigmas = sigmas.to(device) | |
| sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device) | |
| sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps) | |
| elif scheduler in ['euler/beta', 'euler']: | |
| sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta')) | |
| if flowedit_args: #seems to work better | |
| timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=get_sampling_sigmas(steps, shift)) | |
| else: | |
| sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas.tolist() if sigmas is not None else None) | |
| elif scheduler in ['euler/accvideo']: | |
| if steps != 50: | |
| raise Exception("Steps must be set to 50 for accvideo scheduler, 10 actual steps are used") | |
| sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta')) | |
| sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas.tolist() if sigmas is not None else None) | |
| start_latent_list = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50] | |
| sample_scheduler.sigmas = sample_scheduler.sigmas[start_latent_list] | |
| steps = len(start_latent_list) - 1 | |
| sample_scheduler.timesteps = timesteps = sample_scheduler.timesteps[start_latent_list[:steps]] | |
| elif 'dpm++' in scheduler: | |
| if 'sde' in scheduler: | |
| algorithm_type = "sde-dpmsolver++" | |
| else: | |
| algorithm_type = "dpmsolver++" | |
| sample_scheduler = FlowDPMSolverMultistepScheduler(shift=shift, algorithm_type=algorithm_type) | |
| if sigmas is None: | |
| sample_scheduler.set_timesteps(steps, device=device, use_beta_sigmas=('beta' in scheduler)) | |
| else: | |
| sample_scheduler.sigmas = sigmas.to(device) | |
| sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device) | |
| sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps) | |
| elif scheduler == 'deis': | |
| sample_scheduler = DEISMultistepScheduler(use_flow_sigmas=True, prediction_type="flow_prediction", flow_shift=shift) | |
| sample_scheduler.set_timesteps(steps, device=device) | |
| sample_scheduler.sigmas[-1] = 1e-6 | |
| elif 'lcm' in scheduler: | |
| sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta')) | |
| sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas.tolist() if sigmas is not None else None) | |
| elif 'flowmatch_causvid' in scheduler: | |
| if transformer.dim == 5120: | |
| denoising_list = [999, 934, 862, 756, 603, 410, 250, 140, 74] | |
| else: | |
| if steps != 4: | |
| raise ValueError("CausVid 1.3B schedule is only for 4 steps") | |
| denoising_list = [1000, 750, 500, 250] | |
| sample_scheduler = FlowMatchScheduler(num_inference_steps=steps, shift=shift, sigma_min=0, extra_one_step=True) | |
| sample_scheduler.timesteps = torch.tensor(denoising_list)[:steps].to(device) | |
| sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)]) | |
| elif 'flowmatch_distill' in scheduler: | |
| sample_scheduler = FlowMatchScheduler( | |
| shift=shift, sigma_min=0.0, extra_one_step=True | |
| ) | |
| sample_scheduler.set_timesteps(1000, training=True) | |
| denoising_step_list = torch.tensor([999, 750, 500, 250] , dtype=torch.long) | |
| temp_timesteps = torch.cat((sample_scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))) | |
| denoising_step_list = temp_timesteps[1000 - denoising_step_list] | |
| print("denoising_step_list: ", denoising_step_list) | |
| #denoising_step_list = [999, 750, 500, 250] | |
| if steps != 4: | |
| raise ValueError("This scheduler is only for 4 steps") | |
| #sample_scheduler = FlowMatchScheduler(num_inference_steps=steps, shift=shift, sigma_min=0, extra_one_step=True) | |
| sample_scheduler.timesteps = torch.tensor(denoising_step_list)[:steps].to(device) | |
| sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)]) | |
| if timesteps is None: | |
| timesteps = sample_scheduler.timesteps | |
| log.info(f"timesteps: {timesteps}") | |
| if denoise_strength < 1.0: | |
| steps = int(steps * denoise_strength) | |
| timesteps = timesteps[-(steps + 1):] | |
| seed_g = torch.Generator(device=torch.device("cpu")) | |
| seed_g.manual_seed(seed) | |
| control_latents = control_camera_latents = clip_fea = clip_fea_neg = end_image = recammaster = camera_embed = unianim_data = None | |
| vace_data = vace_context = vace_scale = None | |
| fun_or_fl2v_model = has_ref = drop_last = False | |
| phantom_latents = None | |
| fun_ref_image = None | |
| image_cond = image_embeds.get("image_embeds", None) | |
| ATI_tracks = None | |
| add_cond = attn_cond = attn_cond_neg = None | |
| if image_cond is not None: | |
| log.info(f"image_cond shape: {image_cond.shape}") | |
| #ATI tracks | |
| if transformer_options is not None: | |
| ATI_tracks = transformer_options.get("ati_tracks", None) | |
| if ATI_tracks is not None: | |
| from .ATI.motion_patch import patch_motion | |
| topk = transformer_options.get("ati_topk", 2) | |
| temperature = transformer_options.get("ati_temperature", 220.0) | |
| ati_start_percent = transformer_options.get("ati_start_percent", 0.0) | |
| ati_end_percent = transformer_options.get("ati_end_percent", 1.0) | |
| image_cond_ati = patch_motion(ATI_tracks.to(image_cond.device, image_cond.dtype), image_cond, topk=topk, temperature=temperature) | |
| log.info(f"ATI tracks shape: {ATI_tracks.shape}") | |
| realisdance_latents = image_embeds.get("realisdance_latents", None) | |
| if realisdance_latents is not None: | |
| add_cond = realisdance_latents["pose_latent"] | |
| attn_cond = realisdance_latents["ref_latent"] | |
| attn_cond_neg = realisdance_latents["ref_latent_neg"] | |
| add_cond_start_percent = realisdance_latents["pose_cond_start_percent"] | |
| add_cond_end_percent = realisdance_latents["pose_cond_end_percent"] | |
| end_image = image_embeds.get("end_image", None) | |
| lat_h = image_embeds.get("lat_h", None) | |
| lat_w = image_embeds.get("lat_w", None) | |
| if lat_h is None or lat_w is None: | |
| raise ValueError("Clip encoded image embeds must be provided for I2V (Image to Video) model") | |
| fun_or_fl2v_model = image_embeds.get("fun_or_fl2v_model", False) | |
| noise = torch.randn( | |
| 16, | |
| (image_embeds["num_frames"] - 1) // 4 + (2 if end_image is not None and not fun_or_fl2v_model else 1), | |
| lat_h, | |
| lat_w, | |
| dtype=torch.float32, | |
| generator=seed_g, | |
| device=torch.device("cpu")) | |
| seq_len = image_embeds["max_seq_len"] | |
| clip_fea = image_embeds.get("clip_context", None) | |
| if clip_fea is not None: | |
| clip_fea = clip_fea.to(dtype) | |
| clip_fea_neg = image_embeds.get("negative_clip_context", None) | |
| if clip_fea_neg is not None: | |
| clip_fea_neg = clip_fea_neg.to(dtype) | |
| control_embeds = image_embeds.get("control_embeds", None) | |
| if control_embeds is not None: | |
| if transformer.in_dim not in [48, 32]: | |
| raise ValueError("Control signal only works with Fun-Control model") | |
| control_latents = control_embeds.get("control_images", None) | |
| control_camera_latents = control_embeds.get("control_camera_latents", None) | |
| control_camera_start_percent = control_embeds.get("control_camera_start_percent", 0.0) | |
| control_camera_end_percent = control_embeds.get("control_camera_end_percent", 1.0) | |
| control_start_percent = control_embeds.get("start_percent", 0.0) | |
| control_end_percent = control_embeds.get("end_percent", 1.0) | |
| drop_last = image_embeds.get("drop_last", False) | |
| has_ref = image_embeds.get("has_ref", False) | |
| else: #t2v | |
| target_shape = image_embeds.get("target_shape", None) | |
| if target_shape is None: | |
| raise ValueError("Empty image embeds must be provided for T2V (Text to Video") | |
| has_ref = image_embeds.get("has_ref", False) | |
| vace_context = image_embeds.get("vace_context", None) | |
| vace_scale = image_embeds.get("vace_scale", None) | |
| if not isinstance(vace_scale, list): | |
| vace_scale = [vace_scale] * (steps+1) | |
| vace_start_percent = image_embeds.get("vace_start_percent", 0.0) | |
| vace_end_percent = image_embeds.get("vace_end_percent", 1.0) | |
| vace_seqlen = image_embeds.get("vace_seq_len", None) | |
| vace_additional_embeds = image_embeds.get("additional_vace_inputs", []) | |
| if vace_context is not None: | |
| vace_data = [ | |
| {"context": vace_context, | |
| "scale": vace_scale, | |
| "start": vace_start_percent, | |
| "end": vace_end_percent, | |
| "seq_len": vace_seqlen | |
| } | |
| ] | |
| if len(vace_additional_embeds) > 0: | |
| for i in range(len(vace_additional_embeds)): | |
| if vace_additional_embeds[i].get("has_ref", False): | |
| has_ref = True | |
| vace_scale = vace_additional_embeds[i]["vace_scale"] | |
| if not isinstance(vace_scale, list): | |
| vace_scale = [vace_scale] * (steps+1) | |
| vace_data.append({ | |
| "context": vace_additional_embeds[i]["vace_context"], | |
| "scale": vace_scale, | |
| "start": vace_additional_embeds[i]["vace_start_percent"], | |
| "end": vace_additional_embeds[i]["vace_end_percent"], | |
| "seq_len": vace_additional_embeds[i]["vace_seq_len"] | |
| }) | |
| noise = torch.randn( | |
| target_shape[0], | |
| target_shape[1] + 1 if has_ref else target_shape[1], | |
| target_shape[2], | |
| target_shape[3], | |
| dtype=torch.float32, | |
| device=torch.device("cpu"), | |
| generator=seed_g) | |
| seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * noise.shape[1]) | |
| recammaster = image_embeds.get("recammaster", None) | |
| if recammaster is not None: | |
| camera_embed = recammaster.get("camera_embed", None) | |
| recam_latents = recammaster.get("source_latents", None) | |
| orig_noise_len = noise.shape[1] | |
| log.info(f"RecamMaster camera embed shape: {camera_embed.shape}") | |
| log.info(f"RecamMaster source video shape: {recam_latents.shape}") | |
| seq_len *= 2 | |
| control_embeds = image_embeds.get("control_embeds", None) | |
| if control_embeds is not None: | |
| control_latents = control_embeds.get("control_images", None) | |
| if control_latents is not None: | |
| control_latents = control_latents.to(device) | |
| control_camera_latents = control_embeds.get("control_camera_latents", None) | |
| control_camera_start_percent = control_embeds.get("control_camera_start_percent", 0.0) | |
| control_camera_end_percent = control_embeds.get("control_camera_end_percent", 1.0) | |
| if control_camera_latents is not None: | |
| control_camera_latents = control_camera_latents.to(device) | |
| if control_lora: | |
| image_cond = control_latents.to(device) | |
| if not patcher.model.is_patched: | |
| log.info("Re-loading control LoRA...") | |
| patcher = apply_lora(patcher, device, device, low_mem_load=False) | |
| patcher.model.is_patched = True | |
| else: | |
| if transformer.in_dim not in [48, 32]: | |
| raise ValueError("Control signal only works with Fun-Control model") | |
| image_cond = torch.zeros_like(noise).to(device) #fun control | |
| clip_fea = None | |
| fun_ref_image = control_embeds.get("fun_ref_image", None) | |
| control_start_percent = control_embeds.get("start_percent", 0.0) | |
| control_end_percent = control_embeds.get("end_percent", 1.0) | |
| else: | |
| if transformer.in_dim == 36: #fun inp | |
| mask_latents = torch.tile( | |
| torch.zeros_like(noise[:1]), [4, 1, 1, 1] | |
| ) | |
| masked_video_latents_input = torch.zeros_like(noise) | |
| image_cond = torch.cat([mask_latents, masked_video_latents_input], dim=0).to(device) | |
| phantom_latents = image_embeds.get("phantom_latents", None) | |
| phantom_cfg_scale = image_embeds.get("phantom_cfg_scale", None) | |
| if not isinstance(phantom_cfg_scale, list): | |
| phantom_cfg_scale = [phantom_cfg_scale] * (steps +1) | |
| phantom_start_percent = image_embeds.get("phantom_start_percent", 0.0) | |
| phantom_end_percent = image_embeds.get("phantom_end_percent", 1.0) | |
| if phantom_latents is not None: | |
| phantom_latents = phantom_latents.to(device) | |
| latent_video_length = noise.shape[1] | |
| if unianimate_poses is not None: | |
| transformer.dwpose_embedding.to(device, model["dtype"]) | |
| dwpose_data = unianimate_poses["pose"].to(device, model["dtype"]) | |
| dwpose_data = torch.cat([dwpose_data[:,:,:1].repeat(1,1,3,1,1), dwpose_data], dim=2) | |
| dwpose_data = transformer.dwpose_embedding(dwpose_data) | |
| log.info(f"UniAnimate pose embed shape: {dwpose_data.shape}") | |
| if dwpose_data.shape[2] > latent_video_length: | |
| log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is longer than the video length {latent_video_length}, truncating") | |
| dwpose_data = dwpose_data[:,:, :latent_video_length] | |
| elif dwpose_data.shape[2] < latent_video_length: | |
| log.warning(f"UniAnimate pose embed length {dwpose_data.shape[2]} is shorter than the video length {latent_video_length}, padding with last pose") | |
| pad_len = latent_video_length - dwpose_data.shape[2] | |
| pad = dwpose_data[:,:,:1].repeat(1,1,pad_len,1,1) | |
| dwpose_data = torch.cat([dwpose_data, pad], dim=2) | |
| dwpose_data_flat = rearrange(dwpose_data, 'b c f h w -> b (f h w) c').contiguous() | |
| random_ref_dwpose_data = None | |
| if image_cond is not None: | |
| transformer.randomref_embedding_pose.to(device) | |
| random_ref_dwpose = unianimate_poses.get("ref", None) | |
| if random_ref_dwpose is not None: | |
| random_ref_dwpose_data = transformer.randomref_embedding_pose( | |
| random_ref_dwpose.to(device) | |
| ).unsqueeze(2).to(model["dtype"]) # [1, 20, 104, 60] | |
| unianim_data = { | |
| "dwpose": dwpose_data_flat, | |
| "random_ref": random_ref_dwpose_data.squeeze(0) if random_ref_dwpose_data is not None else None, | |
| "strength": unianimate_poses["strength"], | |
| "start_percent": unianimate_poses["start_percent"], | |
| "end_percent": unianimate_poses["end_percent"] | |
| } | |
| audio_proj = None | |
| if fantasytalking_embeds is not None: | |
| audio_proj = fantasytalking_embeds["audio_proj"].to(device) | |
| audio_context_lens = fantasytalking_embeds["audio_context_lens"] | |
| audio_scale = fantasytalking_embeds["audio_scale"] | |
| audio_cfg_scale = fantasytalking_embeds["audio_cfg_scale"] | |
| if not isinstance(audio_cfg_scale, list): | |
| audio_cfg_scale = [audio_cfg_scale] * (steps +1) | |
| log.info(f"Audio proj shape: {audio_proj.shape}, audio context lens: {audio_context_lens}") | |
| minimax_latents = minimax_mask_latents = None | |
| minimax_latents = image_embeds.get("minimax_latents", None) | |
| minimax_mask_latents = image_embeds.get("minimax_mask_latents", None) | |
| if minimax_latents is not None: | |
| log.info(f"minimax_latents: {minimax_latents.shape}") | |
| log.info(f"minimax_mask_latents: {minimax_mask_latents.shape}") | |
| minimax_latents = minimax_latents.to(device, dtype) | |
| minimax_mask_latents = minimax_mask_latents.to(device, dtype) | |
| is_looped = False | |
| if context_options is not None: | |
| def create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=False): | |
| window_mask = torch.ones_like(noise_pred_context) | |
| # Apply left-side blending for all except first chunk (or always in loop mode) | |
| if min(c) > 0 or (looped and max(c) == latent_video_length - 1): | |
| ramp_up = torch.linspace(0, 1, context_overlap, device=noise_pred_context.device) | |
| ramp_up = ramp_up.view(1, -1, 1, 1) | |
| window_mask[:, :context_overlap] = ramp_up | |
| # Apply right-side blending for all except last chunk (or always in loop mode) | |
| if max(c) < latent_video_length - 1 or (looped and min(c) == 0): | |
| ramp_down = torch.linspace(1, 0, context_overlap, device=noise_pred_context.device) | |
| ramp_down = ramp_down.view(1, -1, 1, 1) | |
| window_mask[:, -context_overlap:] = ramp_down | |
| return window_mask | |
| context_schedule = context_options["context_schedule"] | |
| context_frames = (context_options["context_frames"] - 1) // 4 + 1 | |
| context_stride = context_options["context_stride"] // 4 | |
| context_overlap = context_options["context_overlap"] // 4 | |
| context_vae = context_options.get("vae", None) | |
| if context_vae is not None: | |
| context_vae.to(device) | |
| self.window_tracker = WindowTracker(verbose=context_options["verbose"]) | |
| # Get total number of prompts | |
| num_prompts = len(text_embeds["prompt_embeds"]) | |
| log.info(f"Number of prompts: {num_prompts}") | |
| # Calculate which section this context window belongs to | |
| section_size = latent_video_length / num_prompts | |
| log.info(f"Section size: {section_size}") | |
| is_looped = context_schedule == "uniform_looped" | |
| seq_len = math.ceil((noise.shape[2] * noise.shape[3]) / 4 * context_frames) | |
| if context_options["freenoise"]: | |
| log.info("Applying FreeNoise") | |
| # code from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) | |
| delta = context_frames - context_overlap | |
| for start_idx in range(0, latent_video_length-context_frames, delta): | |
| place_idx = start_idx + context_frames | |
| if place_idx >= latent_video_length: | |
| break | |
| end_idx = place_idx - 1 | |
| if end_idx + delta >= latent_video_length: | |
| final_delta = latent_video_length - place_idx | |
| list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) | |
| list_idx = list_idx[torch.randperm(final_delta, generator=seed_g)] | |
| noise[:, place_idx:place_idx + final_delta, :, :] = noise[:, list_idx, :, :] | |
| break | |
| list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) | |
| list_idx = list_idx[torch.randperm(delta, generator=seed_g)] | |
| noise[:, place_idx:place_idx + delta, :, :] = noise[:, list_idx, :, :] | |
| log.info(f"Context schedule enabled: {context_frames} frames, {context_stride} stride, {context_overlap} overlap") | |
| from .context import get_context_scheduler | |
| context = get_context_scheduler(context_schedule) | |
| if samples is not None: | |
| input_samples = samples["samples"].squeeze(0).to(noise) | |
| if input_samples.shape[1] != noise.shape[1]: | |
| input_samples = torch.cat([input_samples[:, :1].repeat(1, noise.shape[1] - input_samples.shape[1], 1, 1), input_samples], dim=1) | |
| original_image = input_samples.to(device) | |
| if denoise_strength < 1.0: | |
| latent_timestep = timesteps[:1].to(noise) | |
| noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples | |
| mask = samples.get("mask", None) | |
| if mask is not None: | |
| if mask.shape[2] != noise.shape[1]: | |
| mask = torch.cat([torch.zeros(1, noise.shape[0], noise.shape[1] - mask.shape[2], noise.shape[2], noise.shape[3]), mask], dim=2) | |
| latent = noise.to(device) | |
| freqs = None | |
| transformer.rope_embedder.k = None | |
| transformer.rope_embedder.num_frames = None | |
| if rope_function=="comfy": | |
| transformer.rope_embedder.k = riflex_freq_index | |
| transformer.rope_embedder.num_frames = latent_video_length | |
| else: | |
| d = transformer.dim // transformer.num_heads | |
| freqs = torch.cat([ | |
| rope_params(1024, d - 4 * (d // 6), L_test=latent_video_length, k=riflex_freq_index), | |
| rope_params(1024, 2 * (d // 6)), | |
| rope_params(1024, 2 * (d // 6)) | |
| ], | |
| dim=1) | |
| if not isinstance(cfg, list): | |
| cfg = [cfg] * (steps +1) | |
| log.info(f"Seq len: {seq_len}") | |
| pbar = ProgressBar(steps) | |
| if args.preview_method in [LatentPreviewMethod.Auto, LatentPreviewMethod.Latent2RGB]: #default for latent2rgb | |
| from latent_preview import prepare_callback | |
| else: | |
| from .latent_preview import prepare_callback #custom for tiny VAE previews | |
| callback = prepare_callback(patcher, steps) | |
| #blockswap init | |
| if transformer_options is not None: | |
| block_swap_args = transformer_options.get("block_swap_args", None) | |
| if block_swap_args is not None: | |
| transformer.use_non_blocking = block_swap_args.get("use_non_blocking", True) | |
| for name, param in transformer.named_parameters(): | |
| if "block" not in name: | |
| param.data = param.data.to(device) | |
| if "control_adapter" in name: | |
| param.data = param.data.to(device) | |
| elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: | |
| param.data = param.data.to(offload_device, non_blocking=transformer.use_non_blocking) | |
| elif block_swap_args["offload_img_emb"] and "img_emb" in name: | |
| param.data = param.data.to(offload_device, non_blocking=transformer.use_non_blocking) | |
| transformer.block_swap( | |
| block_swap_args["blocks_to_swap"] - 1 , | |
| block_swap_args["offload_txt_emb"], | |
| block_swap_args["offload_img_emb"], | |
| vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), | |
| ) | |
| elif model["auto_cpu_offload"]: | |
| for module in transformer.modules(): | |
| if hasattr(module, "offload"): | |
| module.offload() | |
| if hasattr(module, "onload"): | |
| module.onload() | |
| elif model["manual_offloading"]: | |
| transformer.to(device) | |
| #controlnet | |
| controlnet_latents = controlnet = None | |
| if transformer_options is not None: | |
| controlnet = transformer_options.get("controlnet", None) | |
| if controlnet is not None: | |
| self.controlnet = controlnet["controlnet"] | |
| controlnet_start = controlnet["controlnet_start"] | |
| controlnet_end = controlnet["controlnet_end"] | |
| controlnet_latents = controlnet["control_latents"] | |
| controlnet["controlnet_weight"] = controlnet["controlnet_strength"] | |
| controlnet["controlnet_stride"] = controlnet["control_stride"] | |
| #uni3c | |
| pcd_data = None | |
| if uni3c_embeds is not None: | |
| transformer.controlnet = uni3c_embeds["controlnet"] | |
| pcd_data = { | |
| "render_latent": uni3c_embeds["render_latent"], | |
| "render_mask": uni3c_embeds["render_mask"], | |
| "camera_embedding": uni3c_embeds["camera_embedding"], | |
| "controlnet_weight": uni3c_embeds["controlnet_weight"], | |
| "start": uni3c_embeds["start"], | |
| "end": uni3c_embeds["end"], | |
| } | |
| #feta | |
| if feta_args is not None and latent_video_length > 1: | |
| set_enhance_weight(feta_args["weight"]) | |
| feta_start_percent = feta_args["start_percent"] | |
| feta_end_percent = feta_args["end_percent"] | |
| if context_options is not None: | |
| set_num_frames(context_frames) | |
| else: | |
| set_num_frames(latent_video_length) | |
| enable_enhance() | |
| else: | |
| feta_args = None | |
| disable_enhance() | |
| # Initialize Cache if enabled | |
| transformer.enable_teacache = transformer.enable_magcache = False | |
| if teacache_args is not None: #for backward compatibility on old workflows | |
| cache_args = teacache_args | |
| if cache_args is not None: | |
| transformer.cache_device = cache_args["cache_device"] | |
| if cache_args["cache_type"] == "TeaCache": | |
| log.info(f"TeaCache: Using cache device: {transformer.cache_device}") | |
| transformer.teacache_state.clear_all() | |
| transformer.enable_teacache = True | |
| transformer.rel_l1_thresh = cache_args["rel_l1_thresh"] | |
| transformer.teacache_start_step = cache_args["start_step"] | |
| transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] | |
| transformer.teacache_use_coefficients = cache_args["use_coefficients"] | |
| transformer.teacache_mode = cache_args["mode"] | |
| elif cache_args["cache_type"] == "MagCache": | |
| log.info(f"MagCache: Using cache device: {transformer.cache_device}") | |
| transformer.magcache_state.clear_all() | |
| transformer.enable_magcache = True | |
| transformer.magcache_start_step = cache_args["start_step"] | |
| transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"] | |
| transformer.magcache_thresh = cache_args["magcache_thresh"] | |
| transformer.magcache_K = cache_args["magcache_K"] | |
| if slg_args is not None: | |
| assert batched_cfg is not None, "Batched cfg is not supported with SLG" | |
| transformer.slg_blocks = slg_args["blocks"] | |
| transformer.slg_start_percent = slg_args["start_percent"] | |
| transformer.slg_end_percent = slg_args["end_percent"] | |
| else: | |
| transformer.slg_blocks = None | |
| self.cache_state = [None, None] | |
| if phantom_latents is not None: | |
| log.info(f"Phantom latents shape: {phantom_latents.shape}") | |
| self.cache_state = [None, None, None] | |
| self.cache_state_source = [None, None] | |
| self.cache_states_context = [] | |
| if flowedit_args is not None: | |
| source_embeds = flowedit_args["source_embeds"] | |
| source_image_embeds = flowedit_args.get("source_image_embeds", image_embeds) | |
| source_image_cond = source_image_embeds.get("image_embeds", None) | |
| source_clip_fea = source_image_embeds.get("clip_fea", clip_fea) | |
| if source_image_cond is not None: | |
| source_image_cond = source_image_cond.to(dtype) | |
| skip_steps = flowedit_args["skip_steps"] | |
| drift_steps = flowedit_args["drift_steps"] | |
| source_cfg = flowedit_args["source_cfg"] | |
| if not isinstance(source_cfg, list): | |
| source_cfg = [source_cfg] * (steps +1) | |
| drift_cfg = flowedit_args["drift_cfg"] | |
| if not isinstance(drift_cfg, list): | |
| drift_cfg = [drift_cfg] * (steps +1) | |
| x_init = samples["samples"].clone().squeeze(0).to(device) | |
| x_tgt = samples["samples"].squeeze(0).to(device) | |
| sample_scheduler = FlowMatchEulerDiscreteScheduler( | |
| num_train_timesteps=1000, | |
| shift=flowedit_args["drift_flow_shift"], | |
| use_dynamic_shifting=False) | |
| sampling_sigmas = get_sampling_sigmas(steps, flowedit_args["drift_flow_shift"]) | |
| drift_timesteps, _ = retrieve_timesteps( | |
| sample_scheduler, | |
| device=device, | |
| sigmas=sampling_sigmas) | |
| if drift_steps > 0: | |
| drift_timesteps = torch.cat([drift_timesteps, torch.tensor([0]).to(drift_timesteps.device)]).to(drift_timesteps.device) | |
| timesteps[-drift_steps:] = drift_timesteps[-drift_steps:] | |
| use_cfg_zero_star = use_fresca = False | |
| if experimental_args is not None: | |
| video_attention_split_steps = experimental_args.get("video_attention_split_steps", []) | |
| if video_attention_split_steps: | |
| transformer.video_attention_split_steps = [int(x.strip()) for x in video_attention_split_steps.split(",")] | |
| else: | |
| transformer.video_attention_split_steps = [] | |
| use_zero_init = experimental_args.get("use_zero_init", True) | |
| use_cfg_zero_star = experimental_args.get("cfg_zero_star", False) | |
| zero_star_steps = experimental_args.get("zero_star_steps", 0) | |
| use_fresca = experimental_args.get("use_fresca", False) | |
| if use_fresca: | |
| fresca_scale_low = experimental_args.get("fresca_scale_low", 1.0) | |
| fresca_scale_high = experimental_args.get("fresca_scale_high", 1.25) | |
| fresca_freq_cutoff = experimental_args.get("fresca_freq_cutoff", 20) | |
| #region model pred | |
| def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None, | |
| control_latents=None, vace_data=None, unianim_data=None, audio_proj=None, control_camera_latents=None, add_cond=None, cache_state=None): | |
| z = z.to(dtype) | |
| with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype, enabled=("fp8" in model["quantization"])): | |
| if use_cfg_zero_star and (idx <= zero_star_steps) and use_zero_init: | |
| return z*0, None | |
| nonlocal patcher | |
| current_step_percentage = idx / len(timesteps) | |
| control_lora_enabled = False | |
| image_cond_input = None | |
| if control_latents is not None: | |
| if control_lora: | |
| control_lora_enabled = True | |
| else: | |
| if (control_start_percent <= current_step_percentage <= control_end_percent) or \ | |
| (control_end_percent > 0 and idx == 0 and current_step_percentage >= control_start_percent): | |
| image_cond_input = torch.cat([control_latents.to(z), image_cond.to(z)]) | |
| else: | |
| image_cond_input = torch.cat([torch.zeros_like(image_cond, dtype=dtype), image_cond.to(z)]) | |
| if fun_ref_image is not None: | |
| fun_ref_input = fun_ref_image.to(z) | |
| else: | |
| fun_ref_input = torch.zeros_like(z, dtype=z.dtype)[:, 0].unsqueeze(1) | |
| #fun_ref_input = None | |
| if control_lora: | |
| if not control_start_percent <= current_step_percentage <= control_end_percent: | |
| control_lora_enabled = False | |
| if patcher.model.is_patched: | |
| log.info("Unloading LoRA...") | |
| patcher.unpatch_model(device) | |
| patcher.model.is_patched = False | |
| else: | |
| image_cond_input = control_latents.to(z) | |
| if not patcher.model.is_patched: | |
| log.info("Loading LoRA...") | |
| patcher = apply_lora(patcher, device, device, low_mem_load=False) | |
| patcher.model.is_patched = True | |
| elif ATI_tracks is not None and ((ati_start_percent <= current_step_percentage <= ati_end_percent) or | |
| (ati_end_percent > 0 and idx == 0 and current_step_percentage >= ati_start_percent)): | |
| image_cond_input = image_cond_ati.to(z) | |
| else: | |
| image_cond_input = image_cond.to(z) if image_cond is not None else None | |
| if control_camera_latents is not None: | |
| if (control_camera_start_percent <= current_step_percentage <= control_camera_end_percent) or \ | |
| (control_end_percent > 0 and idx == 0 and current_step_percentage >= control_camera_start_percent): | |
| control_camera_input = control_camera_latents.to(z) | |
| else: | |
| control_camera_input = None | |
| if recammaster is not None: | |
| z = torch.cat([z, recam_latents.to(z)], dim=1) | |
| use_phantom = False | |
| if phantom_latents is not None: | |
| if (phantom_start_percent <= current_step_percentage <= phantom_end_percent) or \ | |
| (phantom_end_percent > 0 and idx == 0 and current_step_percentage >= phantom_start_percent): | |
| z_pos = torch.cat([z[:,:-phantom_latents.shape[1]], phantom_latents.to(z)], dim=1) | |
| z_phantom_img = torch.cat([z[:,:-phantom_latents.shape[1]], phantom_latents.to(z)], dim=1) | |
| z_neg = torch.cat([z[:,:-phantom_latents.shape[1]], torch.zeros_like(phantom_latents).to(z)], dim=1) | |
| use_phantom = True | |
| if cache_state is not None and len(cache_state) != 3: | |
| cache_state.append(None) | |
| if not use_phantom: | |
| z_pos = z_neg = z | |
| if controlnet_latents is not None: | |
| if (controlnet_start <= current_step_percentage < controlnet_end): | |
| self.controlnet.to(device) | |
| controlnet_states = self.controlnet( | |
| hidden_states=z.unsqueeze(0).to(device, self.controlnet.dtype), | |
| timestep=timestep, | |
| encoder_hidden_states=positive_embeds[0].unsqueeze(0).to(device, self.controlnet.dtype), | |
| attention_kwargs=None, | |
| controlnet_states=controlnet_latents.to(device, self.controlnet.dtype), | |
| return_dict=False, | |
| )[0] | |
| if isinstance(controlnet_states, (tuple, list)): | |
| controlnet["controlnet_states"] = [x.to(z) for x in controlnet_states] | |
| else: | |
| controlnet["controlnet_states"] = controlnet_states.to(z) | |
| add_cond_input = None | |
| if add_cond is not None: | |
| if (add_cond_start_percent <= current_step_percentage <= add_cond_end_percent) or \ | |
| (add_cond_end_percent > 0 and idx == 0 and current_step_percentage >= add_cond_start_percent): | |
| add_cond_input = add_cond | |
| if minimax_latents is not None: | |
| z_pos = z_neg = torch.cat([z, minimax_latents, minimax_mask_latents], dim=0) | |
| base_params = { | |
| 'seq_len': seq_len, | |
| 'device': device, | |
| 'freqs': freqs, | |
| 't': timestep, | |
| 'current_step': idx, | |
| 'control_lora_enabled': control_lora_enabled, | |
| 'camera_embed': camera_embed, | |
| 'unianim_data': unianim_data, | |
| 'fun_ref': fun_ref_input if fun_ref_image is not None else None, | |
| 'fun_camera': control_camera_input if control_camera_latents is not None else None, | |
| 'audio_proj': audio_proj if fantasytalking_embeds is not None else None, | |
| 'audio_context_lens': audio_context_lens if fantasytalking_embeds is not None else None, | |
| 'audio_scale': audio_scale if fantasytalking_embeds is not None else None, | |
| "pcd_data": pcd_data, | |
| "controlnet": controlnet, | |
| "add_cond": add_cond_input, | |
| "nag_params": text_embeds.get("nag_params", {}), | |
| "nag_context": text_embeds.get("nag_prompt_embeds", None), | |
| } | |
| batch_size = 1 | |
| if not math.isclose(cfg_scale, 1.0) and len(positive_embeds) > 1: | |
| negative_embeds = negative_embeds * len(positive_embeds) | |
| if not batched_cfg: | |
| #cond | |
| noise_pred_cond, cache_state_cond = transformer( | |
| [z_pos], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, | |
| clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, | |
| pred_id=cache_state[0] if cache_state else None, | |
| vace_data=vace_data, attn_cond=attn_cond, | |
| **base_params | |
| ) | |
| noise_pred_cond = noise_pred_cond[0].to(intermediate_device) | |
| if math.isclose(cfg_scale, 1.0): | |
| if use_fresca: | |
| noise_pred_cond = fourier_filter( | |
| noise_pred_cond, | |
| scale_low=fresca_scale_low, | |
| scale_high=fresca_scale_high, | |
| freq_cutoff=fresca_freq_cutoff, | |
| ) | |
| return noise_pred_cond, [cache_state_cond] | |
| #uncond | |
| if fantasytalking_embeds is not None: | |
| if not math.isclose(audio_cfg_scale[idx], 1.0): | |
| base_params['audio_proj'] = None | |
| noise_pred_uncond, cache_state_uncond = transformer( | |
| [z_neg], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, | |
| y=[image_cond_input] if image_cond_input is not None else None, | |
| is_uncond=True, current_step_percentage=current_step_percentage, | |
| pred_id=cache_state[1] if cache_state else None, | |
| vace_data=vace_data, attn_cond=attn_cond_neg, | |
| **base_params | |
| ) | |
| noise_pred_uncond = noise_pred_uncond[0].to(intermediate_device) | |
| #phantom | |
| if use_phantom and not math.isclose(phantom_cfg_scale[idx], 1.0): | |
| noise_pred_phantom, cache_state_phantom = transformer( | |
| [z_phantom_img], context=negative_embeds, clip_fea=clip_fea_neg if clip_fea_neg is not None else clip_fea, | |
| y=[image_cond_input] if image_cond_input is not None else None, | |
| is_uncond=True, current_step_percentage=current_step_percentage, | |
| pred_id=cache_state[2] if cache_state else None, | |
| vace_data=None, | |
| **base_params | |
| ) | |
| noise_pred_phantom = noise_pred_phantom[0].to(intermediate_device) | |
| noise_pred = noise_pred_uncond + phantom_cfg_scale[idx] * (noise_pred_phantom - noise_pred_uncond) + cfg_scale * (noise_pred_cond - noise_pred_phantom) | |
| return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_phantom] | |
| #fantasytalking | |
| if fantasytalking_embeds is not None: | |
| if not math.isclose(audio_cfg_scale[idx], 1.0): | |
| if cache_state is not None and len(cache_state) != 3: | |
| cache_state.append(None) | |
| base_params['audio_proj'] = None | |
| noise_pred_no_audio, cache_state_audio = transformer( | |
| [z_pos], context=positive_embeds, y=[image_cond_input] if image_cond_input is not None else None, | |
| clip_fea=clip_fea, is_uncond=False, current_step_percentage=current_step_percentage, | |
| pred_id=cache_state[2] if cache_state else None, | |
| vace_data=vace_data, | |
| **base_params | |
| ) | |
| noise_pred_no_audio = noise_pred_no_audio[0].to(intermediate_device) | |
| noise_pred = ( | |
| noise_pred_uncond | |
| + cfg_scale * (noise_pred_no_audio - noise_pred_uncond) | |
| + audio_cfg_scale[idx] * (noise_pred_cond - noise_pred_no_audio) | |
| ) | |
| return noise_pred, [cache_state_cond, cache_state_uncond, cache_state_audio] | |
| #batched | |
| else: | |
| cache_state_uncond = None | |
| [noise_pred_cond, noise_pred_uncond], cache_state_cond = transformer( | |
| [z] + [z], context=positive_embeds + negative_embeds, | |
| y=[image_cond_input] + [image_cond_input] if image_cond_input is not None else None, | |
| clip_fea=clip_fea.repeat(2,1,1), is_uncond=False, current_step_percentage=current_step_percentage, | |
| pred_id=cache_state[0] if cache_state else None, | |
| **base_params | |
| ) | |
| #cfg | |
| #https://github.com/WeichenFan/CFG-Zero-star/ | |
| if use_cfg_zero_star: | |
| alpha = optimized_scale( | |
| noise_pred_cond.view(batch_size, -1), | |
| noise_pred_uncond.view(batch_size, -1) | |
| ).view(batch_size, 1, 1, 1) | |
| else: | |
| alpha = 1.0 | |
| #https://github.com/WikiChao/FreSca | |
| if use_fresca: | |
| filtered_cond = fourier_filter( | |
| noise_pred_cond - noise_pred_uncond, | |
| scale_low=fresca_scale_low, | |
| scale_high=fresca_scale_high, | |
| freq_cutoff=fresca_freq_cutoff, | |
| ) | |
| noise_pred = noise_pred_uncond * alpha + cfg_scale * filtered_cond * alpha | |
| else: | |
| noise_pred = noise_pred_uncond * alpha + cfg_scale * (noise_pred_cond - noise_pred_uncond * alpha) | |
| return noise_pred, [cache_state_cond, cache_state_uncond] | |
| log.info(f"Sampling {(latent_video_length-1) * 4 + 1} frames at {latent.shape[3]*8}x{latent.shape[2]*8} with {steps} steps") | |
| intermediate_device = device | |
| # diff diff prep | |
| masks = None | |
| if samples is not None and mask is not None: | |
| mask = 1 - mask | |
| thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps) | |
| thresholds = thresholds.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).to(device) | |
| masks = mask.repeat(len(timesteps), 1, 1, 1, 1).to(device) | |
| masks = masks > thresholds | |
| latent_shift_loop = False | |
| if loop_args is not None: | |
| latent_shift_loop = True | |
| is_looped = True | |
| latent_skip = loop_args["shift_skip"] | |
| latent_shift_start_percent = loop_args["start_percent"] | |
| latent_shift_end_percent = loop_args["end_percent"] | |
| shift_idx = 0 | |
| #clear memory before sampling | |
| mm.unload_all_models() | |
| mm.soft_empty_cache() | |
| gc.collect() | |
| try: | |
| torch.cuda.reset_peak_memory_stats(device) | |
| except: | |
| pass | |
| #region main loop start | |
| for idx, t in enumerate(tqdm(timesteps)): | |
| if flowedit_args is not None: | |
| if idx < skip_steps: | |
| continue | |
| # diff diff | |
| if masks is not None: | |
| if idx < len(timesteps) - 1: | |
| noise_timestep = timesteps[idx+1] | |
| image_latent = sample_scheduler.scale_noise( | |
| original_image, torch.tensor([noise_timestep]), noise.to(device) | |
| ) | |
| mask = masks[idx] | |
| mask = mask.to(latent) | |
| latent = image_latent * mask + latent * (1-mask) | |
| # end diff diff | |
| latent_model_input = latent.to(device) | |
| timestep = torch.tensor([t]).to(device) | |
| current_step_percentage = idx / len(timesteps) | |
| ### latent shift | |
| if latent_shift_loop: | |
| if latent_shift_start_percent <= current_step_percentage <= latent_shift_end_percent: | |
| latent_model_input = torch.cat([latent_model_input[:, shift_idx:]] + [latent_model_input[:, :shift_idx]], dim=1) | |
| #enhance-a-video | |
| if feta_args is not None and feta_start_percent <= current_step_percentage <= feta_end_percent: | |
| enable_enhance() | |
| else: | |
| disable_enhance() | |
| #flow-edit | |
| if flowedit_args is not None: | |
| sigma = t / 1000.0 | |
| sigma_prev = (timesteps[idx + 1] if idx < len(timesteps) - 1 else timesteps[-1]) / 1000.0 | |
| noise = torch.randn(x_init.shape, generator=seed_g, device=torch.device("cpu")) | |
| if idx < len(timesteps) - drift_steps: | |
| cfg = drift_cfg | |
| zt_src = (1-sigma) * x_init + sigma * noise.to(t) | |
| zt_tgt = x_tgt + zt_src - x_init | |
| #source | |
| if idx < len(timesteps) - drift_steps: | |
| if context_options is not None: | |
| counter = torch.zeros_like(zt_src, device=intermediate_device) | |
| vt_src = torch.zeros_like(zt_src, device=intermediate_device) | |
| context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) | |
| for c in context_queue: | |
| window_id = self.window_tracker.get_window_id(c) | |
| if cache_args is not None: | |
| current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) | |
| else: | |
| current_teacache = None | |
| prompt_index = min(int(max(c) / section_size), num_prompts - 1) | |
| if context_options["verbose"]: | |
| log.info(f"Prompt index: {prompt_index}") | |
| if len(source_embeds["prompt_embeds"]) > 1: | |
| positive = source_embeds["prompt_embeds"][prompt_index] | |
| else: | |
| positive = source_embeds["prompt_embeds"] | |
| partial_img_emb = None | |
| if source_image_cond is not None: | |
| partial_img_emb = source_image_cond[:, c, :, :] | |
| partial_img_emb[:, 0, :, :] = source_image_cond[:, 0, :, :].to(intermediate_device) | |
| partial_zt_src = zt_src[:, c, :, :] | |
| vt_src_context, new_teacache = predict_with_cfg( | |
| partial_zt_src, cfg[idx], | |
| positive, source_embeds["negative_prompt_embeds"], | |
| timestep, idx, partial_img_emb, control_latents, | |
| source_clip_fea, current_teacache) | |
| if cache_args is not None: | |
| self.window_tracker.cache_states[window_id] = new_teacache | |
| window_mask = create_window_mask(vt_src_context, c, latent_video_length, context_overlap) | |
| vt_src[:, c, :, :] += vt_src_context * window_mask | |
| counter[:, c, :, :] += window_mask | |
| vt_src /= counter | |
| else: | |
| vt_src, self.cache_state_source = predict_with_cfg( | |
| zt_src, cfg[idx], | |
| source_embeds["prompt_embeds"], | |
| source_embeds["negative_prompt_embeds"], | |
| timestep, idx, source_image_cond, | |
| source_clip_fea, control_latents, | |
| cache_state=self.cache_state_source) | |
| else: | |
| if idx == len(timesteps) - drift_steps: | |
| x_tgt = zt_tgt | |
| zt_tgt = x_tgt | |
| vt_src = 0 | |
| #target | |
| if context_options is not None: | |
| counter = torch.zeros_like(zt_tgt, device=intermediate_device) | |
| vt_tgt = torch.zeros_like(zt_tgt, device=intermediate_device) | |
| context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) | |
| for c in context_queue: | |
| window_id = self.window_tracker.get_window_id(c) | |
| if cache_args is not None: | |
| current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) | |
| else: | |
| current_teacache = None | |
| prompt_index = min(int(max(c) / section_size), num_prompts - 1) | |
| if context_options["verbose"]: | |
| log.info(f"Prompt index: {prompt_index}") | |
| if len(text_embeds["prompt_embeds"]) > 1: | |
| positive = text_embeds["prompt_embeds"][prompt_index] | |
| else: | |
| positive = text_embeds["prompt_embeds"] | |
| partial_img_emb = None | |
| partial_control_latents = None | |
| if image_cond is not None: | |
| partial_img_emb = image_cond[:, c, :, :] | |
| partial_img_emb[:, 0, :, :] = image_cond[:, 0, :, :].to(intermediate_device) | |
| if control_latents is not None: | |
| partial_control_latents = control_latents[:, c, :, :] | |
| partial_zt_tgt = zt_tgt[:, c, :, :] | |
| vt_tgt_context, new_teacache = predict_with_cfg( | |
| partial_zt_tgt, cfg[idx], | |
| positive, text_embeds["negative_prompt_embeds"], | |
| timestep, idx, partial_img_emb, partial_control_latents, | |
| clip_fea, current_teacache) | |
| if cache_args is not None: | |
| self.window_tracker.cache_states[window_id] = new_teacache | |
| window_mask = create_window_mask(vt_tgt_context, c, latent_video_length, context_overlap) | |
| vt_tgt[:, c, :, :] += vt_tgt_context * window_mask | |
| counter[:, c, :, :] += window_mask | |
| vt_tgt /= counter | |
| else: | |
| vt_tgt, self.cache_state = predict_with_cfg( | |
| zt_tgt, cfg[idx], | |
| text_embeds["prompt_embeds"], | |
| text_embeds["negative_prompt_embeds"], | |
| timestep, idx, image_cond, clip_fea, control_latents, | |
| cache_state=self.cache_state) | |
| v_delta = vt_tgt - vt_src | |
| x_tgt = x_tgt.to(torch.float32) | |
| v_delta = v_delta.to(torch.float32) | |
| x_tgt = x_tgt + (sigma_prev - sigma) * v_delta | |
| x0 = x_tgt | |
| #context windowing | |
| elif context_options is not None: | |
| counter = torch.zeros_like(latent_model_input, device=intermediate_device) | |
| noise_pred = torch.zeros_like(latent_model_input, device=intermediate_device) | |
| context_queue = list(context(idx, steps, latent_video_length, context_frames, context_stride, context_overlap)) | |
| for c in context_queue: | |
| window_id = self.window_tracker.get_window_id(c) | |
| if cache_args is not None: | |
| current_teacache = self.window_tracker.get_teacache(window_id, self.cache_state) | |
| else: | |
| current_teacache = None | |
| prompt_index = min(int(max(c) / section_size), num_prompts - 1) | |
| if context_options["verbose"]: | |
| log.info(f"Prompt index: {prompt_index}") | |
| # Use the appropriate prompt for this section | |
| if len(text_embeds["prompt_embeds"]) > 1: | |
| positive = text_embeds["prompt_embeds"][prompt_index] | |
| else: | |
| positive = text_embeds["prompt_embeds"] | |
| partial_img_emb = None | |
| partial_control_latents = None | |
| if image_cond is not None: | |
| partial_img_emb = image_cond[:, c] | |
| partial_img_emb[:, 0] = image_cond[:, 0].to(intermediate_device) | |
| if control_latents is not None: | |
| partial_control_latents = control_latents[:, c] | |
| partial_control_camera_latents = None | |
| if control_camera_latents is not None: | |
| partial_control_camera_latents = control_camera_latents[:, :, c] | |
| partial_vace_context = None | |
| if vace_data is not None: | |
| window_vace_data = [] | |
| for vace_entry in vace_data: | |
| partial_context = vace_entry["context"][0][:, c] | |
| if has_ref: | |
| partial_context[:, 0] = vace_entry["context"][0][:, 0] | |
| window_vace_data.append({ | |
| "context": [partial_context], | |
| "scale": vace_entry["scale"], | |
| "start": vace_entry["start"], | |
| "end": vace_entry["end"], | |
| "seq_len": vace_entry["seq_len"] | |
| }) | |
| partial_vace_context = window_vace_data | |
| partial_audio_proj = None | |
| if fantasytalking_embeds is not None: | |
| partial_audio_proj = audio_proj[:, c] | |
| partial_latent_model_input = latent_model_input[:, c] | |
| partial_unianim_data = None | |
| if unianim_data is not None: | |
| partial_dwpose = dwpose_data[:, :, c] | |
| partial_dwpose_flat=rearrange(partial_dwpose, 'b c f h w -> b (f h w) c') | |
| partial_unianim_data = { | |
| "dwpose": partial_dwpose_flat, | |
| "random_ref": unianim_data["random_ref"], | |
| "strength": unianimate_poses["strength"], | |
| "start_percent": unianimate_poses["start_percent"], | |
| "end_percent": unianimate_poses["end_percent"] | |
| } | |
| partial_add_cond = None | |
| if add_cond is not None: | |
| partial_add_cond = add_cond[:, :, c].to(device, dtype) | |
| noise_pred_context, new_teacache = predict_with_cfg( | |
| partial_latent_model_input, | |
| cfg[idx], positive, | |
| text_embeds["negative_prompt_embeds"], | |
| timestep, idx, partial_img_emb, clip_fea, partial_control_latents, partial_vace_context, partial_unianim_data,partial_audio_proj, | |
| partial_control_camera_latents, partial_add_cond, | |
| current_teacache) | |
| if cache_args is not None: | |
| self.window_tracker.cache_states[window_id] = new_teacache | |
| window_mask = create_window_mask(noise_pred_context, c, latent_video_length, context_overlap, looped=is_looped) | |
| noise_pred[:, c] += noise_pred_context * window_mask | |
| counter[:, c] += window_mask | |
| noise_pred /= counter | |
| #region normal inference | |
| else: | |
| noise_pred, self.cache_state = predict_with_cfg( | |
| latent_model_input, | |
| cfg[idx], | |
| text_embeds["prompt_embeds"], | |
| text_embeds["negative_prompt_embeds"], | |
| timestep, idx, image_cond, clip_fea, control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond, | |
| cache_state=self.cache_state) | |
| if latent_shift_loop: | |
| #reverse latent shift | |
| if latent_shift_start_percent <= current_step_percentage <= latent_shift_end_percent: | |
| noise_pred = torch.cat([noise_pred[:, latent_video_length - shift_idx:]] + [noise_pred[:, :latent_video_length - shift_idx]], dim=1) | |
| shift_idx = (shift_idx + latent_skip) % latent_video_length | |
| if flowedit_args is None: | |
| latent = latent.to(intermediate_device) | |
| step_args = { | |
| "generator": seed_g, | |
| } | |
| if isinstance(sample_scheduler, DEISMultistepScheduler) or isinstance(sample_scheduler, FlowMatchScheduler): | |
| step_args.pop("generator", None) | |
| temp_x0 = sample_scheduler.step( | |
| noise_pred[:, :orig_noise_len].unsqueeze(0) if recammaster is not None else noise_pred.unsqueeze(0), | |
| t, | |
| latent[:, :orig_noise_len].unsqueeze(0) if recammaster is not None else latent.unsqueeze(0), | |
| #return_dict=False, | |
| **step_args)[0] | |
| latent = temp_x0.squeeze(0) | |
| x0 = latent.to(device) | |
| if callback is not None: | |
| if recammaster is not None: | |
| callback_latent = (latent_model_input[:, :orig_noise_len].to(device) - noise_pred[:, :orig_noise_len].to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) | |
| elif phantom_latents is not None: | |
| callback_latent = (latent_model_input[:,:-phantom_latents.shape[1]].to(device) - noise_pred[:,:-phantom_latents.shape[1]].to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) | |
| else: | |
| callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) | |
| callback(idx, callback_latent, None, steps) | |
| else: | |
| pbar.update(1) | |
| del latent_model_input, timestep | |
| else: | |
| if callback is not None: | |
| callback_latent = (zt_tgt.to(device) - vt_tgt.to(device) * t.to(device) / 1000).detach().permute(1,0,2,3) | |
| callback(idx, callback_latent, None, steps) | |
| else: | |
| pbar.update(1) | |
| if phantom_latents is not None: | |
| x0 = x0[:,:-phantom_latents.shape[1]] | |
| if cache_args is not None: | |
| cache_type = cache_args["cache_type"] | |
| states = transformer.teacache_state.states if cache_type == "TeaCache" else transformer.magcache_state.states | |
| state_names = { | |
| 0: "conditional", | |
| 1: "unconditional" | |
| } | |
| for pred_id, state in states.items(): | |
| name = state_names.get(pred_id, f"prediction_{pred_id}") | |
| if 'skipped_steps' in state: | |
| log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}") | |
| transformer.teacache_state.clear_all() | |
| transformer.magcache_state.clear_all() | |
| del states | |
| if force_offload: | |
| if model["manual_offloading"]: | |
| transformer.to(offload_device) | |
| mm.soft_empty_cache() | |
| gc.collect() | |
| try: | |
| print_memory(device) | |
| torch.cuda.reset_peak_memory_stats(device) | |
| except: | |
| pass | |
| return ({ | |
| "samples": x0.unsqueeze(0).cpu(), "looped": is_looped, "end_image": end_image if not fun_or_fl2v_model else None, "has_ref": has_ref, "drop_last": drop_last, | |
| }, ) | |
| class WindowTracker: | |
| def __init__(self, verbose=False): | |
| self.window_map = {} # Maps frame sequence to persistent ID | |
| self.next_id = 0 | |
| self.cache_states = {} # Maps persistent ID to teacache state | |
| self.verbose = verbose | |
| def get_window_id(self, frames): | |
| key = tuple(sorted(frames)) # Order-independent frame sequence | |
| if key not in self.window_map: | |
| self.window_map[key] = self.next_id | |
| if self.verbose: | |
| log.info(f"New window pattern {key} -> ID {self.next_id}") | |
| self.next_id += 1 | |
| return self.window_map[key] | |
| def get_teacache(self, window_id, base_state): | |
| if window_id not in self.cache_states: | |
| if self.verbose: | |
| log.info(f"Initializing persistent teacache for window {window_id}") | |
| self.cache_states[window_id] = base_state.copy() | |
| return self.cache_states[window_id] | |
| #region VideoDecode | |
| class WanVideoDecode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "vae": ("WANVAE",), | |
| "samples": ("LATENT",), | |
| "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": ( | |
| "Drastically reduces memory use but will introduce seams at tile stride boundaries. " | |
| "The location and number of seams is dictated by the tile stride size. " | |
| "The visibility of seams can be controlled by increasing the tile size. " | |
| "Seams become less obvious at 1.5x stride and are barely noticeable at 2x stride size. " | |
| "Which is to say if you use a stride width of 160, the seams are barely noticeable with a tile width of 320." | |
| )}), | |
| "tile_x": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile width in pixels. Smaller values use less VRAM but will make seams more obvious."}), | |
| "tile_y": ("INT", {"default": 272, "min": 40, "max": 2048, "step": 8, "tooltip": "Tile height in pixels. Smaller values use less VRAM but will make seams more obvious."}), | |
| "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride width in pixels. Smaller values use less VRAM but will introduce more seams."}), | |
| "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2040, "step": 8, "tooltip": "Tile stride height in pixels. Smaller values use less VRAM but will introduce more seams."}), | |
| }, | |
| } | |
| def VALIDATE_INPUTS(s, tile_x, tile_y, tile_stride_x, tile_stride_y): | |
| if tile_x <= tile_stride_x: | |
| return "Tile width must be larger than the tile stride width." | |
| if tile_y <= tile_stride_y: | |
| return "Tile height must be larger than the tile stride height." | |
| return True | |
| RETURN_TYPES = ("IMAGE",) | |
| RETURN_NAMES = ("images",) | |
| FUNCTION = "decode" | |
| CATEGORY = "WanVideoWrapper" | |
| def decode(self, vae, samples, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| mm.soft_empty_cache() | |
| latents = samples["samples"] | |
| end_image = samples.get("end_image", None) | |
| has_ref = samples.get("has_ref", False) | |
| drop_last = samples.get("drop_last", False) | |
| is_looped = samples.get("looped", False) | |
| vae.to(device) | |
| latents = latents.to(device = device, dtype = vae.dtype) | |
| mm.soft_empty_cache() | |
| if has_ref: | |
| latents = latents[:, :, 1:] | |
| if drop_last: | |
| latents = latents[:, :, :-1] | |
| #if is_looped: | |
| # latents = torch.cat([latents[:, :, :warmup_latent_count],latents], dim=2) | |
| if type(vae).__name__ == "TAEHV": | |
| images = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3) | |
| images = torch.clamp(images, 0.0, 1.0) | |
| images = images.permute(1, 2, 3, 0).cpu().float() | |
| return (images,) | |
| else: | |
| if end_image is not None: | |
| enable_vae_tiling = False | |
| images = vae.decode(latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8))[0] | |
| vae.model.clear_cache() | |
| #images = (images - images.min()) / (images.max() - images.min()) | |
| images = torch.clamp(images, -1.0, 1.0) | |
| images = (images + 1.0) / 2.0 | |
| if is_looped: | |
| #images = images[:, warmup_latent_count * 4:] | |
| temp_latents = torch.cat([latents[:, :, -3:]] + [latents[:, :, :2]], dim=2) | |
| temp_images = vae.decode(temp_latents, device=device, end_=(end_image is not None), tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8))[0] | |
| temp_images = (temp_images - temp_images.min()) / (temp_images.max() - temp_images.min()) | |
| out = temp_images[:, 9:] | |
| out = torch.cat([out, images[:, 5:]], dim=1) | |
| images = out | |
| if end_image is not None: | |
| #end_image = (end_image - end_image.min()) / (end_image.max() - end_image.min()) | |
| #image[:, -1] = end_image[:, 0].to(image) #not sure about this | |
| images = images[:, 0:-1] | |
| vae.model.clear_cache() | |
| vae.to(offload_device) | |
| mm.soft_empty_cache() | |
| images = torch.clamp(images, 0.0, 1.0) | |
| images = images.permute(1, 2, 3, 0).cpu().float() | |
| return (images,) | |
| #region VideoEncode | |
| class WanVideoEncode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "vae": ("WANVAE",), | |
| "image": ("IMAGE",), | |
| "enable_vae_tiling": ("BOOLEAN", {"default": False, "tooltip": "Drastically reduces memory use but may introduce seams"}), | |
| "tile_x": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), | |
| "tile_y": ("INT", {"default": 272, "min": 64, "max": 2048, "step": 1, "tooltip": "Tile size in pixels, smaller values use less VRAM, may introduce more seams"}), | |
| "tile_stride_x": ("INT", {"default": 144, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), | |
| "tile_stride_y": ("INT", {"default": 128, "min": 32, "max": 2048, "step": 32, "tooltip": "Tile stride in pixels, smaller values use less VRAM, may introduce more seams"}), | |
| }, | |
| "optional": { | |
| "noise_aug_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Strength of noise augmentation, helpful for leapfusion I2V where some noise can add motion and give sharper results"}), | |
| "latent_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional latent multiplier, helpful for leapfusion I2V where lower values allow for more motion"}), | |
| "mask": ("MASK", ), | |
| } | |
| } | |
| RETURN_TYPES = ("LATENT",) | |
| RETURN_NAMES = ("samples",) | |
| FUNCTION = "encode" | |
| CATEGORY = "WanVideoWrapper" | |
| def encode(self, vae, image, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, noise_aug_strength=0.0, latent_strength=1.0, mask=None): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| vae.to(device) | |
| image = image.clone() | |
| B, H, W, C = image.shape | |
| if W % 16 != 0 or H % 16 != 0: | |
| new_height = (H // 16) * 16 | |
| new_width = (W // 16) * 16 | |
| log.warning(f"Image size {W}x{H} is not divisible by 16, resizing to {new_width}x{new_height}") | |
| image = common_upscale(image.movedim(-1, 1), new_width, new_height, "lanczos", "disabled").movedim(1, -1) | |
| image = image.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W | |
| if noise_aug_strength > 0.0: | |
| image = add_noise_to_reference_video(image, ratio=noise_aug_strength) | |
| if isinstance(vae, TAEHV): | |
| latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False)# B, T, C, H, W | |
| latents = latents.permute(0, 2, 1, 3, 4) | |
| else: | |
| latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x//8, tile_y//8), tile_stride=(tile_stride_x//8, tile_stride_y//8)) | |
| vae.model.clear_cache() | |
| if latent_strength != 1.0: | |
| latents *= latent_strength | |
| log.info(f"encoded latents shape {latents.shape}") | |
| latent_mask = None | |
| if mask is None: | |
| vae.to(offload_device) | |
| else: | |
| #latent_mask = mask.clone().to(vae.dtype).to(device) * 2.0 - 1.0 | |
| #latent_mask = latent_mask.unsqueeze(0).unsqueeze(0).repeat(1, 3, 1, 1, 1) | |
| #latent_mask = vae.encode(latent_mask, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y)) | |
| target_h, target_w = latents.shape[3:] | |
| mask = torch.nn.functional.interpolate( | |
| mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W] | |
| size=(latents.shape[2], target_h, target_w), | |
| mode='trilinear', | |
| align_corners=False | |
| ).squeeze(0) # Remove batch dim, keep channel dim | |
| # Add batch & channel dims for final output | |
| latent_mask = mask.unsqueeze(0).repeat(1, latents.shape[1], 1, 1, 1) | |
| log.info(f"latent mask shape {latent_mask.shape}") | |
| vae.to(offload_device) | |
| mm.soft_empty_cache() | |
| return ({"samples": latents, "mask": latent_mask},) | |
| NODE_CLASS_MAPPINGS = { | |
| "WanVideoSampler": WanVideoSampler, | |
| "WanVideoDecode": WanVideoDecode, | |
| "WanVideoTextEncode": WanVideoTextEncode, | |
| "WanVideoTextEncodeSingle": WanVideoTextEncodeSingle, | |
| "WanVideoModelLoader": WanVideoModelLoader, | |
| "WanVideoVAELoader": WanVideoVAELoader, | |
| "LoadWanVideoT5TextEncoder": LoadWanVideoT5TextEncoder, | |
| "WanVideoImageClipEncode": WanVideoImageClipEncode,#deprecated | |
| "WanVideoClipVisionEncode": WanVideoClipVisionEncode, | |
| "WanVideoImageToVideoEncode": WanVideoImageToVideoEncode, | |
| "LoadWanVideoClipTextEncoder": LoadWanVideoClipTextEncoder, | |
| "WanVideoEncode": WanVideoEncode, | |
| "WanVideoBlockSwap": WanVideoBlockSwap, | |
| "WanVideoTorchCompileSettings": WanVideoTorchCompileSettings, | |
| "WanVideoEmptyEmbeds": WanVideoEmptyEmbeds, | |
| "WanVideoLoraSelect": WanVideoLoraSelect, | |
| "WanVideoLoraBlockEdit": WanVideoLoraBlockEdit, | |
| "WanVideoEnhanceAVideo": WanVideoEnhanceAVideo, | |
| "WanVideoContextOptions": WanVideoContextOptions, | |
| "WanVideoTeaCache": WanVideoTeaCache, | |
| "WanVideoMagCache": WanVideoMagCache, | |
| "WanVideoVRAMManagement": WanVideoVRAMManagement, | |
| "WanVideoTextEmbedBridge": WanVideoTextEmbedBridge, | |
| "WanVideoFlowEdit": WanVideoFlowEdit, | |
| "WanVideoControlEmbeds": WanVideoControlEmbeds, | |
| "WanVideoSLG": WanVideoSLG, | |
| "WanVideoTinyVAELoader": WanVideoTinyVAELoader, | |
| "WanVideoLoopArgs": WanVideoLoopArgs, | |
| "WanVideoImageResizeToClosest": WanVideoImageResizeToClosest, | |
| "WanVideoSetBlockSwap": WanVideoSetBlockSwap, | |
| "WanVideoExperimentalArgs": WanVideoExperimentalArgs, | |
| "WanVideoVACEEncode": WanVideoVACEEncode, | |
| "WanVideoVACEStartToEndFrame": WanVideoVACEStartToEndFrame, | |
| "WanVideoVACEModelSelect": WanVideoVACEModelSelect, | |
| "WanVideoPhantomEmbeds": WanVideoPhantomEmbeds, | |
| "CreateCFGScheduleFloatList": CreateCFGScheduleFloatList, | |
| "WanVideoRealisDanceLatents": WanVideoRealisDanceLatents, | |
| "WanVideoApplyNAG": WanVideoApplyNAG, | |
| "WanVideoMiniMaxRemoverEmbeds": WanVideoMiniMaxRemoverEmbeds, | |
| "WanVideoLoraSelectMulti": WanVideoLoraSelectMulti | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "WanVideoSampler": "WanVideo Sampler", | |
| "WanVideoDecode": "WanVideo Decode", | |
| "WanVideoTextEncode": "WanVideo TextEncode", | |
| "WanVideoTextEncodeSingle": "WanVideo TextEncodeSingle", | |
| "WanVideoTextImageEncode": "WanVideo TextImageEncode (IP2V)", | |
| "WanVideoModelLoader": "WanVideo Model Loader", | |
| "WanVideoVAELoader": "WanVideo VAE Loader", | |
| "LoadWanVideoT5TextEncoder": "Load WanVideo T5 TextEncoder", | |
| "WanVideoImageClipEncode": "WanVideo ImageClip Encode (Deprecated)", | |
| "WanVideoClipVisionEncode": "WanVideo ClipVision Encode", | |
| "WanVideoImageToVideoEncode": "WanVideo ImageToVideo Encode", | |
| "LoadWanVideoClipTextEncoder": "Load WanVideo Clip Encoder", | |
| "WanVideoEncode": "WanVideo Encode", | |
| "WanVideoBlockSwap": "WanVideo BlockSwap", | |
| "WanVideoTorchCompileSettings": "WanVideo Torch Compile Settings", | |
| "WanVideoEmptyEmbeds": "WanVideo Empty Embeds", | |
| "WanVideoLoraSelect": "WanVideo Lora Select", | |
| "WanVideoLoraBlockEdit": "WanVideo Lora Block Edit", | |
| "WanVideoEnhanceAVideo": "WanVideo Enhance-A-Video", | |
| "WanVideoContextOptions": "WanVideo Context Options", | |
| "WanVideoTeaCache": "WanVideo TeaCache", | |
| "WanVideoMagCache": "WanVideo MagCache", | |
| "WanVideoVRAMManagement": "WanVideo VRAM Management", | |
| "WanVideoTextEmbedBridge": "WanVideo TextEmbed Bridge", | |
| "WanVideoFlowEdit": "WanVideo FlowEdit", | |
| "WanVideoControlEmbeds": "WanVideo Control Embeds", | |
| "WanVideoSLG": "WanVideo SLG", | |
| "WanVideoTinyVAELoader": "WanVideo Tiny VAE Loader", | |
| "WanVideoLoopArgs": "WanVideo Loop Args", | |
| "WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest", | |
| "WanVideoSetBlockSwap": "WanVideo Set BlockSwap", | |
| "WanVideoExperimentalArgs": "WanVideo Experimental Args", | |
| "WanVideoVACEEncode": "WanVideo VACE Encode", | |
| "WanVideoVACEStartToEndFrame": "WanVideo VACE Start To End Frame", | |
| "WanVideoVACEModelSelect": "WanVideo VACE Model Select", | |
| "WanVideoPhantomEmbeds": "WanVideo Phantom Embeds", | |
| "CreateCFGScheduleFloatList": "WanVideo CFG Schedule Float List", | |
| "WanVideoRealisDanceLatents": "WanVideo RealisDance Latents", | |
| "WanVideoApplyNAG": "WanVideo Apply NAG", | |
| "WanVideoMiniMaxRemoverEmbeds": "WanVideo MiniMax Remover Embeds", | |
| "WanVideoLoraSelectMulti": "WanVideo Lora Select Multi" | |
| } | |