Elea Zhong commited on
Commit
e64ed84
·
1 Parent(s): 454ba5e

add debug functions

Browse files
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.egg-info
2
+ .env
3
+ **/__pycache__/*
4
+ wandb/*
5
+ *.log
6
+ venv/*
7
+ .venv/*
8
+ keyfile
9
+ **/.ipynb_checkpoints/
10
+ **/.DS_Store/*
11
+ .idea/*
12
+ .vscode/*
13
+ latentanalysis_data/*
14
+ latentanalysis_scripts/*
15
+ test-images/*
16
+ weights/*
17
+ latentmask/latentanalysis/*
18
+ cache/*
19
+ docs/automated-documentation/*
app.py CHANGED
@@ -13,11 +13,12 @@ from diffusers import FlowMatchEulerDiscreteScheduler
13
  from huggingface_hub import hf_hub_download
14
  from safetensors.torch import load_file
15
 
16
- from optimization import optimize_pipeline_
 
17
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
18
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
19
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
20
- from prompt import build_camera_prompt
21
 
22
  # --- Model Loading ---
23
  dtype = torch.bfloat16
@@ -53,7 +54,8 @@ optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB",
53
  MAX_SEED = np.iinfo(np.int32).max
54
 
55
 
56
- @spaces.GPU
 
57
  def infer_camera_edit(
58
  image,
59
  rotate_deg,
@@ -111,7 +113,7 @@ css = '''#col-container { max-width: 800px; margin: 0 auto; }
111
  #examples{max-width: 800px; margin: 0 auto; }'''
112
 
113
  def reset_all():
114
- return [0, 0, 0, 0, False, True]
115
 
116
  def end_reset():
117
  return False
 
13
  from huggingface_hub import hf_hub_download
14
  from safetensors.torch import load_file
15
 
16
+ from qwenimage.debug import ftimed
17
+ from qwenimage.optimization import optimize_pipeline_
18
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
19
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
20
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
21
+ from qwenimage.prompt import build_camera_prompt
22
 
23
  # --- Model Loading ---
24
  dtype = torch.bfloat16
 
54
  MAX_SEED = np.iinfo(np.int32).max
55
 
56
 
57
+ # @spaces.GPU
58
+ @ftimed
59
  def infer_camera_edit(
60
  image,
61
  rotate_deg,
 
113
  #examples{max-width: 800px; margin: 0 auto; }'''
114
 
115
  def reset_all():
116
+ return [0, 0, 0, 0, False]
117
 
118
  def end_reset():
119
  return False
pyproject.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "qwenimage"
7
+ version = "0.1"
8
+
9
+ [tool.setuptools.packages.find]
10
+ where = ["."]
11
+ include = ["qwenimage*"]
12
+
13
+ [tool.setuptools.dynamic]
14
+ dependencies = {file = ["requirements/requirements.txt"]}
qwenimage/debug.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import time
4
+ import uuid
5
+ import warnings
6
+ from functools import wraps
7
+ from typing import Callable, Literal
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ from torchvision.utils import save_image
13
+
14
+ DEBUG = True
15
+
16
+ def ftimed(func=None):
17
+
18
+ def decorator(func):
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ if not DEBUG:
22
+ return func(*args, **kwargs)
23
+ else:
24
+ start_time = time.perf_counter()
25
+ result = func(*args, **kwargs)
26
+ end_time = time.perf_counter()
27
+ print(f"Time taken by {func.__qualname__}: {end_time - start_time} seconds")
28
+ return result
29
+ return wrapper
30
+
31
+
32
+ if func is None:
33
+ return decorator
34
+ else:
35
+ return decorator(func)
36
+
37
+
38
+ class ctimed:
39
+ """
40
+ Context manager for timing lines of code. Use like:
41
+ ```
42
+ with ctimed(name="Model Forward"):
43
+ y = model(x)
44
+ ```
45
+ """
46
+ def __init__(self, name=None):
47
+ self.name = name
48
+ self.start_time = None
49
+
50
+ def __enter__(self):
51
+ if DEBUG:
52
+ self.start_time = time.perf_counter()
53
+ return self
54
+
55
+ def __exit__(self, exc_type, exc_value, traceback):
56
+ if DEBUG:
57
+ end_time = time.perf_counter()
58
+ if self.name:
59
+ print(f"Time taken by {self.name}: {end_time - self.start_time} seconds")
60
+ else:
61
+ print(f"Time taken: {end_time - self.start_time} seconds")
62
+
63
+
64
+ def print_gpu_memory(clear_mem: Literal["pre", "post", None] = "pre"):
65
+ if not torch.cuda.is_available():
66
+ warnings.warn("Warning: CUDA device not available. Running on CPU.")
67
+ return
68
+ if clear_mem == "pre":
69
+ torch.cuda.empty_cache()
70
+ allocated = torch.cuda.memory_allocated()
71
+ reserved = torch.cuda.memory_reserved()
72
+ total = torch.cuda.get_device_properties(0).total_memory
73
+ print(f"Memory allocated: {allocated / (1024**2):.2f} MB")
74
+ print(f"Memory reserved: {reserved / (1024**2):.2f} MB")
75
+ print(f"Total memory: {total / (1024**2):.2f} MB")
76
+ if clear_mem == "post":
77
+ torch.cuda.empty_cache()
78
+
79
+ def cuda_empty_cache(func):
80
+ def wrapper(*args, **kwargs):
81
+ result = func(*args, **kwargs)
82
+ if torch.cuda.is_available():
83
+ torch.cuda.empty_cache()
84
+ return result
85
+ return wrapper
86
+
87
+ def print_first_param(module):
88
+ print(list(module.parameters())[0])
89
+
90
+ def fdebug(func=None, *, exclude=None):
91
+ if exclude is None:
92
+ exclude = []
93
+ elif isinstance(exclude, str):
94
+ exclude = [exclude]
95
+
96
+ def decorator(func):
97
+ @wraps(func)
98
+ def wrapper(*args, **kwargs):
99
+ arg_names = func.__code__.co_varnames[:func.__code__.co_argcount]
100
+ arg_vals = args[:len(arg_names)]
101
+ arg_vals = [
102
+ (str(value)+str(value.shape) if isinstance(value, torch.Tensor) else value)
103
+ for value in arg_vals
104
+ ]
105
+ args_pairs = ", ".join(f"{name}={value}" for name, value in zip(arg_names, arg_vals) if name not in exclude)
106
+ kwargs_pairs = ", ".join(f"{k}={v}" for k, v in kwargs.items() if k not in exclude)
107
+ all_args = ", ".join(filter(None, [args_pairs, kwargs_pairs]))
108
+ print(f"Calling {func.__name__}({all_args})")
109
+ result = func(*args, **kwargs)
110
+ print(f"{func.__name__} returned {str(result)+str(result.shape) if isinstance(result, torch.Tensor) else result}")
111
+ return result
112
+ return wrapper
113
+
114
+ if func is None:
115
+ return decorator
116
+ else:
117
+ return decorator(func)
118
+
119
+ class IncrementIndex:
120
+ def __init__(self, max:int=100):
121
+ self.retry_max = max
122
+ self.retries = 0
123
+
124
+ def __call__(self, index):
125
+ if self.retries > self.retry_max:
126
+ raise RuntimeError(f"Retried too many times, max:{self.retry_max}")
127
+ else:
128
+ self.retries += 1
129
+ index += 1
130
+ return index
131
+
132
+ _identity = lambda x: x
133
+
134
+ def fretry(func=None, *, exceptions=(Exception,), mod_args:tuple[Callable|None, ...]=tuple(), mod_kwargs:dict[str,Callable|None]=dict()):
135
+ def decorator(func):
136
+ @wraps(func)
137
+ def fretry_wrapper(*args, **kwargs):
138
+ try:
139
+ out = func(*args, **kwargs)
140
+ except exceptions as e:
141
+ new_args = []
142
+ for i, arg in enumerate(args):
143
+ if i < len(mod_args):
144
+ mod_func = mod_args[i] or _identity
145
+ new_args.append(mod_func(arg))
146
+ else:
147
+ new_args.append(arg)
148
+ new_kwargs = {}
149
+ for k, kwarg in kwargs.items():
150
+ if k in mod_kwargs:
151
+ mod_func = mod_kwargs[k] or _identity
152
+ new_kwargs[k] = mod_func(kwarg)
153
+ kwargs.update(new_kwargs)
154
+
155
+ import traceback
156
+ traceback.print_exc()
157
+ warnings.warn(
158
+ f"Function {func} failed due to {e} with inputs {args}, {kwargs}, "
159
+ f"retrying with modified inputs {new_args}, {new_kwargs}"
160
+ )
161
+ out = fretry_wrapper(*new_args, **new_kwargs)
162
+ return out
163
+ return fretry_wrapper
164
+
165
+ if func is None:
166
+ return decorator
167
+ else:
168
+ return decorator(func)
169
+
170
+
171
+ def texam(t: torch.Tensor):
172
+ print(f"Shape: {tuple(t.shape)}")
173
+ if t.dtype.is_floating_point or t.dtype.is_complex:
174
+ mean_val = t.mean().item()
175
+ else:
176
+ mean_val = "N/A"
177
+ print(f"Min: {t.min().item()}, Max: {t.max().item()}, Mean: {mean_val}")
178
+ print(f"Device: {t.device}, Dtype: {t.dtype}, Requires Grad: {t.requires_grad}")
optimization.py → qwenimage/optimization.py RENAMED
@@ -1,5 +1,6 @@
1
  """
2
  """
 
3
 
4
  from typing import Any
5
  from typing import Callable
@@ -66,5 +67,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
66
  )
67
 
68
  return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
 
 
69
 
70
  spaces.aoti_apply(compile_transformer(), pipeline.transformer)
 
1
  """
2
  """
3
+ import os
4
 
5
  from typing import Any
6
  from typing import Callable
 
67
  )
68
 
69
  return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
70
+
71
+
72
 
73
  spaces.aoti_apply(compile_transformer(), pipeline.transformer)
qwenimage/pipeline_qwenimage_edit_plus.py CHANGED
@@ -18,7 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
18
 
19
  import numpy as np
20
  import torch
21
- from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
 
 
 
22
 
23
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
24
  from diffusers.loaders import QwenImageLoraLoaderMixin
@@ -29,6 +32,8 @@ from diffusers.utils.torch_utils import randn_tensor
29
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
  from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
31
 
 
 
32
 
33
  if is_torch_xla_available():
34
  import torch_xla.core.xla_model as xm
@@ -284,6 +289,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
284
  return prompt_embeds, encoder_attention_mask
285
 
286
  # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
 
287
  def encode_prompt(
288
  self,
289
  prompt: Union[str, List[str]],
@@ -627,265 +633,271 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
627
  [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
628
  returning a tuple, the first element is a list with the generated images.
629
  """
630
- image_size = image[-1].size if isinstance(image, list) else image.size
631
- calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
632
- height = height or calculated_height
633
- width = width or calculated_width
634
-
635
- multiple_of = self.vae_scale_factor * 2
636
- width = width // multiple_of * multiple_of
637
- height = height // multiple_of * multiple_of
638
-
639
- # 1. Check inputs. Raise error if not correct
640
- self.check_inputs(
641
- prompt,
642
- height,
643
- width,
644
- negative_prompt=negative_prompt,
645
- prompt_embeds=prompt_embeds,
646
- negative_prompt_embeds=negative_prompt_embeds,
647
- prompt_embeds_mask=prompt_embeds_mask,
648
- negative_prompt_embeds_mask=negative_prompt_embeds_mask,
649
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
650
- max_sequence_length=max_sequence_length,
651
- )
 
652
 
653
- self._guidance_scale = guidance_scale
654
- self._attention_kwargs = attention_kwargs
655
- self._current_timestep = None
656
- self._interrupt = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
- # 2. Define call parameters
659
- if prompt is not None and isinstance(prompt, str):
660
- batch_size = 1
661
- elif prompt is not None and isinstance(prompt, list):
662
- batch_size = len(prompt)
663
- else:
664
- batch_size = prompt_embeds.shape[0]
665
-
666
- device = self._execution_device
667
- # 3. Preprocess image
668
- if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
669
- if not isinstance(image, list):
670
- image = [image]
671
- condition_image_sizes = []
672
- condition_images = []
673
- vae_image_sizes = []
674
- vae_images = []
675
- for img in image:
676
- image_width, image_height = img.size
677
- condition_width, condition_height = calculate_dimensions(
678
- CONDITION_IMAGE_SIZE, image_width / image_height
679
  )
680
- vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
681
- condition_image_sizes.append((condition_width, condition_height))
682
- vae_image_sizes.append((vae_width, vae_height))
683
- condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
684
- vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
685
-
686
- has_neg_prompt = negative_prompt is not None or (
687
- negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
688
- )
689
-
690
- if true_cfg_scale > 1 and not has_neg_prompt:
691
- logger.warning(
692
- f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
693
- )
694
- elif true_cfg_scale <= 1 and has_neg_prompt:
695
- logger.warning(
696
- " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
697
- )
698
 
699
- do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
700
- prompt_embeds, prompt_embeds_mask = self.encode_prompt(
701
- image=condition_images,
702
- prompt=prompt,
703
- prompt_embeds=prompt_embeds,
704
- prompt_embeds_mask=prompt_embeds_mask,
705
- device=device,
706
- num_images_per_prompt=num_images_per_prompt,
707
- max_sequence_length=max_sequence_length,
708
- )
709
- if do_true_cfg:
710
- negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
711
  image=condition_images,
712
- prompt=negative_prompt,
713
- prompt_embeds=negative_prompt_embeds,
714
- prompt_embeds_mask=negative_prompt_embeds_mask,
715
  device=device,
716
  num_images_per_prompt=num_images_per_prompt,
717
  max_sequence_length=max_sequence_length,
718
  )
 
 
 
 
 
 
 
 
 
 
719
 
720
- # 4. Prepare latent variables
721
- num_channels_latents = self.transformer.config.in_channels // 4
722
- latents, image_latents = self.prepare_latents(
723
- vae_images,
724
- batch_size * num_images_per_prompt,
725
- num_channels_latents,
726
- height,
727
- width,
728
- prompt_embeds.dtype,
729
- device,
730
- generator,
731
- latents,
732
- )
733
- img_shapes = [
734
- [
735
- (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
736
- *[
737
- (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
738
- for vae_width, vae_height in vae_image_sizes
739
- ],
740
- ]
741
- ] * batch_size
742
-
743
- # 5. Prepare timesteps
744
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
745
- image_seq_len = latents.shape[1]
746
- mu = calculate_shift(
747
- image_seq_len,
748
- self.scheduler.config.get("base_image_seq_len", 256),
749
- self.scheduler.config.get("max_image_seq_len", 4096),
750
- self.scheduler.config.get("base_shift", 0.5),
751
- self.scheduler.config.get("max_shift", 1.15),
752
- )
753
- timesteps, num_inference_steps = retrieve_timesteps(
754
- self.scheduler,
755
- num_inference_steps,
756
- device,
757
- sigmas=sigmas,
758
- mu=mu,
759
- )
760
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
761
- self._num_timesteps = len(timesteps)
762
-
763
- # handle guidance
764
- if self.transformer.config.guidance_embeds and guidance_scale is None:
765
- raise ValueError("guidance_scale is required for guidance-distilled model.")
766
- elif self.transformer.config.guidance_embeds:
767
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
768
- guidance = guidance.expand(latents.shape[0])
769
- elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
770
- logger.warning(
771
- f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
772
  )
773
- guidance = None
774
- elif not self.transformer.config.guidance_embeds and guidance_scale is None:
775
- guidance = None
776
-
777
- if self.attention_kwargs is None:
778
- self._attention_kwargs = {}
779
-
780
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
781
-
782
- image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
783
- if do_true_cfg:
784
- negative_txt_seq_lens = (
785
- negative_prompt_embeds_mask.sum(dim=1).tolist()
786
- if negative_prompt_embeds_mask is not None
787
- else None
 
 
 
 
788
  )
789
- uncond_image_rotary_emb = self.transformer.pos_embed(
790
- img_shapes, negative_txt_seq_lens, device=latents.device
 
 
 
 
791
  )
792
- else:
793
- uncond_image_rotary_emb = None
794
-
795
- # 6. Denoising loop
796
- self.scheduler.set_begin_index(0)
797
- with self.progress_bar(total=num_inference_steps) as progress_bar:
798
- for i, t in enumerate(timesteps):
799
- if self.interrupt:
800
- continue
801
-
802
- self._current_timestep = t
803
-
804
- latent_model_input = latents
805
- if image_latents is not None:
806
- latent_model_input = torch.cat([latents, image_latents], dim=1)
807
-
808
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
809
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
810
- with self.transformer.cache_context("cond"):
811
- noise_pred = self.transformer(
812
- hidden_states=latent_model_input,
813
- timestep=timestep / 1000,
814
- guidance=guidance,
815
- encoder_hidden_states_mask=prompt_embeds_mask,
816
- encoder_hidden_states=prompt_embeds,
817
- image_rotary_emb=image_rotary_emb,
818
- attention_kwargs=self.attention_kwargs,
819
- return_dict=False,
820
- )[0]
821
- noise_pred = noise_pred[:, : latents.size(1)]
822
-
823
- if do_true_cfg:
824
- with self.transformer.cache_context("uncond"):
825
- neg_noise_pred = self.transformer(
826
- hidden_states=latent_model_input,
827
- timestep=timestep / 1000,
828
- guidance=guidance,
829
- encoder_hidden_states_mask=negative_prompt_embeds_mask,
830
- encoder_hidden_states=negative_prompt_embeds,
831
- image_rotary_emb=uncond_image_rotary_emb,
832
- attention_kwargs=self.attention_kwargs,
833
- return_dict=False,
834
- )[0]
835
- neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
836
- comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
837
-
838
- cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
839
- noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
840
- noise_pred = comb_pred * (cond_norm / noise_norm)
841
-
842
- # compute the previous noisy sample x_t -> x_t-1
843
- latents_dtype = latents.dtype
844
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
845
-
846
- if latents.dtype != latents_dtype:
847
- if torch.backends.mps.is_available():
848
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
849
- latents = latents.to(latents_dtype)
850
-
851
- if callback_on_step_end is not None:
852
- callback_kwargs = {}
853
- for k in callback_on_step_end_tensor_inputs:
854
- callback_kwargs[k] = locals()[k]
855
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
856
-
857
- latents = callback_outputs.pop("latents", latents)
858
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
859
-
860
- # call the callback, if provided
861
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
862
- progress_bar.update()
863
-
864
- if XLA_AVAILABLE:
865
- xm.mark_step()
866
-
867
- self._current_timestep = None
868
- if output_type == "latent":
869
- image = latents
870
- else:
871
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
872
- latents = latents.to(self.vae.dtype)
873
- latents_mean = (
874
- torch.tensor(self.vae.config.latents_mean)
875
- .view(1, self.vae.config.z_dim, 1, 1, 1)
876
- .to(latents.device, latents.dtype)
877
- )
878
- latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
879
- latents.device, latents.dtype
880
- )
881
- latents = latents / latents_std + latents_mean
882
- image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
883
- image = self.image_processor.postprocess(image, output_type=output_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884
 
885
- # Offload all models
886
- self.maybe_free_model_hooks()
887
 
888
- if not return_dict:
889
- return (image,)
890
 
891
  return QwenImagePipelineOutput(images=image)
 
18
 
19
  import numpy as np
20
  import torch
21
+ # from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
22
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
23
+ from transformers.models.qwen2 import Qwen2Tokenizer
24
+ from transformers.models.qwen2_vl import Qwen2VLProcessor
25
 
26
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
27
  from diffusers.loaders import QwenImageLoraLoaderMixin
 
32
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
  from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
34
 
35
+ from qwenimage.debug import ctimed, ftimed
36
+
37
 
38
  if is_torch_xla_available():
39
  import torch_xla.core.xla_model as xm
 
289
  return prompt_embeds, encoder_attention_mask
290
 
291
  # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
292
+ @ftimed
293
  def encode_prompt(
294
  self,
295
  prompt: Union[str, List[str]],
 
633
  [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
634
  returning a tuple, the first element is a list with the generated images.
635
  """
636
+ with ctimed("Preprocessing"):
637
+ image_size = image[-1].size if isinstance(image, list) else image.size
638
+ calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
639
+ height = height or calculated_height
640
+ width = width or calculated_width
641
+
642
+ multiple_of = self.vae_scale_factor * 2
643
+ width = width // multiple_of * multiple_of
644
+ height = height // multiple_of * multiple_of
645
+
646
+ # 1. Check inputs. Raise error if not correct
647
+ self.check_inputs(
648
+ prompt,
649
+ height,
650
+ width,
651
+ negative_prompt=negative_prompt,
652
+ prompt_embeds=prompt_embeds,
653
+ negative_prompt_embeds=negative_prompt_embeds,
654
+ prompt_embeds_mask=prompt_embeds_mask,
655
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
656
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
657
+ max_sequence_length=max_sequence_length,
658
+ )
659
 
660
+ self._guidance_scale = guidance_scale
661
+ self._attention_kwargs = attention_kwargs
662
+ self._current_timestep = None
663
+ self._interrupt = False
664
+
665
+ # 2. Define call parameters
666
+ if prompt is not None and isinstance(prompt, str):
667
+ batch_size = 1
668
+ elif prompt is not None and isinstance(prompt, list):
669
+ batch_size = len(prompt)
670
+ else:
671
+ batch_size = prompt_embeds.shape[0]
672
+
673
+ device = self._execution_device
674
+ # 3. Preprocess image
675
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
676
+ if not isinstance(image, list):
677
+ image = [image]
678
+ condition_image_sizes = []
679
+ condition_images = []
680
+ vae_image_sizes = []
681
+ vae_images = []
682
+ for img in image:
683
+ image_width, image_height = img.size
684
+ condition_width, condition_height = calculate_dimensions(
685
+ CONDITION_IMAGE_SIZE, image_width / image_height
686
+ )
687
+ vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
688
+ condition_image_sizes.append((condition_width, condition_height))
689
+ vae_image_sizes.append((vae_width, vae_height))
690
+ condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
691
+ vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
692
+
693
+ has_neg_prompt = negative_prompt is not None or (
694
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
695
+ )
696
 
697
+ if true_cfg_scale > 1 and not has_neg_prompt:
698
+ logger.warning(
699
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
700
+ )
701
+ elif true_cfg_scale <= 1 and has_neg_prompt:
702
+ logger.warning(
703
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
 
706
+ with ctimed("Encode Prompt"):
707
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
708
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
 
 
 
 
 
 
 
 
 
709
  image=condition_images,
710
+ prompt=prompt,
711
+ prompt_embeds=prompt_embeds,
712
+ prompt_embeds_mask=prompt_embeds_mask,
713
  device=device,
714
  num_images_per_prompt=num_images_per_prompt,
715
  max_sequence_length=max_sequence_length,
716
  )
717
+ if do_true_cfg:
718
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
719
+ image=condition_images,
720
+ prompt=negative_prompt,
721
+ prompt_embeds=negative_prompt_embeds,
722
+ prompt_embeds_mask=negative_prompt_embeds_mask,
723
+ device=device,
724
+ num_images_per_prompt=num_images_per_prompt,
725
+ max_sequence_length=max_sequence_length,
726
+ )
727
 
728
+ with ctimed("Prep gen"):
729
+ # 4. Prepare latent variables
730
+ num_channels_latents = self.transformer.config.in_channels // 4
731
+ latents, image_latents = self.prepare_latents(
732
+ vae_images,
733
+ batch_size * num_images_per_prompt,
734
+ num_channels_latents,
735
+ height,
736
+ width,
737
+ prompt_embeds.dtype,
738
+ device,
739
+ generator,
740
+ latents,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
  )
742
+ img_shapes = [
743
+ [
744
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
745
+ *[
746
+ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
747
+ for vae_width, vae_height in vae_image_sizes
748
+ ],
749
+ ]
750
+ ] * batch_size
751
+
752
+ # 5. Prepare timesteps
753
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
754
+ image_seq_len = latents.shape[1]
755
+ mu = calculate_shift(
756
+ image_seq_len,
757
+ self.scheduler.config.get("base_image_seq_len", 256),
758
+ self.scheduler.config.get("max_image_seq_len", 4096),
759
+ self.scheduler.config.get("base_shift", 0.5),
760
+ self.scheduler.config.get("max_shift", 1.15),
761
  )
762
+ timesteps, num_inference_steps = retrieve_timesteps(
763
+ self.scheduler,
764
+ num_inference_steps,
765
+ device,
766
+ sigmas=sigmas,
767
+ mu=mu,
768
  )
769
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
770
+ self._num_timesteps = len(timesteps)
771
+
772
+ # handle guidance
773
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
774
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
775
+ elif self.transformer.config.guidance_embeds:
776
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
777
+ guidance = guidance.expand(latents.shape[0])
778
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
779
+ logger.warning(
780
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
781
+ )
782
+ guidance = None
783
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
784
+ guidance = None
785
+
786
+ if self.attention_kwargs is None:
787
+ self._attention_kwargs = {}
788
+
789
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
790
+
791
+ image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
792
+ if do_true_cfg:
793
+ negative_txt_seq_lens = (
794
+ negative_prompt_embeds_mask.sum(dim=1).tolist()
795
+ if negative_prompt_embeds_mask is not None
796
+ else None
797
+ )
798
+ uncond_image_rotary_emb = self.transformer.pos_embed(
799
+ img_shapes, negative_txt_seq_lens, device=latents.device
800
+ )
801
+ else:
802
+ uncond_image_rotary_emb = None
803
+
804
+ with ctimed("loop"):
805
+ # 6. Denoising loop
806
+ self.scheduler.set_begin_index(0)
807
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
808
+ for i, t in enumerate(timesteps):
809
+ with ctimed(f"loop {i}"):
810
+ if self.interrupt:
811
+ continue
812
+
813
+ self._current_timestep = t
814
+
815
+ latent_model_input = latents
816
+ if image_latents is not None:
817
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
818
+
819
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
820
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
821
+ with self.transformer.cache_context("cond"):
822
+ noise_pred = self.transformer(
823
+ hidden_states=latent_model_input,
824
+ timestep=timestep / 1000,
825
+ guidance=guidance,
826
+ encoder_hidden_states_mask=prompt_embeds_mask,
827
+ encoder_hidden_states=prompt_embeds,
828
+ image_rotary_emb=image_rotary_emb,
829
+ attention_kwargs=self.attention_kwargs,
830
+ return_dict=False,
831
+ )[0]
832
+ noise_pred = noise_pred[:, : latents.size(1)]
833
+
834
+ if do_true_cfg:
835
+ with self.transformer.cache_context("uncond"):
836
+ neg_noise_pred = self.transformer(
837
+ hidden_states=latent_model_input,
838
+ timestep=timestep / 1000,
839
+ guidance=guidance,
840
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
841
+ encoder_hidden_states=negative_prompt_embeds,
842
+ image_rotary_emb=uncond_image_rotary_emb,
843
+ attention_kwargs=self.attention_kwargs,
844
+ return_dict=False,
845
+ )[0]
846
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
847
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
848
+
849
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
850
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
851
+ noise_pred = comb_pred * (cond_norm / noise_norm)
852
+
853
+ # compute the previous noisy sample x_t -> x_t-1
854
+ latents_dtype = latents.dtype
855
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
856
+
857
+ if latents.dtype != latents_dtype:
858
+ if torch.backends.mps.is_available():
859
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
860
+ latents = latents.to(latents_dtype)
861
+
862
+ if callback_on_step_end is not None:
863
+ callback_kwargs = {}
864
+ for k in callback_on_step_end_tensor_inputs:
865
+ callback_kwargs[k] = locals()[k]
866
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
867
+
868
+ latents = callback_outputs.pop("latents", latents)
869
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
870
+
871
+ # call the callback, if provided
872
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
873
+ progress_bar.update()
874
+
875
+ if XLA_AVAILABLE:
876
+ xm.mark_step()
877
+
878
+ with ctimed("Post (vae)"):
879
+ self._current_timestep = None
880
+ if output_type == "latent":
881
+ image = latents
882
+ else:
883
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
884
+ latents = latents.to(self.vae.dtype)
885
+ latents_mean = (
886
+ torch.tensor(self.vae.config.latents_mean)
887
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
888
+ .to(latents.device, latents.dtype)
889
+ )
890
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
891
+ latents.device, latents.dtype
892
+ )
893
+ latents = latents / latents_std + latents_mean
894
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
895
+ image = self.image_processor.postprocess(image, output_type=output_type)
896
 
897
+ # Offload all models
898
+ self.maybe_free_model_hooks()
899
 
900
+ if not return_dict:
901
+ return (image,)
902
 
903
  return QwenImagePipelineOutput(images=image)
prompt.py → qwenimage/prompt.py RENAMED
File without changes