Upload 4 files
Browse files- cog.yaml +11 -0
- image_to_image.py +281 -0
- predict.py +136 -0
- script/download-weights +18 -0
cog.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build:
|
| 2 |
+
gpu: true
|
| 3 |
+
cuda: "11.6.2"
|
| 4 |
+
python_version: "3.10"
|
| 5 |
+
python_packages:
|
| 6 |
+
- "diffusers==0.2.4"
|
| 7 |
+
- "torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116"
|
| 8 |
+
- "ftfy==6.1.1"
|
| 9 |
+
- "scipy==1.9.0"
|
| 10 |
+
- "transformers==4.21.1"
|
| 11 |
+
predict: "predict.py:Predictor"
|
image_to_image.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import List, Optional, Union, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from diffusers import (
|
| 8 |
+
AutoencoderKL,
|
| 9 |
+
DDIMScheduler,
|
| 10 |
+
DiffusionPipeline,
|
| 11 |
+
PNDMScheduler,
|
| 12 |
+
LMSDiscreteScheduler,
|
| 13 |
+
UNet2DConditionModel,
|
| 14 |
+
)
|
| 15 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 16 |
+
from tqdm.auto import tqdm
|
| 17 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def preprocess_init_image(image: Image, width: int, height: int):
|
| 21 |
+
image = image.resize((width, height), resample=Image.LANCZOS)
|
| 22 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 23 |
+
image = image[None].transpose(0, 3, 1, 2)
|
| 24 |
+
image = torch.from_numpy(image)
|
| 25 |
+
return 2.0 * image - 1.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def preprocess_mask(mask: Image, width: int, height: int):
|
| 29 |
+
mask = mask.convert("L")
|
| 30 |
+
mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS)
|
| 31 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
| 32 |
+
mask = np.tile(mask, (4, 1, 1))
|
| 33 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
| 34 |
+
mask = torch.from_numpy(mask)
|
| 35 |
+
return mask
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
| 39 |
+
"""
|
| 40 |
+
From https://github.com/huggingface/diffusers/pull/241
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
vae: AutoencoderKL,
|
| 46 |
+
text_encoder: CLIPTextModel,
|
| 47 |
+
tokenizer: CLIPTokenizer,
|
| 48 |
+
unet: UNet2DConditionModel,
|
| 49 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
| 50 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 51 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
scheduler = scheduler.set_format("pt")
|
| 55 |
+
self.register_modules(
|
| 56 |
+
vae=vae,
|
| 57 |
+
text_encoder=text_encoder,
|
| 58 |
+
tokenizer=tokenizer,
|
| 59 |
+
unet=unet,
|
| 60 |
+
scheduler=scheduler,
|
| 61 |
+
safety_checker=safety_checker,
|
| 62 |
+
feature_extractor=feature_extractor,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
@torch.no_grad()
|
| 66 |
+
def __call__(
|
| 67 |
+
self,
|
| 68 |
+
prompt: Union[str, List[str]],
|
| 69 |
+
init_image: Optional[torch.FloatTensor],
|
| 70 |
+
mask: Optional[torch.FloatTensor],
|
| 71 |
+
width: int,
|
| 72 |
+
height: int,
|
| 73 |
+
prompt_strength: float = 0.8,
|
| 74 |
+
num_inference_steps: int = 50,
|
| 75 |
+
guidance_scale: float = 7.5,
|
| 76 |
+
eta: float = 0.0,
|
| 77 |
+
generator: Optional[torch.Generator] = None,
|
| 78 |
+
) -> Image:
|
| 79 |
+
if isinstance(prompt, str):
|
| 80 |
+
batch_size = 1
|
| 81 |
+
elif isinstance(prompt, list):
|
| 82 |
+
batch_size = len(prompt)
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if prompt_strength < 0 or prompt_strength > 1:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if mask is not None and init_image is None:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"If mask is defined, then init_image also needs to be defined"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if width % 8 != 0 or height % 8 != 0:
|
| 99 |
+
raise ValueError("Width and height must both be divisible by 8")
|
| 100 |
+
|
| 101 |
+
# set timesteps
|
| 102 |
+
accepts_offset = "offset" in set(
|
| 103 |
+
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
|
| 104 |
+
)
|
| 105 |
+
extra_set_kwargs = {}
|
| 106 |
+
offset = 0
|
| 107 |
+
if accepts_offset:
|
| 108 |
+
offset = 1
|
| 109 |
+
extra_set_kwargs["offset"] = 1
|
| 110 |
+
|
| 111 |
+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
| 112 |
+
|
| 113 |
+
if init_image is not None:
|
| 114 |
+
init_latents_orig, latents, init_timestep = self.latents_from_init_image(
|
| 115 |
+
init_image,
|
| 116 |
+
prompt_strength,
|
| 117 |
+
offset,
|
| 118 |
+
num_inference_steps,
|
| 119 |
+
batch_size,
|
| 120 |
+
generator,
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
latents = torch.randn(
|
| 124 |
+
(batch_size, self.unet.in_channels, height // 8, width // 8),
|
| 125 |
+
generator=generator,
|
| 126 |
+
device=self.device,
|
| 127 |
+
)
|
| 128 |
+
init_timestep = num_inference_steps
|
| 129 |
+
|
| 130 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 131 |
+
text_embeddings = self.embed_text(
|
| 132 |
+
prompt, do_classifier_free_guidance, batch_size
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 136 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 137 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 138 |
+
# and should be between [0, 1]
|
| 139 |
+
accepts_eta = "eta" in set(
|
| 140 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
| 141 |
+
)
|
| 142 |
+
extra_step_kwargs = {}
|
| 143 |
+
if accepts_eta:
|
| 144 |
+
extra_step_kwargs["eta"] = eta
|
| 145 |
+
|
| 146 |
+
mask_noise = torch.randn(latents.shape, generator=generator, device=self.device)
|
| 147 |
+
|
| 148 |
+
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
| 149 |
+
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
| 150 |
+
latents = latents * self.scheduler.sigmas[0]
|
| 151 |
+
|
| 152 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
| 153 |
+
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
| 154 |
+
# expand the latents if we are doing classifier free guidance
|
| 155 |
+
latent_model_input = (
|
| 156 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
| 160 |
+
sigma = self.scheduler.sigmas[i]
|
| 161 |
+
latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5)
|
| 162 |
+
|
| 163 |
+
# predict the noise residual
|
| 164 |
+
noise_pred = self.unet(
|
| 165 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings
|
| 166 |
+
)["sample"]
|
| 167 |
+
|
| 168 |
+
# perform guidance
|
| 169 |
+
if do_classifier_free_guidance:
|
| 170 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 171 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 172 |
+
noise_pred_text - noise_pred_uncond
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 176 |
+
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
| 177 |
+
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[
|
| 178 |
+
"prev_sample"
|
| 179 |
+
]
|
| 180 |
+
else:
|
| 181 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[
|
| 182 |
+
"prev_sample"
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
# replace the unmasked part with original latents, with added noise
|
| 186 |
+
if mask is not None:
|
| 187 |
+
timesteps = self.scheduler.timesteps[t_start + i]
|
| 188 |
+
timesteps = torch.tensor(
|
| 189 |
+
[timesteps] * batch_size, dtype=torch.long, device=self.device
|
| 190 |
+
)
|
| 191 |
+
noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps)
|
| 192 |
+
latents = noisy_init_latents * mask + latents * (1 - mask)
|
| 193 |
+
|
| 194 |
+
# scale and decode the image latents with vae
|
| 195 |
+
latents = 1 / 0.18215 * latents
|
| 196 |
+
image = self.vae.decode(latents)
|
| 197 |
+
|
| 198 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 199 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
| 200 |
+
|
| 201 |
+
# run safety checker
|
| 202 |
+
safety_cheker_input = self.feature_extractor(
|
| 203 |
+
self.numpy_to_pil(image), return_tensors="pt"
|
| 204 |
+
).to(self.device)
|
| 205 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 206 |
+
images=image, clip_input=safety_cheker_input.pixel_values
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
image = self.numpy_to_pil(image)
|
| 210 |
+
|
| 211 |
+
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
| 212 |
+
|
| 213 |
+
def latents_from_init_image(
|
| 214 |
+
self,
|
| 215 |
+
init_image: torch.FloatTensor,
|
| 216 |
+
prompt_strength: float,
|
| 217 |
+
offset: int,
|
| 218 |
+
num_inference_steps: int,
|
| 219 |
+
batch_size: int,
|
| 220 |
+
generator: Optional[torch.Generator],
|
| 221 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]:
|
| 222 |
+
# encode the init image into latents and scale the latents
|
| 223 |
+
init_latents = self.vae.encode(init_image.to(self.device)).sample()
|
| 224 |
+
init_latents = 0.18215 * init_latents
|
| 225 |
+
init_latents_orig = init_latents
|
| 226 |
+
|
| 227 |
+
# prepare init_latents noise to latents
|
| 228 |
+
init_latents = torch.cat([init_latents] * batch_size)
|
| 229 |
+
|
| 230 |
+
# get the original timestep using init_timestep
|
| 231 |
+
init_timestep = int(num_inference_steps * prompt_strength) + offset
|
| 232 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
| 233 |
+
timesteps = self.scheduler.timesteps[-init_timestep]
|
| 234 |
+
timesteps = torch.tensor(
|
| 235 |
+
[timesteps] * batch_size, dtype=torch.long, device=self.device
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# add noise to latents using the timesteps
|
| 239 |
+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
| 240 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
| 241 |
+
|
| 242 |
+
return init_latents_orig, init_latents, init_timestep
|
| 243 |
+
|
| 244 |
+
def embed_text(
|
| 245 |
+
self,
|
| 246 |
+
prompt: Union[str, List[str]],
|
| 247 |
+
do_classifier_free_guidance: bool,
|
| 248 |
+
batch_size: int,
|
| 249 |
+
) -> torch.FloatTensor:
|
| 250 |
+
# get prompt text embeddings
|
| 251 |
+
text_input = self.tokenizer(
|
| 252 |
+
prompt,
|
| 253 |
+
padding="max_length",
|
| 254 |
+
max_length=self.tokenizer.model_max_length,
|
| 255 |
+
truncation=True,
|
| 256 |
+
return_tensors="pt",
|
| 257 |
+
)
|
| 258 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
| 259 |
+
|
| 260 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 261 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 262 |
+
# corresponds to doing no classifier free guidance.
|
| 263 |
+
# get unconditional embeddings for classifier free guidance
|
| 264 |
+
if do_classifier_free_guidance:
|
| 265 |
+
max_length = text_input.input_ids.shape[-1]
|
| 266 |
+
uncond_input = self.tokenizer(
|
| 267 |
+
[""] * batch_size,
|
| 268 |
+
padding="max_length",
|
| 269 |
+
max_length=max_length,
|
| 270 |
+
return_tensors="pt",
|
| 271 |
+
)
|
| 272 |
+
uncond_embeddings = self.text_encoder(
|
| 273 |
+
uncond_input.input_ids.to(self.device)
|
| 274 |
+
)[0]
|
| 275 |
+
|
| 276 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 277 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 278 |
+
# to avoid doing two forward passes
|
| 279 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 280 |
+
|
| 281 |
+
return text_embeddings
|
predict.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import autocast
|
| 7 |
+
from diffusers import PNDMScheduler, LMSDiscreteScheduler
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from cog import BasePredictor, Input, Path
|
| 10 |
+
|
| 11 |
+
from image_to_image import (
|
| 12 |
+
StableDiffusionImg2ImgPipeline,
|
| 13 |
+
preprocess_init_image,
|
| 14 |
+
preprocess_mask,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def patch_conv(**patch):
|
| 18 |
+
cls = torch.nn.Conv2d
|
| 19 |
+
init = cls.__init__
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
return init(self, *args, **kwargs, **patch)
|
| 22 |
+
cls.__init__ = __init__
|
| 23 |
+
|
| 24 |
+
patch_conv(padding_mode='circular')
|
| 25 |
+
|
| 26 |
+
MODEL_CACHE = "diffusers-cache"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Predictor(BasePredictor):
|
| 30 |
+
def setup(self):
|
| 31 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
| 32 |
+
print("Loading pipeline...")
|
| 33 |
+
scheduler = PNDMScheduler(
|
| 34 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
| 35 |
+
)
|
| 36 |
+
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
| 37 |
+
"CompVis/stable-diffusion-v1-4",
|
| 38 |
+
scheduler=scheduler,
|
| 39 |
+
revision="fp16",
|
| 40 |
+
torch_dtype=torch.float16,
|
| 41 |
+
cache_dir=MODEL_CACHE,
|
| 42 |
+
local_files_only=True,
|
| 43 |
+
).to("cuda")
|
| 44 |
+
|
| 45 |
+
@torch.inference_mode()
|
| 46 |
+
@torch.cuda.amp.autocast()
|
| 47 |
+
def predict(
|
| 48 |
+
self,
|
| 49 |
+
prompt: str = Input(description="Input prompt", default=""),
|
| 50 |
+
width: int = Input(
|
| 51 |
+
description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
|
| 52 |
+
choices=[128, 256, 512, 768, 1024],
|
| 53 |
+
default=512,
|
| 54 |
+
),
|
| 55 |
+
height: int = Input(
|
| 56 |
+
description="Height of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
|
| 57 |
+
choices=[128, 256, 512, 768, 1024],
|
| 58 |
+
default=512,
|
| 59 |
+
),
|
| 60 |
+
init_image: Path = Input(
|
| 61 |
+
description="Inital image to generate variations of. Will be resized to the specified width and height",
|
| 62 |
+
default=None,
|
| 63 |
+
),
|
| 64 |
+
mask: Path = Input(
|
| 65 |
+
description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7",
|
| 66 |
+
default=None,
|
| 67 |
+
),
|
| 68 |
+
prompt_strength: float = Input(
|
| 69 |
+
description="Prompt strength when using init image. 1.0 corresponds to full destruction of information in init image",
|
| 70 |
+
default=0.8,
|
| 71 |
+
),
|
| 72 |
+
num_outputs: int = Input(
|
| 73 |
+
description="Number of images to output", choices=[1, 4], default=1
|
| 74 |
+
),
|
| 75 |
+
num_inference_steps: int = Input(
|
| 76 |
+
description="Number of denoising steps", ge=1, le=500, default=50
|
| 77 |
+
),
|
| 78 |
+
guidance_scale: float = Input(
|
| 79 |
+
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5
|
| 80 |
+
),
|
| 81 |
+
seed: int = Input(
|
| 82 |
+
description="Random seed. Leave blank to randomize the seed", default=None
|
| 83 |
+
),
|
| 84 |
+
) -> List[Path]:
|
| 85 |
+
"""Run a single prediction on the model"""
|
| 86 |
+
if seed is None:
|
| 87 |
+
seed = int.from_bytes(os.urandom(2), "big")
|
| 88 |
+
print(f"Using seed: {seed}")
|
| 89 |
+
|
| 90 |
+
if width == height == 1024:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits. Please select a lower width or height."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if init_image:
|
| 96 |
+
init_image = Image.open(init_image).convert("RGB")
|
| 97 |
+
init_image = preprocess_init_image(init_image, width, height).to("cuda")
|
| 98 |
+
|
| 99 |
+
# use PNDM with init images
|
| 100 |
+
scheduler = PNDMScheduler(
|
| 101 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
# use LMS without init images
|
| 105 |
+
scheduler = LMSDiscreteScheduler(
|
| 106 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.pipe.scheduler = scheduler
|
| 110 |
+
|
| 111 |
+
if mask:
|
| 112 |
+
mask = Image.open(mask).convert("RGB")
|
| 113 |
+
mask = preprocess_mask(mask, width, height).to("cuda")
|
| 114 |
+
|
| 115 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
| 116 |
+
output = self.pipe(
|
| 117 |
+
prompt=[prompt] * num_outputs if prompt is not None else None,
|
| 118 |
+
init_image=init_image,
|
| 119 |
+
mask=mask,
|
| 120 |
+
width=width,
|
| 121 |
+
height=height,
|
| 122 |
+
prompt_strength=prompt_strength,
|
| 123 |
+
guidance_scale=guidance_scale,
|
| 124 |
+
generator=generator,
|
| 125 |
+
num_inference_steps=num_inference_steps,
|
| 126 |
+
)
|
| 127 |
+
if any(output["nsfw_content_detected"]):
|
| 128 |
+
raise Exception("NSFW content detected, please try a different prompt")
|
| 129 |
+
|
| 130 |
+
output_paths = []
|
| 131 |
+
for i, sample in enumerate(output["sample"]):
|
| 132 |
+
output_path = f"/tmp/out-{i}.png"
|
| 133 |
+
sample.save(output_path)
|
| 134 |
+
output_paths.append(Path(output_path))
|
| 135 |
+
|
| 136 |
+
return output_paths
|
script/download-weights
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from diffusers import StableDiffusionPipeline
|
| 8 |
+
|
| 9 |
+
os.makedirs("diffusers-cache", exist_ok=True)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 13 |
+
"CompVis/stable-diffusion-v1-4",
|
| 14 |
+
cache_dir="diffusers-cache",
|
| 15 |
+
revision="fp16",
|
| 16 |
+
torch_dtype=torch.float16,
|
| 17 |
+
use_auth_token=sys.argv[1],
|
| 18 |
+
)
|