| import numpy as np |
| import PIL |
| from PIL import Image |
| import torch |
|
|
|
|
| def pil_to_tensor( |
| img: Image.Image, |
| target_image_size=512, |
| lock_ratio=True, |
| center_crop=True, |
| padding=False, |
| standardize=True, |
| **kwarg |
| ) -> torch.Tensor: |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
|
|
| if isinstance(target_image_size, int): |
| target_size = (target_image_size, target_image_size) |
| if target_image_size < 0: |
| target_size = img.size |
| else: |
| target_size = target_image_size |
|
|
| if lock_ratio: |
| original_width, original_height = img.size |
| target_width, target_height = target_size |
|
|
| scale_w = target_width / original_width |
| scale_h = target_height / original_height |
|
|
| if center_crop: |
| scale = max(scale_w, scale_h) |
| elif padding: |
| scale = min(scale_w, scale_h) |
| else: |
| scale = 1.0 |
|
|
| new_size = (round(original_width * scale), round(original_height * scale)) |
| img = img.resize(new_size, Image.LANCZOS) |
|
|
| if center_crop: |
| left = (img.width - target_width) // 2 |
| top = (img.height - target_height) // 2 |
| img = img.crop((left, top, left + target_width, top + target_height)) |
| elif padding: |
| new_img = Image.new("RGB", target_size, (0, 0, 0)) |
| left = (target_width - img.width) // 2 |
| top = (target_height - img.height) // 2 |
| new_img.paste(img, (left, top)) |
| img = new_img |
| else: |
| img = img.resize(target_size, Image.LANCZOS) |
|
|
| np_img = np.array(img) / 255.0 |
| if standardize: |
| np_img = np_img * 2 - 1 |
| tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() |
|
|
| return tensor_img |
|
|
|
|
| def tensor_to_pil(chw_tensor: torch.Tensor, standardize=True, **kwarg) -> PIL.Image: |
| |
| detached_chw_tensor = chw_tensor.detach().cpu() |
|
|
| |
| if standardize: |
| normalized_chw_tensor = ( |
| torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0 |
| ) / 2.0 |
| else: |
| normalized_chw_tensor = torch.clamp(detached_chw_tensor, 0.0, 1.0) |
|
|
| |
| hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy() |
|
|
| |
| image_array_uint8 = (hwc_array * 255).astype(np.uint8) |
|
|
| |
| pil_image = Image.fromarray(image_array_uint8) |
|
|
| |
| if pil_image.mode != "RGB": |
| pil_image = pil_image.convert("RGB") |
|
|
| return pil_image |
|
|