QingyanBai's picture
Upload 750 files
a42ebba verified
import importlib.metadata
import torch
import logging
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
from accelerate.utils import set_module_tensor_to_device
def check_diffusers_version():
try:
version = importlib.metadata.version('diffusers')
required_version = '0.31.0'
if version < required_version:
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.")
except importlib.metadata.PackageNotFoundError:
raise AssertionError("diffusers is not installed.")
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
log.info(f"Allocated memory: {memory=:.3f} GB")
log.info(f"Max allocated memory: {max_memory=:.3f} GB")
log.info(f"Max reserved memory: {max_reserved=:.3f} GB")
#memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False)
#log.info(f"Memory Summary:\n{memory_summary}")
def get_module_memory_mb(module):
memory = 0
for param in module.parameters():
if param.data is not None:
memory += param.nelement() * param.element_size()
return memory / (1024 * 1024) # Convert to MB
def apply_lora(model, device_to, transformer_load_device, params_to_keep=None, dtype=None, base_dtype=None, state_dict=None, low_mem_load=False):
to_load = []
for n, m in model.model.named_modules():
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
for name, param in m.named_parameters(recurse=True):
if name not in params:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
to_load.append((n, m, params))
to_load.sort(reverse=True)
for x in tqdm(to_load, desc="Loading model and applying LoRA weights:", leave=True):
name = x[0]
m = x[1]
params = x[2]
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
for param in params:
name = name.replace("._orig_mod.", ".") # torch compiled modules have this prefix
if low_mem_load:
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
if "patch_embedding" in name:
dtype_to_use = torch.float32
if name.startswith("diffusion_model."):
name_no_prefix = name[len("diffusion_model."):]
key = "{}.{}".format(name_no_prefix, param)
try:
set_module_tensor_to_device(model.model.diffusion_model, key, device=transformer_load_device, dtype=dtype_to_use, value=state_dict[key])
except:
continue
model.patch_weight_to_device("{}.{}".format(name, param), device_to=device_to)
if low_mem_load:
try:
set_module_tensor_to_device(model.model.diffusion_model, key, device=transformer_load_device, dtype=dtype_to_use, value=model.model.diffusion_model.state_dict()[key])
except:
continue
m.comfy_patched_weights = True
model.current_weight_patches_uuid = model.patches_uuid
if low_mem_load:
for name, param in model.model.diffusion_model.named_parameters():
if param.device != transformer_load_device:
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
if "patch_embedding" in name:
dtype_to_use = torch.float32
try:
set_module_tensor_to_device(model.model.diffusion_model, name, device=transformer_load_device, dtype=dtype_to_use, value=state_dict[name])
except:
continue
return model
# from https://github.com/cubiq/ComfyUI_IPAdapter_plus/blob/9d076a3df0d2763cef5510ec5ab807f6632c39f5/utils.py#L181
def split_tiles(embeds, num_split):
_, H, W, _ = embeds.shape
out = []
for x in embeds:
x = x.unsqueeze(0)
h, w = H // num_split, W // num_split
x_split = torch.cat([x[:, i*h:(i+1)*h, j*w:(j+1)*w, :] for i in range(num_split) for j in range(num_split)], dim=0)
out.append(x_split)
x_split = torch.stack(out, dim=0)
return x_split
def merge_hiddenstates(x, tiles):
chunk_size = tiles*tiles
x = x.split(chunk_size)
out = []
for embeds in x:
num_tiles = embeds.shape[0]
tile_size = int((embeds.shape[1]-1) ** 0.5)
grid_size = int(num_tiles ** 0.5)
# Extract class tokens
class_tokens = embeds[:, 0, :] # Save class tokens: [num_tiles, embeds[-1]]
avg_class_token = class_tokens.mean(dim=0, keepdim=True).unsqueeze(0) # Average token, shape: [1, 1, embeds[-1]]
patch_embeds = embeds[:, 1:, :] # Shape: [num_tiles, tile_size^2, embeds[-1]]
reshaped = patch_embeds.reshape(grid_size, grid_size, tile_size, tile_size, embeds.shape[-1])
merged = torch.cat([torch.cat([reshaped[i, j] for j in range(grid_size)], dim=1)
for i in range(grid_size)], dim=0)
merged = merged.unsqueeze(0) # Shape: [1, grid_size*tile_size, grid_size*tile_size, embeds[-1]]
# Pool to original size
pooled = torch.nn.functional.adaptive_avg_pool2d(merged.permute(0, 3, 1, 2), (tile_size, tile_size)).permute(0, 2, 3, 1)
flattened = pooled.reshape(1, tile_size*tile_size, embeds.shape[-1])
# Add back the class token
with_class = torch.cat([avg_class_token, flattened], dim=1) # Shape: original shape
out.append(with_class)
out = torch.cat(out, dim=0)
return out
from comfy.clip_vision import clip_preprocess, ClipVisionModel
def clip_encode_image_tiled(clip_vision, image, tiles=1, ratio=1.0):
embeds = encode_image_(clip_vision, image)
tiles = min(tiles, 16)
if tiles > 1:
# split in tiles
image_split = split_tiles(image, tiles)
# get the embeds for each tile
embeds_split = {}
for i in image_split:
encoded = encode_image_(clip_vision, i)
if not hasattr(embeds_split, "last_hidden_state"):
embeds_split["last_hidden_state"] = encoded
else:
embeds_split["last_hidden_state"] = torch.cat(embeds_split["last_hidden_state"], encoded, dim=0)
embeds_split['last_hidden_state'] = merge_hiddenstates(embeds_split['last_hidden_state'], tiles)
if embeds.shape[0] > 1: # if we have more than one image we need to average the embeddings for consistency
embeds = embeds * ratio + embeds_split['last_hidden_state']*(1-ratio)
else: # otherwise we can concatenate them, they can be averaged later
embeds = torch.cat([embeds * ratio, embeds_split['last_hidden_state']])
return embeds
def encode_image_(clip_vision, image):
if isinstance(clip_vision, ClipVisionModel):
out = clip_vision.encode_image(image).last_hidden_state
else:
pixel_values = clip_preprocess(image, size=224, crop=True).float()
out = clip_vision.visual(pixel_values)
return out
# Code based on https://github.com/WikiChao/FreSca (MIT License)
import torch
import torch.fft as fft
def fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
"""
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
Parameters:
x: Input tensor of shape (B, C, H, W)
scale_low: Scaling factor for low-frequency components (default: 1.0)
scale_high: Scaling factor for high-frequency components (default: 1.5)
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
Returns:
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
"""
# Preserve input dtype and device
dtype, device = x.dtype, x.device
# Convert to float32 for FFT computations
x = x.to(torch.float32)
# 1) Apply FFT and shift low frequencies to center
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
# 2) Create a mask to scale frequencies differently
C, B, H, W = x_freq.shape
crow, ccol = H // 2, W // 2
# Initialize mask with high-frequency scaling factor
mask = torch.ones((C, B, H, W), device=device) * scale_high
# Apply low-frequency scaling factor to center region
mask[
...,
crow - freq_cutoff : crow + freq_cutoff,
ccol - freq_cutoff : ccol + freq_cutoff,
] = scale_low
# 3) Apply frequency-specific scaling
x_freq = x_freq * mask
# 4) Convert back to spatial domain
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
# 5) Restore original dtype
x_filtered = x_filtered.to(dtype)
return x_filtered