Update pipeline_glide.py
Browse files- pipeline_glide.py +66 -131
pipeline_glide.py
CHANGED
|
@@ -24,20 +24,16 @@ import torch.utils.checkpoint
|
|
| 24 |
from torch import nn
|
| 25 |
|
| 26 |
import tqdm
|
| 27 |
-
from diffusers.models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
| 28 |
-
from diffusers.pipeline_utils import DiffusionPipeline
|
| 29 |
-
from diffusers.schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler
|
| 30 |
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
| 31 |
from transformers.activations import ACT2FN
|
| 32 |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
| 33 |
from transformers.modeling_utils import PreTrainedModel
|
| 34 |
-
from transformers.utils import
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
)
|
| 41 |
|
| 42 |
|
| 43 |
#####################
|
|
@@ -719,7 +715,7 @@ class GLIDE(DiffusionPipeline):
|
|
| 719 |
def __init__(
|
| 720 |
self,
|
| 721 |
text_unet: GLIDETextToImageUNetModel,
|
| 722 |
-
text_noise_scheduler:
|
| 723 |
text_encoder: CLIPTextModel,
|
| 724 |
tokenizer: GPT2Tokenizer,
|
| 725 |
upscale_unet: GLIDESuperResUNetModel,
|
|
@@ -735,100 +731,28 @@ class GLIDE(DiffusionPipeline):
|
|
| 735 |
upscale_noise_scheduler=upscale_noise_scheduler,
|
| 736 |
)
|
| 737 |
|
| 738 |
-
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
| 739 |
-
"""
|
| 740 |
-
Compute the mean and variance of the diffusion posterior:
|
| 741 |
-
|
| 742 |
-
q(x_{t-1} | x_t, x_0)
|
| 743 |
-
|
| 744 |
-
"""
|
| 745 |
-
assert x_start.shape == x_t.shape
|
| 746 |
-
posterior_mean = (
|
| 747 |
-
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 748 |
-
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 749 |
-
)
|
| 750 |
-
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
| 751 |
-
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
| 752 |
-
assert (
|
| 753 |
-
posterior_mean.shape[0]
|
| 754 |
-
== posterior_variance.shape[0]
|
| 755 |
-
== posterior_log_variance_clipped.shape[0]
|
| 756 |
-
== x_start.shape[0]
|
| 757 |
-
)
|
| 758 |
-
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 759 |
-
|
| 760 |
-
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
|
| 761 |
-
"""
|
| 762 |
-
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
| 763 |
-
the initial x, x_0.
|
| 764 |
-
|
| 765 |
-
:param model: the model, which takes a signal and a batch of timesteps
|
| 766 |
-
as input.
|
| 767 |
-
:param x: the [N x C x ...] tensor at time t.
|
| 768 |
-
:param t: a 1-D Tensor of timesteps.
|
| 769 |
-
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
| 770 |
-
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 771 |
-
pass to the model. This can be used for conditioning.
|
| 772 |
-
:return: a dict with the following keys:
|
| 773 |
-
- 'mean': the model mean output.
|
| 774 |
-
- 'variance': the model variance output.
|
| 775 |
-
- 'log_variance': the log of 'variance'.
|
| 776 |
-
- 'pred_xstart': the prediction for x_0.
|
| 777 |
-
"""
|
| 778 |
-
|
| 779 |
-
B, C = x.shape[:2]
|
| 780 |
-
assert t.shape == (B,)
|
| 781 |
-
if transformer_out is None:
|
| 782 |
-
# super-res model
|
| 783 |
-
model_output = model(x, t, low_res)
|
| 784 |
-
else:
|
| 785 |
-
# text2image model
|
| 786 |
-
model_output = model(x, t, transformer_out)
|
| 787 |
-
|
| 788 |
-
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
| 789 |
-
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
| 790 |
-
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
|
| 791 |
-
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
|
| 792 |
-
# The model_var_values is [-1, 1] for [min_var, max_var].
|
| 793 |
-
frac = (model_var_values + 1) / 2
|
| 794 |
-
model_log_variance = frac * max_log + (1 - frac) * min_log
|
| 795 |
-
model_variance = torch.exp(model_log_variance)
|
| 796 |
-
|
| 797 |
-
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
|
| 798 |
-
if clip_denoised:
|
| 799 |
-
pred_xstart = pred_xstart.clamp(-1, 1)
|
| 800 |
-
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
|
| 801 |
-
|
| 802 |
-
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
| 803 |
-
return model_mean, model_variance, model_log_variance, pred_xstart
|
| 804 |
-
|
| 805 |
-
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
|
| 806 |
-
assert x_t.shape == eps.shape
|
| 807 |
-
return (
|
| 808 |
-
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 809 |
-
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
| 810 |
-
)
|
| 811 |
-
|
| 812 |
-
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
|
| 813 |
-
return (
|
| 814 |
-
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
| 815 |
-
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 816 |
-
|
| 817 |
@torch.no_grad()
|
| 818 |
-
def __call__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 820 |
|
| 821 |
self.text_unet.to(torch_device)
|
| 822 |
self.text_encoder.to(torch_device)
|
| 823 |
self.upscale_unet.to(torch_device)
|
| 824 |
|
| 825 |
-
|
| 826 |
-
guidance_scale = 3.0
|
| 827 |
-
|
| 828 |
-
def text_model_fn(x_t, ts, transformer_out, **kwargs):
|
| 829 |
half = x_t[: len(x_t) // 2]
|
| 830 |
combined = torch.cat([half, half], dim=0)
|
| 831 |
-
model_out = self.text_unet(combined,
|
| 832 |
eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 833 |
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 834 |
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
|
@@ -837,71 +761,82 @@ class GLIDE(DiffusionPipeline):
|
|
| 837 |
|
| 838 |
# 1. Sample gaussian noise
|
| 839 |
batch_size = 2 # second image is empty for classifier-free guidance
|
| 840 |
-
image =
|
| 841 |
-
(
|
| 842 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
|
| 844 |
# 2. Encode tokens
|
| 845 |
-
# an empty input is needed to guide the model away from
|
| 846 |
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
| 847 |
input_ids = inputs["input_ids"].to(torch_device)
|
| 848 |
attention_mask = inputs["attention_mask"].to(torch_device)
|
| 849 |
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
| 850 |
|
| 851 |
# 3. Run the text2image generation step
|
| 852 |
-
|
| 853 |
-
for
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 857 |
)
|
| 858 |
-
noise =
|
| 859 |
-
|
| 860 |
-
|
|
|
|
|
|
|
| 861 |
|
| 862 |
# 4. Run the upscaling step
|
| 863 |
batch_size = 1
|
| 864 |
image = image[:1]
|
| 865 |
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
| 866 |
-
eta = 0.0
|
| 867 |
-
|
| 868 |
-
# Tune this parameter to control the sharpness of 256x256 images.
|
| 869 |
-
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
| 870 |
-
upsample_temp = 0.997
|
| 871 |
|
| 872 |
# Sample gaussian noise to begin loop
|
| 873 |
image = torch.randn(
|
| 874 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 875 |
generator=generator,
|
| 876 |
-
)
|
| 877 |
-
image = image
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
# Notation (<variable name> -> <name in paper>
|
| 883 |
-
# - pred_noise_t -> e_theta(x_t, t)
|
| 884 |
-
# - pred_original_image -> f_theta(x_t, t) or x_0
|
| 885 |
-
# - std_dev_t -> sigma_t
|
| 886 |
-
# - eta -> η
|
| 887 |
-
# - pred_image_direction -> "direction pointingc to x_t"
|
| 888 |
-
# - pred_prev_image -> "x_t-1"
|
| 889 |
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
| 890 |
# 1. predict noise residual
|
| 891 |
with torch.no_grad():
|
| 892 |
-
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
| 893 |
model_output = self.upscale_unet(image, time_input, low_res)
|
| 894 |
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
| 895 |
|
| 896 |
# 2. predict previous mean of image x_t-1
|
| 897 |
pred_prev_image = self.upscale_noise_scheduler.step(
|
| 898 |
-
noise_residual, image, t, num_inference_steps_upscale, eta
|
| 899 |
)
|
| 900 |
|
| 901 |
# 3. optionally sample variance
|
| 902 |
variance = 0
|
| 903 |
if eta > 0:
|
| 904 |
-
noise = torch.randn(image.shape, generator=generator).to(
|
| 905 |
variance = (
|
| 906 |
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
|
| 907 |
)
|
|
@@ -909,6 +844,6 @@ class GLIDE(DiffusionPipeline):
|
|
| 909 |
# 4. set current image to prev_image: x_t -> x_t-1
|
| 910 |
image = pred_prev_image + variance
|
| 911 |
|
| 912 |
-
image = image.permute(0, 2, 3, 1)
|
| 913 |
|
| 914 |
return image
|
|
|
|
| 24 |
from torch import nn
|
| 25 |
|
| 26 |
import tqdm
|
|
|
|
|
|
|
|
|
|
| 27 |
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
| 28 |
from transformers.activations import ACT2FN
|
| 29 |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
| 30 |
from transformers.modeling_utils import PreTrainedModel
|
| 31 |
+
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
|
| 32 |
+
|
| 33 |
+
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
| 34 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 35 |
+
from ..schedulers import DDPMScheduler, DDIMScheduler
|
| 36 |
+
from ..utils import logging
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
#####################
|
|
|
|
| 715 |
def __init__(
|
| 716 |
self,
|
| 717 |
text_unet: GLIDETextToImageUNetModel,
|
| 718 |
+
text_noise_scheduler: DDPMScheduler,
|
| 719 |
text_encoder: CLIPTextModel,
|
| 720 |
tokenizer: GPT2Tokenizer,
|
| 721 |
upscale_unet: GLIDESuperResUNetModel,
|
|
|
|
| 731 |
upscale_noise_scheduler=upscale_noise_scheduler,
|
| 732 |
)
|
| 733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
@torch.no_grad()
|
| 735 |
+
def __call__(
|
| 736 |
+
self,
|
| 737 |
+
prompt,
|
| 738 |
+
generator=None,
|
| 739 |
+
torch_device=None,
|
| 740 |
+
num_inference_steps_upscale=50,
|
| 741 |
+
guidance_scale=3.0,
|
| 742 |
+
eta=0.0,
|
| 743 |
+
upsample_temp=0.997,
|
| 744 |
+
):
|
| 745 |
+
|
| 746 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 747 |
|
| 748 |
self.text_unet.to(torch_device)
|
| 749 |
self.text_encoder.to(torch_device)
|
| 750 |
self.upscale_unet.to(torch_device)
|
| 751 |
|
| 752 |
+
def text_model_fn(x_t, timesteps, transformer_out, **kwargs):
|
|
|
|
|
|
|
|
|
|
| 753 |
half = x_t[: len(x_t) // 2]
|
| 754 |
combined = torch.cat([half, half], dim=0)
|
| 755 |
+
model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs)
|
| 756 |
eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 757 |
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 758 |
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
|
|
|
| 761 |
|
| 762 |
# 1. Sample gaussian noise
|
| 763 |
batch_size = 2 # second image is empty for classifier-free guidance
|
| 764 |
+
image = torch.randn(
|
| 765 |
+
(
|
| 766 |
+
batch_size,
|
| 767 |
+
self.text_unet.in_channels,
|
| 768 |
+
self.text_unet.resolution,
|
| 769 |
+
self.text_unet.resolution,
|
| 770 |
+
),
|
| 771 |
+
generator=generator,
|
| 772 |
+
).to(torch_device)
|
| 773 |
|
| 774 |
# 2. Encode tokens
|
| 775 |
+
# an empty input is needed to guide the model away from it
|
| 776 |
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
| 777 |
input_ids = inputs["input_ids"].to(torch_device)
|
| 778 |
attention_mask = inputs["attention_mask"].to(torch_device)
|
| 779 |
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
| 780 |
|
| 781 |
# 3. Run the text2image generation step
|
| 782 |
+
num_prediction_steps = len(self.text_noise_scheduler)
|
| 783 |
+
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
| 784 |
+
with torch.no_grad():
|
| 785 |
+
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
| 786 |
+
model_output = text_model_fn(image, time_input, transformer_out)
|
| 787 |
+
noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
|
| 788 |
+
|
| 789 |
+
min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log")
|
| 790 |
+
max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log")
|
| 791 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
| 792 |
+
frac = (model_var_values + 1) / 2
|
| 793 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
| 794 |
+
|
| 795 |
+
pred_prev_image = self.upscale_noise_scheduler.step(
|
| 796 |
+
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
|
| 797 |
)
|
| 798 |
+
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
| 799 |
+
variance = torch.exp(0.5 * model_log_variance) * noise
|
| 800 |
+
|
| 801 |
+
# set current image to prev_image: x_t -> x_t-1
|
| 802 |
+
image = pred_prev_image + variance
|
| 803 |
|
| 804 |
# 4. Run the upscaling step
|
| 805 |
batch_size = 1
|
| 806 |
image = image[:1]
|
| 807 |
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 808 |
|
| 809 |
# Sample gaussian noise to begin loop
|
| 810 |
image = torch.randn(
|
| 811 |
+
(
|
| 812 |
+
batch_size,
|
| 813 |
+
self.upscale_unet.in_channels // 2,
|
| 814 |
+
self.upscale_unet.resolution,
|
| 815 |
+
self.upscale_unet.resolution,
|
| 816 |
+
),
|
| 817 |
generator=generator,
|
| 818 |
+
).to(torch_device)
|
| 819 |
+
image = image * upsample_temp
|
| 820 |
+
|
| 821 |
+
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
| 822 |
+
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
| 823 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
| 825 |
# 1. predict noise residual
|
| 826 |
with torch.no_grad():
|
| 827 |
+
time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
| 828 |
model_output = self.upscale_unet(image, time_input, low_res)
|
| 829 |
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
| 830 |
|
| 831 |
# 2. predict previous mean of image x_t-1
|
| 832 |
pred_prev_image = self.upscale_noise_scheduler.step(
|
| 833 |
+
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
|
| 834 |
)
|
| 835 |
|
| 836 |
# 3. optionally sample variance
|
| 837 |
variance = 0
|
| 838 |
if eta > 0:
|
| 839 |
+
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
| 840 |
variance = (
|
| 841 |
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
|
| 842 |
)
|
|
|
|
| 844 |
# 4. set current image to prev_image: x_t -> x_t-1
|
| 845 |
image = pred_prev_image + variance
|
| 846 |
|
| 847 |
+
image = image.clamp(-1, 1).permute(0, 2, 3, 1)
|
| 848 |
|
| 849 |
return image
|