Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
e64ed84
1
Parent(s):
454ba5e
add debug functions
Browse files- .gitignore +19 -0
- app.py +6 -4
- pyproject.toml +14 -0
- qwenimage/debug.py +178 -0
- optimization.py → qwenimage/optimization.py +3 -0
- qwenimage/pipeline_qwenimage_edit_plus.py +258 -246
- prompt.py → qwenimage/prompt.py +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.egg-info
|
| 2 |
+
.env
|
| 3 |
+
**/__pycache__/*
|
| 4 |
+
wandb/*
|
| 5 |
+
*.log
|
| 6 |
+
venv/*
|
| 7 |
+
.venv/*
|
| 8 |
+
keyfile
|
| 9 |
+
**/.ipynb_checkpoints/
|
| 10 |
+
**/.DS_Store/*
|
| 11 |
+
.idea/*
|
| 12 |
+
.vscode/*
|
| 13 |
+
latentanalysis_data/*
|
| 14 |
+
latentanalysis_scripts/*
|
| 15 |
+
test-images/*
|
| 16 |
+
weights/*
|
| 17 |
+
latentmask/latentanalysis/*
|
| 18 |
+
cache/*
|
| 19 |
+
docs/automated-documentation/*
|
app.py
CHANGED
|
@@ -13,11 +13,12 @@ from diffusers import FlowMatchEulerDiscreteScheduler
|
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
from safetensors.torch import load_file
|
| 15 |
|
| 16 |
-
from
|
|
|
|
| 17 |
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
| 18 |
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
|
| 19 |
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
|
| 20 |
-
from prompt import build_camera_prompt
|
| 21 |
|
| 22 |
# --- Model Loading ---
|
| 23 |
dtype = torch.bfloat16
|
|
@@ -53,7 +54,8 @@ optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB",
|
|
| 53 |
MAX_SEED = np.iinfo(np.int32).max
|
| 54 |
|
| 55 |
|
| 56 |
-
@spaces.GPU
|
|
|
|
| 57 |
def infer_camera_edit(
|
| 58 |
image,
|
| 59 |
rotate_deg,
|
|
@@ -111,7 +113,7 @@ css = '''#col-container { max-width: 800px; margin: 0 auto; }
|
|
| 111 |
#examples{max-width: 800px; margin: 0 auto; }'''
|
| 112 |
|
| 113 |
def reset_all():
|
| 114 |
-
return [0, 0, 0, 0, False
|
| 115 |
|
| 116 |
def end_reset():
|
| 117 |
return False
|
|
|
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
from safetensors.torch import load_file
|
| 15 |
|
| 16 |
+
from qwenimage.debug import ftimed
|
| 17 |
+
from qwenimage.optimization import optimize_pipeline_
|
| 18 |
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
| 19 |
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
|
| 20 |
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
|
| 21 |
+
from qwenimage.prompt import build_camera_prompt
|
| 22 |
|
| 23 |
# --- Model Loading ---
|
| 24 |
dtype = torch.bfloat16
|
|
|
|
| 54 |
MAX_SEED = np.iinfo(np.int32).max
|
| 55 |
|
| 56 |
|
| 57 |
+
# @spaces.GPU
|
| 58 |
+
@ftimed
|
| 59 |
def infer_camera_edit(
|
| 60 |
image,
|
| 61 |
rotate_deg,
|
|
|
|
| 113 |
#examples{max-width: 800px; margin: 0 auto; }'''
|
| 114 |
|
| 115 |
def reset_all():
|
| 116 |
+
return [0, 0, 0, 0, False]
|
| 117 |
|
| 118 |
def end_reset():
|
| 119 |
return False
|
pyproject.toml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "qwenimage"
|
| 7 |
+
version = "0.1"
|
| 8 |
+
|
| 9 |
+
[tool.setuptools.packages.find]
|
| 10 |
+
where = ["."]
|
| 11 |
+
include = ["qwenimage*"]
|
| 12 |
+
|
| 13 |
+
[tool.setuptools.dynamic]
|
| 14 |
+
dependencies = {file = ["requirements/requirements.txt"]}
|
qwenimage/debug.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import time
|
| 4 |
+
import uuid
|
| 5 |
+
import warnings
|
| 6 |
+
from functools import wraps
|
| 7 |
+
from typing import Callable, Literal
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import torch
|
| 12 |
+
from torchvision.utils import save_image
|
| 13 |
+
|
| 14 |
+
DEBUG = True
|
| 15 |
+
|
| 16 |
+
def ftimed(func=None):
|
| 17 |
+
|
| 18 |
+
def decorator(func):
|
| 19 |
+
@wraps(func)
|
| 20 |
+
def wrapper(*args, **kwargs):
|
| 21 |
+
if not DEBUG:
|
| 22 |
+
return func(*args, **kwargs)
|
| 23 |
+
else:
|
| 24 |
+
start_time = time.perf_counter()
|
| 25 |
+
result = func(*args, **kwargs)
|
| 26 |
+
end_time = time.perf_counter()
|
| 27 |
+
print(f"Time taken by {func.__qualname__}: {end_time - start_time} seconds")
|
| 28 |
+
return result
|
| 29 |
+
return wrapper
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if func is None:
|
| 33 |
+
return decorator
|
| 34 |
+
else:
|
| 35 |
+
return decorator(func)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ctimed:
|
| 39 |
+
"""
|
| 40 |
+
Context manager for timing lines of code. Use like:
|
| 41 |
+
```
|
| 42 |
+
with ctimed(name="Model Forward"):
|
| 43 |
+
y = model(x)
|
| 44 |
+
```
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, name=None):
|
| 47 |
+
self.name = name
|
| 48 |
+
self.start_time = None
|
| 49 |
+
|
| 50 |
+
def __enter__(self):
|
| 51 |
+
if DEBUG:
|
| 52 |
+
self.start_time = time.perf_counter()
|
| 53 |
+
return self
|
| 54 |
+
|
| 55 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 56 |
+
if DEBUG:
|
| 57 |
+
end_time = time.perf_counter()
|
| 58 |
+
if self.name:
|
| 59 |
+
print(f"Time taken by {self.name}: {end_time - self.start_time} seconds")
|
| 60 |
+
else:
|
| 61 |
+
print(f"Time taken: {end_time - self.start_time} seconds")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def print_gpu_memory(clear_mem: Literal["pre", "post", None] = "pre"):
|
| 65 |
+
if not torch.cuda.is_available():
|
| 66 |
+
warnings.warn("Warning: CUDA device not available. Running on CPU.")
|
| 67 |
+
return
|
| 68 |
+
if clear_mem == "pre":
|
| 69 |
+
torch.cuda.empty_cache()
|
| 70 |
+
allocated = torch.cuda.memory_allocated()
|
| 71 |
+
reserved = torch.cuda.memory_reserved()
|
| 72 |
+
total = torch.cuda.get_device_properties(0).total_memory
|
| 73 |
+
print(f"Memory allocated: {allocated / (1024**2):.2f} MB")
|
| 74 |
+
print(f"Memory reserved: {reserved / (1024**2):.2f} MB")
|
| 75 |
+
print(f"Total memory: {total / (1024**2):.2f} MB")
|
| 76 |
+
if clear_mem == "post":
|
| 77 |
+
torch.cuda.empty_cache()
|
| 78 |
+
|
| 79 |
+
def cuda_empty_cache(func):
|
| 80 |
+
def wrapper(*args, **kwargs):
|
| 81 |
+
result = func(*args, **kwargs)
|
| 82 |
+
if torch.cuda.is_available():
|
| 83 |
+
torch.cuda.empty_cache()
|
| 84 |
+
return result
|
| 85 |
+
return wrapper
|
| 86 |
+
|
| 87 |
+
def print_first_param(module):
|
| 88 |
+
print(list(module.parameters())[0])
|
| 89 |
+
|
| 90 |
+
def fdebug(func=None, *, exclude=None):
|
| 91 |
+
if exclude is None:
|
| 92 |
+
exclude = []
|
| 93 |
+
elif isinstance(exclude, str):
|
| 94 |
+
exclude = [exclude]
|
| 95 |
+
|
| 96 |
+
def decorator(func):
|
| 97 |
+
@wraps(func)
|
| 98 |
+
def wrapper(*args, **kwargs):
|
| 99 |
+
arg_names = func.__code__.co_varnames[:func.__code__.co_argcount]
|
| 100 |
+
arg_vals = args[:len(arg_names)]
|
| 101 |
+
arg_vals = [
|
| 102 |
+
(str(value)+str(value.shape) if isinstance(value, torch.Tensor) else value)
|
| 103 |
+
for value in arg_vals
|
| 104 |
+
]
|
| 105 |
+
args_pairs = ", ".join(f"{name}={value}" for name, value in zip(arg_names, arg_vals) if name not in exclude)
|
| 106 |
+
kwargs_pairs = ", ".join(f"{k}={v}" for k, v in kwargs.items() if k not in exclude)
|
| 107 |
+
all_args = ", ".join(filter(None, [args_pairs, kwargs_pairs]))
|
| 108 |
+
print(f"Calling {func.__name__}({all_args})")
|
| 109 |
+
result = func(*args, **kwargs)
|
| 110 |
+
print(f"{func.__name__} returned {str(result)+str(result.shape) if isinstance(result, torch.Tensor) else result}")
|
| 111 |
+
return result
|
| 112 |
+
return wrapper
|
| 113 |
+
|
| 114 |
+
if func is None:
|
| 115 |
+
return decorator
|
| 116 |
+
else:
|
| 117 |
+
return decorator(func)
|
| 118 |
+
|
| 119 |
+
class IncrementIndex:
|
| 120 |
+
def __init__(self, max:int=100):
|
| 121 |
+
self.retry_max = max
|
| 122 |
+
self.retries = 0
|
| 123 |
+
|
| 124 |
+
def __call__(self, index):
|
| 125 |
+
if self.retries > self.retry_max:
|
| 126 |
+
raise RuntimeError(f"Retried too many times, max:{self.retry_max}")
|
| 127 |
+
else:
|
| 128 |
+
self.retries += 1
|
| 129 |
+
index += 1
|
| 130 |
+
return index
|
| 131 |
+
|
| 132 |
+
_identity = lambda x: x
|
| 133 |
+
|
| 134 |
+
def fretry(func=None, *, exceptions=(Exception,), mod_args:tuple[Callable|None, ...]=tuple(), mod_kwargs:dict[str,Callable|None]=dict()):
|
| 135 |
+
def decorator(func):
|
| 136 |
+
@wraps(func)
|
| 137 |
+
def fretry_wrapper(*args, **kwargs):
|
| 138 |
+
try:
|
| 139 |
+
out = func(*args, **kwargs)
|
| 140 |
+
except exceptions as e:
|
| 141 |
+
new_args = []
|
| 142 |
+
for i, arg in enumerate(args):
|
| 143 |
+
if i < len(mod_args):
|
| 144 |
+
mod_func = mod_args[i] or _identity
|
| 145 |
+
new_args.append(mod_func(arg))
|
| 146 |
+
else:
|
| 147 |
+
new_args.append(arg)
|
| 148 |
+
new_kwargs = {}
|
| 149 |
+
for k, kwarg in kwargs.items():
|
| 150 |
+
if k in mod_kwargs:
|
| 151 |
+
mod_func = mod_kwargs[k] or _identity
|
| 152 |
+
new_kwargs[k] = mod_func(kwarg)
|
| 153 |
+
kwargs.update(new_kwargs)
|
| 154 |
+
|
| 155 |
+
import traceback
|
| 156 |
+
traceback.print_exc()
|
| 157 |
+
warnings.warn(
|
| 158 |
+
f"Function {func} failed due to {e} with inputs {args}, {kwargs}, "
|
| 159 |
+
f"retrying with modified inputs {new_args}, {new_kwargs}"
|
| 160 |
+
)
|
| 161 |
+
out = fretry_wrapper(*new_args, **new_kwargs)
|
| 162 |
+
return out
|
| 163 |
+
return fretry_wrapper
|
| 164 |
+
|
| 165 |
+
if func is None:
|
| 166 |
+
return decorator
|
| 167 |
+
else:
|
| 168 |
+
return decorator(func)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def texam(t: torch.Tensor):
|
| 172 |
+
print(f"Shape: {tuple(t.shape)}")
|
| 173 |
+
if t.dtype.is_floating_point or t.dtype.is_complex:
|
| 174 |
+
mean_val = t.mean().item()
|
| 175 |
+
else:
|
| 176 |
+
mean_val = "N/A"
|
| 177 |
+
print(f"Min: {t.min().item()}, Max: {t.max().item()}, Mean: {mean_val}")
|
| 178 |
+
print(f"Device: {t.device}, Dtype: {t.dtype}, Requires Grad: {t.requires_grad}")
|
optimization.py → qwenimage/optimization.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
"""
|
|
|
|
| 3 |
|
| 4 |
from typing import Any
|
| 5 |
from typing import Callable
|
|
@@ -66,5 +67,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 66 |
)
|
| 67 |
|
| 68 |
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
|
|
|
|
|
|
|
| 69 |
|
| 70 |
spaces.aoti_apply(compile_transformer(), pipeline.transformer)
|
|
|
|
| 1 |
"""
|
| 2 |
"""
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
from typing import Any
|
| 6 |
from typing import Callable
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
|
| 73 |
spaces.aoti_apply(compile_transformer(), pipeline.transformer)
|
qwenimage/pipeline_qwenimage_edit_plus.py
CHANGED
|
@@ -18,7 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
import torch
|
| 21 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 24 |
from diffusers.loaders import QwenImageLoraLoaderMixin
|
|
@@ -29,6 +32,8 @@ from diffusers.utils.torch_utils import randn_tensor
|
|
| 29 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 30 |
from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
|
| 31 |
|
|
|
|
|
|
|
| 32 |
|
| 33 |
if is_torch_xla_available():
|
| 34 |
import torch_xla.core.xla_model as xm
|
|
@@ -284,6 +289,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 284 |
return prompt_embeds, encoder_attention_mask
|
| 285 |
|
| 286 |
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
|
|
|
|
| 287 |
def encode_prompt(
|
| 288 |
self,
|
| 289 |
prompt: Union[str, List[str]],
|
|
@@ -627,265 +633,271 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 627 |
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
| 628 |
returning a tuple, the first element is a list with the generated images.
|
| 629 |
"""
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
|
|
|
| 652 |
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
device = self._execution_device
|
| 667 |
-
# 3. Preprocess image
|
| 668 |
-
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 669 |
-
if not isinstance(image, list):
|
| 670 |
-
image = [image]
|
| 671 |
-
condition_image_sizes = []
|
| 672 |
-
condition_images = []
|
| 673 |
-
vae_image_sizes = []
|
| 674 |
-
vae_images = []
|
| 675 |
-
for img in image:
|
| 676 |
-
image_width, image_height = img.size
|
| 677 |
-
condition_width, condition_height = calculate_dimensions(
|
| 678 |
-
CONDITION_IMAGE_SIZE, image_width / image_height
|
| 679 |
)
|
| 680 |
-
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
|
| 681 |
-
condition_image_sizes.append((condition_width, condition_height))
|
| 682 |
-
vae_image_sizes.append((vae_width, vae_height))
|
| 683 |
-
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
|
| 684 |
-
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
|
| 685 |
-
|
| 686 |
-
has_neg_prompt = negative_prompt is not None or (
|
| 687 |
-
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
| 688 |
-
)
|
| 689 |
-
|
| 690 |
-
if true_cfg_scale > 1 and not has_neg_prompt:
|
| 691 |
-
logger.warning(
|
| 692 |
-
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
| 693 |
-
)
|
| 694 |
-
elif true_cfg_scale <= 1 and has_neg_prompt:
|
| 695 |
-
logger.warning(
|
| 696 |
-
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
| 697 |
-
)
|
| 698 |
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
prompt=prompt,
|
| 703 |
-
prompt_embeds=prompt_embeds,
|
| 704 |
-
prompt_embeds_mask=prompt_embeds_mask,
|
| 705 |
-
device=device,
|
| 706 |
-
num_images_per_prompt=num_images_per_prompt,
|
| 707 |
-
max_sequence_length=max_sequence_length,
|
| 708 |
-
)
|
| 709 |
-
if do_true_cfg:
|
| 710 |
-
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
| 711 |
image=condition_images,
|
| 712 |
-
prompt=
|
| 713 |
-
prompt_embeds=
|
| 714 |
-
prompt_embeds_mask=
|
| 715 |
device=device,
|
| 716 |
num_images_per_prompt=num_images_per_prompt,
|
| 717 |
max_sequence_length=max_sequence_length,
|
| 718 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
img_shapes = [
|
| 734 |
-
[
|
| 735 |
-
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
| 736 |
-
*[
|
| 737 |
-
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
|
| 738 |
-
for vae_width, vae_height in vae_image_sizes
|
| 739 |
-
],
|
| 740 |
-
]
|
| 741 |
-
] * batch_size
|
| 742 |
-
|
| 743 |
-
# 5. Prepare timesteps
|
| 744 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 745 |
-
image_seq_len = latents.shape[1]
|
| 746 |
-
mu = calculate_shift(
|
| 747 |
-
image_seq_len,
|
| 748 |
-
self.scheduler.config.get("base_image_seq_len", 256),
|
| 749 |
-
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 750 |
-
self.scheduler.config.get("base_shift", 0.5),
|
| 751 |
-
self.scheduler.config.get("max_shift", 1.15),
|
| 752 |
-
)
|
| 753 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 754 |
-
self.scheduler,
|
| 755 |
-
num_inference_steps,
|
| 756 |
-
device,
|
| 757 |
-
sigmas=sigmas,
|
| 758 |
-
mu=mu,
|
| 759 |
-
)
|
| 760 |
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 761 |
-
self._num_timesteps = len(timesteps)
|
| 762 |
-
|
| 763 |
-
# handle guidance
|
| 764 |
-
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
| 765 |
-
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
| 766 |
-
elif self.transformer.config.guidance_embeds:
|
| 767 |
-
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 768 |
-
guidance = guidance.expand(latents.shape[0])
|
| 769 |
-
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
| 770 |
-
logger.warning(
|
| 771 |
-
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
| 772 |
)
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 788 |
)
|
| 789 |
-
|
| 790 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
)
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 884 |
|
| 885 |
-
|
| 886 |
-
|
| 887 |
|
| 888 |
-
|
| 889 |
-
|
| 890 |
|
| 891 |
return QwenImagePipelineOutput(images=image)
|
|
|
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
import torch
|
| 21 |
+
# from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
|
| 22 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
| 23 |
+
from transformers.models.qwen2 import Qwen2Tokenizer
|
| 24 |
+
from transformers.models.qwen2_vl import Qwen2VLProcessor
|
| 25 |
|
| 26 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 27 |
from diffusers.loaders import QwenImageLoraLoaderMixin
|
|
|
|
| 32 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 33 |
from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
|
| 34 |
|
| 35 |
+
from qwenimage.debug import ctimed, ftimed
|
| 36 |
+
|
| 37 |
|
| 38 |
if is_torch_xla_available():
|
| 39 |
import torch_xla.core.xla_model as xm
|
|
|
|
| 289 |
return prompt_embeds, encoder_attention_mask
|
| 290 |
|
| 291 |
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
|
| 292 |
+
@ftimed
|
| 293 |
def encode_prompt(
|
| 294 |
self,
|
| 295 |
prompt: Union[str, List[str]],
|
|
|
|
| 633 |
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
| 634 |
returning a tuple, the first element is a list with the generated images.
|
| 635 |
"""
|
| 636 |
+
with ctimed("Preprocessing"):
|
| 637 |
+
image_size = image[-1].size if isinstance(image, list) else image.size
|
| 638 |
+
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
|
| 639 |
+
height = height or calculated_height
|
| 640 |
+
width = width or calculated_width
|
| 641 |
+
|
| 642 |
+
multiple_of = self.vae_scale_factor * 2
|
| 643 |
+
width = width // multiple_of * multiple_of
|
| 644 |
+
height = height // multiple_of * multiple_of
|
| 645 |
+
|
| 646 |
+
# 1. Check inputs. Raise error if not correct
|
| 647 |
+
self.check_inputs(
|
| 648 |
+
prompt,
|
| 649 |
+
height,
|
| 650 |
+
width,
|
| 651 |
+
negative_prompt=negative_prompt,
|
| 652 |
+
prompt_embeds=prompt_embeds,
|
| 653 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 654 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 655 |
+
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 656 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 657 |
+
max_sequence_length=max_sequence_length,
|
| 658 |
+
)
|
| 659 |
|
| 660 |
+
self._guidance_scale = guidance_scale
|
| 661 |
+
self._attention_kwargs = attention_kwargs
|
| 662 |
+
self._current_timestep = None
|
| 663 |
+
self._interrupt = False
|
| 664 |
+
|
| 665 |
+
# 2. Define call parameters
|
| 666 |
+
if prompt is not None and isinstance(prompt, str):
|
| 667 |
+
batch_size = 1
|
| 668 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 669 |
+
batch_size = len(prompt)
|
| 670 |
+
else:
|
| 671 |
+
batch_size = prompt_embeds.shape[0]
|
| 672 |
+
|
| 673 |
+
device = self._execution_device
|
| 674 |
+
# 3. Preprocess image
|
| 675 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 676 |
+
if not isinstance(image, list):
|
| 677 |
+
image = [image]
|
| 678 |
+
condition_image_sizes = []
|
| 679 |
+
condition_images = []
|
| 680 |
+
vae_image_sizes = []
|
| 681 |
+
vae_images = []
|
| 682 |
+
for img in image:
|
| 683 |
+
image_width, image_height = img.size
|
| 684 |
+
condition_width, condition_height = calculate_dimensions(
|
| 685 |
+
CONDITION_IMAGE_SIZE, image_width / image_height
|
| 686 |
+
)
|
| 687 |
+
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
|
| 688 |
+
condition_image_sizes.append((condition_width, condition_height))
|
| 689 |
+
vae_image_sizes.append((vae_width, vae_height))
|
| 690 |
+
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
|
| 691 |
+
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
|
| 692 |
+
|
| 693 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 694 |
+
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
| 695 |
+
)
|
| 696 |
|
| 697 |
+
if true_cfg_scale > 1 and not has_neg_prompt:
|
| 698 |
+
logger.warning(
|
| 699 |
+
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
| 700 |
+
)
|
| 701 |
+
elif true_cfg_scale <= 1 and has_neg_prompt:
|
| 702 |
+
logger.warning(
|
| 703 |
+
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
|
| 706 |
+
with ctimed("Encode Prompt"):
|
| 707 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 708 |
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
image=condition_images,
|
| 710 |
+
prompt=prompt,
|
| 711 |
+
prompt_embeds=prompt_embeds,
|
| 712 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 713 |
device=device,
|
| 714 |
num_images_per_prompt=num_images_per_prompt,
|
| 715 |
max_sequence_length=max_sequence_length,
|
| 716 |
)
|
| 717 |
+
if do_true_cfg:
|
| 718 |
+
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
| 719 |
+
image=condition_images,
|
| 720 |
+
prompt=negative_prompt,
|
| 721 |
+
prompt_embeds=negative_prompt_embeds,
|
| 722 |
+
prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 723 |
+
device=device,
|
| 724 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 725 |
+
max_sequence_length=max_sequence_length,
|
| 726 |
+
)
|
| 727 |
|
| 728 |
+
with ctimed("Prep gen"):
|
| 729 |
+
# 4. Prepare latent variables
|
| 730 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 731 |
+
latents, image_latents = self.prepare_latents(
|
| 732 |
+
vae_images,
|
| 733 |
+
batch_size * num_images_per_prompt,
|
| 734 |
+
num_channels_latents,
|
| 735 |
+
height,
|
| 736 |
+
width,
|
| 737 |
+
prompt_embeds.dtype,
|
| 738 |
+
device,
|
| 739 |
+
generator,
|
| 740 |
+
latents,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
)
|
| 742 |
+
img_shapes = [
|
| 743 |
+
[
|
| 744 |
+
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
| 745 |
+
*[
|
| 746 |
+
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
|
| 747 |
+
for vae_width, vae_height in vae_image_sizes
|
| 748 |
+
],
|
| 749 |
+
]
|
| 750 |
+
] * batch_size
|
| 751 |
+
|
| 752 |
+
# 5. Prepare timesteps
|
| 753 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 754 |
+
image_seq_len = latents.shape[1]
|
| 755 |
+
mu = calculate_shift(
|
| 756 |
+
image_seq_len,
|
| 757 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 758 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 759 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 760 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 761 |
)
|
| 762 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 763 |
+
self.scheduler,
|
| 764 |
+
num_inference_steps,
|
| 765 |
+
device,
|
| 766 |
+
sigmas=sigmas,
|
| 767 |
+
mu=mu,
|
| 768 |
)
|
| 769 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 770 |
+
self._num_timesteps = len(timesteps)
|
| 771 |
+
|
| 772 |
+
# handle guidance
|
| 773 |
+
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
| 774 |
+
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
| 775 |
+
elif self.transformer.config.guidance_embeds:
|
| 776 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 777 |
+
guidance = guidance.expand(latents.shape[0])
|
| 778 |
+
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
| 779 |
+
logger.warning(
|
| 780 |
+
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
| 781 |
+
)
|
| 782 |
+
guidance = None
|
| 783 |
+
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
|
| 784 |
+
guidance = None
|
| 785 |
+
|
| 786 |
+
if self.attention_kwargs is None:
|
| 787 |
+
self._attention_kwargs = {}
|
| 788 |
+
|
| 789 |
+
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
| 790 |
+
|
| 791 |
+
image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
| 792 |
+
if do_true_cfg:
|
| 793 |
+
negative_txt_seq_lens = (
|
| 794 |
+
negative_prompt_embeds_mask.sum(dim=1).tolist()
|
| 795 |
+
if negative_prompt_embeds_mask is not None
|
| 796 |
+
else None
|
| 797 |
+
)
|
| 798 |
+
uncond_image_rotary_emb = self.transformer.pos_embed(
|
| 799 |
+
img_shapes, negative_txt_seq_lens, device=latents.device
|
| 800 |
+
)
|
| 801 |
+
else:
|
| 802 |
+
uncond_image_rotary_emb = None
|
| 803 |
+
|
| 804 |
+
with ctimed("loop"):
|
| 805 |
+
# 6. Denoising loop
|
| 806 |
+
self.scheduler.set_begin_index(0)
|
| 807 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 808 |
+
for i, t in enumerate(timesteps):
|
| 809 |
+
with ctimed(f"loop {i}"):
|
| 810 |
+
if self.interrupt:
|
| 811 |
+
continue
|
| 812 |
+
|
| 813 |
+
self._current_timestep = t
|
| 814 |
+
|
| 815 |
+
latent_model_input = latents
|
| 816 |
+
if image_latents is not None:
|
| 817 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 818 |
+
|
| 819 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 820 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 821 |
+
with self.transformer.cache_context("cond"):
|
| 822 |
+
noise_pred = self.transformer(
|
| 823 |
+
hidden_states=latent_model_input,
|
| 824 |
+
timestep=timestep / 1000,
|
| 825 |
+
guidance=guidance,
|
| 826 |
+
encoder_hidden_states_mask=prompt_embeds_mask,
|
| 827 |
+
encoder_hidden_states=prompt_embeds,
|
| 828 |
+
image_rotary_emb=image_rotary_emb,
|
| 829 |
+
attention_kwargs=self.attention_kwargs,
|
| 830 |
+
return_dict=False,
|
| 831 |
+
)[0]
|
| 832 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 833 |
+
|
| 834 |
+
if do_true_cfg:
|
| 835 |
+
with self.transformer.cache_context("uncond"):
|
| 836 |
+
neg_noise_pred = self.transformer(
|
| 837 |
+
hidden_states=latent_model_input,
|
| 838 |
+
timestep=timestep / 1000,
|
| 839 |
+
guidance=guidance,
|
| 840 |
+
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
| 841 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 842 |
+
image_rotary_emb=uncond_image_rotary_emb,
|
| 843 |
+
attention_kwargs=self.attention_kwargs,
|
| 844 |
+
return_dict=False,
|
| 845 |
+
)[0]
|
| 846 |
+
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
| 847 |
+
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 848 |
+
|
| 849 |
+
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
| 850 |
+
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
| 851 |
+
noise_pred = comb_pred * (cond_norm / noise_norm)
|
| 852 |
+
|
| 853 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 854 |
+
latents_dtype = latents.dtype
|
| 855 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 856 |
+
|
| 857 |
+
if latents.dtype != latents_dtype:
|
| 858 |
+
if torch.backends.mps.is_available():
|
| 859 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 860 |
+
latents = latents.to(latents_dtype)
|
| 861 |
+
|
| 862 |
+
if callback_on_step_end is not None:
|
| 863 |
+
callback_kwargs = {}
|
| 864 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 865 |
+
callback_kwargs[k] = locals()[k]
|
| 866 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 867 |
+
|
| 868 |
+
latents = callback_outputs.pop("latents", latents)
|
| 869 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 870 |
+
|
| 871 |
+
# call the callback, if provided
|
| 872 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 873 |
+
progress_bar.update()
|
| 874 |
+
|
| 875 |
+
if XLA_AVAILABLE:
|
| 876 |
+
xm.mark_step()
|
| 877 |
+
|
| 878 |
+
with ctimed("Post (vae)"):
|
| 879 |
+
self._current_timestep = None
|
| 880 |
+
if output_type == "latent":
|
| 881 |
+
image = latents
|
| 882 |
+
else:
|
| 883 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 884 |
+
latents = latents.to(self.vae.dtype)
|
| 885 |
+
latents_mean = (
|
| 886 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 887 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 888 |
+
.to(latents.device, latents.dtype)
|
| 889 |
+
)
|
| 890 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 891 |
+
latents.device, latents.dtype
|
| 892 |
+
)
|
| 893 |
+
latents = latents / latents_std + latents_mean
|
| 894 |
+
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
| 895 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 896 |
|
| 897 |
+
# Offload all models
|
| 898 |
+
self.maybe_free_model_hooks()
|
| 899 |
|
| 900 |
+
if not return_dict:
|
| 901 |
+
return (image,)
|
| 902 |
|
| 903 |
return QwenImagePipelineOutput(images=image)
|
prompt.py → qwenimage/prompt.py
RENAMED
|
File without changes
|