| """ |
| GPU-Native Eye Image Processor for Color Fundus Photography (CFP) Images. |
| |
| This module implements a fully PyTorch-based image processor that: |
| 1. Localizes the eye/fundus region using gradient-based radial symmetry |
| 2. Crops to a border-minimized square centered on the eye |
| 3. Applies CLAHE for contrast enhancement |
| 4. Outputs tensors compatible with Hugging Face vision models |
| |
| Constraints: |
| - PyTorch only (no OpenCV, PIL, NumPy in runtime) |
| - CUDA-compatible, batch-friendly, deterministic |
| """ |
|
|
| from typing import Dict, List, Optional, Union |
| import math |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers.image_processing_utils import BaseImageProcessor |
| from transformers.image_processing_base import BatchFeature |
|
|
| |
| try: |
| from PIL import Image |
| PIL_AVAILABLE = True |
| except ImportError: |
| PIL_AVAILABLE = False |
|
|
| try: |
| import numpy as np |
| NUMPY_AVAILABLE = True |
| except ImportError: |
| NUMPY_AVAILABLE = False |
|
|
|
|
| |
| |
| |
|
|
| def _pil_to_tensor(image: "Image.Image") -> torch.Tensor: |
| """Convert a single PIL Image to a float32 tensor of shape (C, H, W) in [0, 1]. |
| |
| Converts to RGB if not already. Uses numpy as intermediate when available, |
| otherwise falls back to manual pixel extraction. |
| """ |
| if not PIL_AVAILABLE: |
| raise ImportError("PIL is required to process PIL Images") |
|
|
| |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| |
| if NUMPY_AVAILABLE: |
| arr = np.array(image, dtype=np.float32) / 255.0 |
| |
| tensor = torch.from_numpy(arr).permute(2, 0, 1) |
| else: |
| |
| width, height = image.size |
| pixels = list(image.getdata()) |
| tensor = torch.tensor(pixels, dtype=torch.float32).view(height, width, 3) / 255.0 |
| tensor = tensor.permute(2, 0, 1) |
|
|
| return tensor |
|
|
|
|
| def _numpy_to_tensor(arr: "np.ndarray") -> torch.Tensor: |
| """Convert a single numpy array to a float32 tensor of shape (C, H, W) in [0, 1]. |
| |
| Handles grayscale (H, W), HWC (H, W, C) with C in {1, 3, 4}, and uint8/float inputs. |
| Makes a copy to avoid sharing memory with the source array. |
| """ |
| if not NUMPY_AVAILABLE: |
| raise ImportError("NumPy is required to process numpy arrays") |
|
|
| |
| if arr.ndim == 2: |
| |
| arr = arr[..., None] |
|
|
| if arr.ndim == 3 and arr.shape[-1] in [1, 3, 4]: |
| |
| arr = arr.transpose(2, 0, 1) |
|
|
| |
| if arr.dtype == np.uint8: |
| arr = arr.astype(np.float32) / 255.0 |
| elif arr.dtype != np.float32: |
| arr = arr.astype(np.float32) |
|
|
| return torch.from_numpy(arr.copy()) |
|
|
|
|
| def standardize_input( |
| images: Union[torch.Tensor, List[torch.Tensor], "Image.Image", List["Image.Image"], "np.ndarray", List["np.ndarray"]], |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """Convert heterogeneous image inputs to a standardized (B, C, H, W) float32 tensor in [0, 1]. |
| |
| Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof. Integer-typed |
| inputs (uint8) are scaled to [0, 1]. The output is clamped to [0, 1]. |
| |
| Note: All images in a list must have the same spatial dimensions (required by torch.stack). |
| A single numpy array with ndim==3 is treated as a single HWC image if the last dimension |
| is in {1, 3, 4}; otherwise it falls through to the tensor path (assumed CHW). |
| |
| Args: |
| images: Input as: |
| - torch.Tensor (C,H,W), (B,C,H,W), or list of tensors |
| - PIL.Image.Image or list of PIL Images |
| - numpy.ndarray (H,W,C), (B,H,W,C), or list of arrays |
| device: Target device (defaults to input device or CPU) |
| |
| Returns: |
| Tensor of shape (B, C, H, W) in float32, range [0, 1] |
| """ |
| |
| if PIL_AVAILABLE and isinstance(images, Image.Image): |
| images = [images] |
| if NUMPY_AVAILABLE and isinstance(images, np.ndarray) and images.ndim == 3: |
| |
| if images.shape[-1] in [1, 3, 4]: |
| images = [images] |
|
|
| |
| if isinstance(images, list): |
| converted = [] |
| for img in images: |
| if PIL_AVAILABLE and isinstance(img, Image.Image): |
| converted.append(_pil_to_tensor(img)) |
| elif NUMPY_AVAILABLE and isinstance(img, np.ndarray): |
| converted.append(_numpy_to_tensor(img)) |
| elif isinstance(img, torch.Tensor): |
| t = img if img.dim() == 3 else img.squeeze(0) |
| converted.append(t) |
| else: |
| raise TypeError(f"Unsupported image type: {type(img)}") |
| images = torch.stack(converted) |
| elif NUMPY_AVAILABLE and isinstance(images, np.ndarray): |
| |
| if images.ndim == 4: |
| images = images.transpose(0, 3, 1, 2) |
| if images.dtype == np.uint8: |
| images = images.astype(np.float32) / 255.0 |
| images = torch.from_numpy(images.copy()) |
|
|
| if images.dim() == 3: |
| |
| images = images.unsqueeze(0) |
|
|
| |
| if device is not None: |
| images = images.to(device) |
|
|
| |
| if images.dtype == torch.uint8: |
| images = images.float() / 255.0 |
| elif images.dtype != torch.float32: |
| images = images.float() |
|
|
| |
| images = images.clamp(0.0, 1.0) |
|
|
| return images |
|
|
| def standardize_mask_input( |
| masks: Union[ |
| torch.Tensor, |
| List[torch.Tensor], |
| "Image.Image", |
| List["Image.Image"], |
| "np.ndarray", |
| List["np.ndarray"], |
| ], |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """Convert heterogeneous mask inputs to a standardized (B, 1, H, W) tensor. |
| |
| Unlike ``standardize_input``, this preserves the original dtype (typically integer |
| label values) and does **not** normalize to [0, 1]. |
| |
| Accepts torch.Tensor, PIL.Image, numpy.ndarray, or lists thereof. |
| A single 2-D input is treated as (H, W) and expanded to (1, 1, H, W). |
| |
| Args: |
| masks: Input masks in any supported format. |
| device: Target device. |
| |
| Returns: |
| Tensor of shape (B, 1, H, W) with original dtype preserved. |
| """ |
|
|
| |
| if PIL_AVAILABLE and isinstance(masks, Image.Image): |
| masks = [masks] |
|
|
| if NUMPY_AVAILABLE and isinstance(masks, np.ndarray) and masks.ndim == 2: |
| masks = [masks] |
|
|
| |
| if isinstance(masks, list): |
| converted = [] |
| for m in masks: |
| if PIL_AVAILABLE and isinstance(m, Image.Image): |
| |
| m = np.array(m) |
| converted.append(torch.from_numpy(m)) |
| elif NUMPY_AVAILABLE and isinstance(m, np.ndarray): |
| converted.append(torch.from_numpy(m)) |
| elif isinstance(m, torch.Tensor): |
| converted.append(m) |
| else: |
| raise TypeError(f"Unsupported mask type: {type(m)}") |
|
|
| masks = torch.stack(converted) |
|
|
| elif NUMPY_AVAILABLE and isinstance(masks, np.ndarray): |
| masks = torch.from_numpy(masks) |
|
|
| |
|
|
| if masks.dim() == 2: |
| |
| masks = masks.unsqueeze(0).unsqueeze(0) |
| elif masks.dim() == 3: |
| |
| masks = masks.unsqueeze(1) |
| elif masks.dim() == 4: |
| |
| pass |
| else: |
| raise ValueError(f"Invalid mask shape: {masks.shape}") |
|
|
| |
| if device is not None: |
| masks = masks.to(device) |
|
|
| return masks |
|
|
|
|
| def rgb_to_grayscale(images: torch.Tensor) -> torch.Tensor: |
| """Convert RGB images to grayscale via ITU-R BT.601 luminance: Y = 0.299R + 0.587G + 0.114B. |
| |
| Args: |
| images: Tensor of shape (B, 3, H, W) in any value range. |
| |
| Returns: |
| Tensor of shape (B, 1, H, W) in the same value range as input. |
| """ |
| |
| weights = torch.tensor([0.299, 0.587, 0.114], device=images.device, dtype=images.dtype) |
| weights = weights.view(1, 3, 1, 1) |
|
|
| grayscale = (images * weights).sum(dim=1, keepdim=True) |
| return grayscale |
|
|
|
|
| |
| |
| |
|
|
| def create_sobel_kernels(device: torch.device, dtype: torch.dtype) -> tuple: |
| """Create 3x3 Sobel edge-detection kernels for horizontal and vertical gradients. |
| |
| Args: |
| device: Target device for the kernels. |
| dtype: Target dtype for the kernels. |
| |
| Returns: |
| Tuple of (sobel_x, sobel_y) kernels, each of shape (1, 1, 3, 3), |
| suitable for use with ``F.conv2d`` on single-channel input. |
| """ |
| sobel_x = torch.tensor([ |
| [-1, 0, 1], |
| [-2, 0, 2], |
| [-1, 0, 1] |
| ], device=device, dtype=dtype).view(1, 1, 3, 3) |
|
|
| sobel_y = torch.tensor([ |
| [-1, -2, -1], |
| [ 0, 0, 0], |
| [ 1, 2, 1] |
| ], device=device, dtype=dtype).view(1, 1, 3, 3) |
|
|
| return sobel_x, sobel_y |
|
|
|
|
| def compute_gradients(grayscale: torch.Tensor) -> tuple: |
| """Compute horizontal and vertical image gradients using 3x3 Sobel filters. |
| |
| Uses reflect-free padding=1 (zero-padded convolution) to maintain spatial size. |
| |
| Args: |
| grayscale: Single-channel images of shape (B, 1, H, W). |
| |
| Returns: |
| Tuple of (grad_x, grad_y, grad_magnitude), each (B, 1, H, W). |
| ``grad_magnitude`` = sqrt(grad_x^2 + grad_y^2 + 1e-8). |
| """ |
| sobel_x, sobel_y = create_sobel_kernels(grayscale.device, grayscale.dtype) |
|
|
| |
| grad_x = F.conv2d(grayscale, sobel_x, padding=1) |
| grad_y = F.conv2d(grayscale, sobel_y, padding=1) |
|
|
| |
| grad_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8) |
|
|
| return grad_x, grad_y, grad_magnitude |
|
|
|
|
| def compute_radial_symmetry_response( |
| grayscale: torch.Tensor, |
| grad_x: torch.Tensor, |
| grad_y: torch.Tensor, |
| grad_magnitude: torch.Tensor, |
| ) -> torch.Tensor: |
| """Compute a radial-symmetry response map for circular-region detection. |
| |
| The algorithm: |
| 1. Estimates an initial center as the intensity-weighted center of mass of |
| dark regions (squared inverse intensity). |
| 2. For each pixel, computes the dot product between the normalized gradient |
| vector and the unit vector pointing toward the estimated center. |
| 3. Weights this alignment score by gradient magnitude and darkness. |
| 4. Smooths the response with a separable Gaussian whose sigma is |
| proportional to the image size (kernel_size = max(H,W)//8, sigma = kernel_size/6). |
| |
| High response indicates pixels whose gradients point radially inward toward |
| a dark center — characteristic of the fundus disc boundary. |
| |
| Args: |
| grayscale: Grayscale images (B, 1, H, W) in [0, 1]. |
| grad_x: Horizontal gradient (B, 1, H, W). |
| grad_y: Vertical gradient (B, 1, H, W). |
| grad_magnitude: Gradient magnitude (B, 1, H, W). |
| |
| Returns: |
| Smoothed radial symmetry response map (B, 1, H, W). |
| """ |
| B, _, H, W = grayscale.shape |
| device = grayscale.device |
| dtype = grayscale.dtype |
|
|
| |
| y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) |
| x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) |
|
|
| |
| |
| dark_weight = 1.0 - grayscale |
| dark_weight = dark_weight ** 2 |
|
|
| |
| weight_sum = dark_weight.sum(dim=(2, 3), keepdim=True) + 1e-8 |
|
|
| |
| cx_init = (dark_weight * x_coords).sum(dim=(2, 3), keepdim=True) / weight_sum |
| cy_init = (dark_weight * y_coords).sum(dim=(2, 3), keepdim=True) / weight_sum |
|
|
| |
| dx_to_center = cx_init - x_coords |
| dy_to_center = cy_init - y_coords |
| dist_to_center = torch.sqrt(dx_to_center ** 2 + dy_to_center ** 2 + 1e-8) |
|
|
| |
| dx_norm = dx_to_center / dist_to_center |
| dy_norm = dy_to_center / dist_to_center |
|
|
| |
| grad_norm = grad_magnitude + 1e-8 |
| gx_norm = grad_x / grad_norm |
| gy_norm = grad_y / grad_norm |
|
|
| |
| |
| radial_alignment = gx_norm * dx_norm + gy_norm * dy_norm |
|
|
| |
| response = radial_alignment * grad_magnitude * dark_weight |
|
|
| |
| kernel_size = max(H, W) // 8 |
| if kernel_size % 2 == 0: |
| kernel_size += 1 |
| kernel_size = max(kernel_size, 5) |
|
|
| sigma = kernel_size / 6.0 |
|
|
| |
| x = torch.arange(kernel_size, device=device, dtype=dtype) - kernel_size // 2 |
| gaussian_1d = torch.exp(-x ** 2 / (2 * sigma ** 2)) |
| gaussian_1d = gaussian_1d / gaussian_1d.sum() |
|
|
| |
| gaussian_1d_h = gaussian_1d.view(1, 1, 1, kernel_size) |
| gaussian_1d_v = gaussian_1d.view(1, 1, kernel_size, 1) |
|
|
| pad_h = kernel_size // 2 |
| pad_v = kernel_size // 2 |
|
|
| response = F.pad(response, (pad_h, pad_h, 0, 0), mode='reflect') |
| response = F.conv2d(response, gaussian_1d_h) |
| response = F.pad(response, (0, 0, pad_v, pad_v), mode='reflect') |
| response = F.conv2d(response, gaussian_1d_v) |
|
|
| return response |
|
|
|
|
| def soft_argmax_2d(response: torch.Tensor, temperature: float = 0.1) -> tuple: |
| """Find the sub-pixel peak location in a response map via softmax-weighted coordinates. |
| |
| Divides the flattened response by ``temperature`` before applying softmax, then |
| computes the weighted mean of the (x, y) coordinate grids. Lower temperature yields |
| a sharper, more argmax-like result; higher temperature yields a broader average. |
| |
| Caution: Very low temperatures (< 0.01) combined with large response magnitudes |
| can cause numerical overflow in the softmax exponential. |
| |
| Args: |
| response: Response map (B, 1, H, W). |
| temperature: Softmax temperature. Default 0.1. |
| |
| Returns: |
| Tuple of (cx, cy), each of shape (B,), in pixel coordinates. |
| """ |
| B, _, H, W = response.shape |
| device = response.device |
| dtype = response.dtype |
|
|
| |
| response_flat = response.view(B, -1) |
|
|
| |
| weights = F.softmax(response_flat / temperature, dim=1) |
| weights = weights.view(B, 1, H, W) |
|
|
| |
| y_coords = torch.arange(H, device=device, dtype=dtype).view(1, 1, H, 1).expand(B, 1, H, W) |
| x_coords = torch.arange(W, device=device, dtype=dtype).view(1, 1, 1, W).expand(B, 1, H, W) |
|
|
| |
| cx = (weights * x_coords).sum(dim=(2, 3)).squeeze(-1) |
| cy = (weights * y_coords).sum(dim=(2, 3)).squeeze(-1) |
|
|
| return cx, cy |
|
|
|
|
| def estimate_eye_center( |
| images: torch.Tensor, |
| softmax_temperature: float = 0.1, |
| ) -> tuple: |
| """Estimate the center of the fundus/eye disc in each image. |
| |
| Pipeline: RGB → grayscale → Sobel gradients → radial symmetry response → soft argmax. |
| |
| Args: |
| images: RGB images of shape (B, 3, H, W) in [0, 1]. |
| softmax_temperature: Temperature for the soft-argmax peak finder. |
| Lower values (0.01-0.1) give sharper localization; higher values |
| (0.3-0.5) give broader averaging, useful for noisy or low-contrast images. |
| Default 0.1. |
| |
| Returns: |
| Tuple of (cx, cy), each of shape (B,), in pixel coordinates. |
| """ |
| grayscale = rgb_to_grayscale(images) |
| grad_x, grad_y, grad_magnitude = compute_gradients(grayscale) |
| response = compute_radial_symmetry_response(grayscale, grad_x, grad_y, grad_magnitude) |
| cx, cy = soft_argmax_2d(response, temperature=softmax_temperature) |
|
|
| return cx, cy |
|
|
|
|
| |
| |
| |
|
|
| def estimate_radius( |
| images: torch.Tensor, |
| cx: torch.Tensor, |
| cy: torch.Tensor, |
| num_radii: int = 100, |
| num_angles: int = 36, |
| min_radius_frac: float = 0.1, |
| max_radius_frac: float = 0.5, |
| ) -> torch.Tensor: |
| """Estimate the radius of the fundus disc by analyzing radial intensity profiles. |
| |
| Samples grayscale intensity along ``num_angles`` rays emanating from ``(cx, cy)`` |
| at ``num_radii`` radial distances. The per-radius mean intensity across all angles |
| gives a 1-D radial profile. The discrete derivative of this profile is linearly |
| weighted by radius (range 0.5–1.5) to bias toward the outer fundus boundary |
| rather than the smaller pupil boundary. The radius at the strongest weighted |
| negative gradient is selected as the disc edge. |
| |
| Uses ``F.grid_sample`` with bilinear interpolation and border padding for |
| sub-pixel sampling. |
| |
| Args: |
| images: RGB images (B, 3, H, W) in [0, 1]. |
| cx, cy: Center coordinates (B,) in pixel units. |
| num_radii: Number of radial sample points. Default 100. |
| num_angles: Number of angular sample rays. Default 36. |
| min_radius_frac: Minimum search radius as fraction of min(H, W). Default 0.1. |
| max_radius_frac: Maximum search radius as fraction of min(H, W). Default 0.5. |
| |
| Returns: |
| Estimated radius for each image (B,), clamped to [min_radius, max_radius]. |
| """ |
| B, _, H, W = images.shape |
| device = images.device |
| dtype = images.dtype |
|
|
| grayscale = rgb_to_grayscale(images) |
|
|
| min_dim = min(H, W) |
| min_radius = int(min_radius_frac * min_dim) |
| max_radius = int(max_radius_frac * min_dim) |
|
|
| |
| radii = torch.linspace(min_radius, max_radius, num_radii, device=device, dtype=dtype) |
| angles = torch.linspace(0, 2 * math.pi, num_angles + 1, device=device, dtype=dtype)[:-1] |
|
|
| |
| cos_angles = torch.cos(angles).view(-1, 1) |
| sin_angles = torch.sin(angles).view(-1, 1) |
|
|
| |
| dx = cos_angles * radii |
| dy = sin_angles * radii |
|
|
| |
| |
| cx_expanded = cx.view(B, 1, 1).expand(B, num_angles, num_radii) |
| cy_expanded = cy.view(B, 1, 1).expand(B, num_angles, num_radii) |
|
|
| sample_x = cx_expanded + dx.unsqueeze(0) |
| sample_y = cy_expanded + dy.unsqueeze(0) |
|
|
| |
| sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 |
| sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 |
|
|
| |
| grid = torch.stack([sample_x_norm, sample_y_norm], dim=-1) |
|
|
| |
| sampled = F.grid_sample( |
| grayscale, grid, mode='bilinear', padding_mode='border', align_corners=True |
| ) |
|
|
| |
| radial_profile = sampled.mean(dim=2).squeeze(1) |
|
|
| |
| radial_gradient = radial_profile[:, 1:] - radial_profile[:, :-1] |
|
|
| |
| |
| radius_weights = torch.linspace(0.5, 1.5, num_radii - 1, device=device, dtype=dtype) |
| weighted_gradient = radial_gradient * radius_weights.unsqueeze(0) |
|
|
| |
| min_idx = weighted_gradient.argmin(dim=1) |
|
|
| |
| estimated_radius = radii[min_idx + 1] |
|
|
| |
| estimated_radius = estimated_radius.clamp(min_radius, max_radius) |
|
|
| return estimated_radius |
|
|
|
|
| |
| |
| |
|
|
| def compute_crop_box( |
| cx: torch.Tensor, |
| cy: torch.Tensor, |
| radius: torch.Tensor, |
| H: int, |
| W: int, |
| scale_factor: float = 1.1, |
| allow_overflow: bool = False, |
| ) -> tuple: |
| """Compute a square bounding box centered on the detected eye. |
| |
| The half-side length is ``radius * scale_factor``. When ``allow_overflow`` is |
| False, the box is clamped to the image bounds and then made square by shrinking |
| to the shorter side and re-centering. The resulting box is guaranteed to be |
| square and fully within [0, W-1] x [0, H-1]. |
| |
| When ``allow_overflow`` is True the raw (possibly out-of-bounds) box is |
| returned, which is useful for images where the fundus disc is partially |
| clipped; out-of-bounds regions will be zero-filled during grid_sample. |
| |
| Args: |
| cx, cy: Detected eye center coordinates (B,). |
| radius: Estimated disc radius (B,). |
| H, W: Spatial dimensions of the source images. |
| scale_factor: Padding multiplier applied to ``radius``. Default 1.1. |
| allow_overflow: Skip clamping / squareness enforcement. Default False. |
| |
| Returns: |
| Tuple of (x1, y1, x2, y2), each of shape (B,), in pixel coordinates. |
| """ |
| |
| half_side = radius * scale_factor |
|
|
| |
| x1 = cx - half_side |
| y1 = cy - half_side |
| x2 = cx + half_side |
| y2 = cy + half_side |
|
|
| if allow_overflow: |
| |
| |
| return x1, y1, x2, y2 |
|
|
| |
| |
| x1 = x1.clamp(min=0) |
| y1 = y1.clamp(min=0) |
| x2 = x2.clamp(max=W - 1) |
| y2 = y2.clamp(max=H - 1) |
|
|
| |
| side_x = x2 - x1 |
| side_y = y2 - y1 |
| side = torch.minimum(side_x, side_y) |
|
|
| |
| cx_new = (x1 + x2) / 2 |
| cy_new = (y1 + y2) / 2 |
|
|
| x1 = (cx_new - side / 2).clamp(min=0) |
| y1 = (cy_new - side / 2).clamp(min=0) |
| x2 = x1 + side |
| y2 = y1 + side |
|
|
| |
| x2 = x2.clamp(max=W - 1) |
| y2 = y2.clamp(max=H - 1) |
|
|
| return x1, y1, x2, y2 |
|
|
|
|
| def batch_crop_and_resize( |
| images: torch.Tensor, |
| x1: torch.Tensor, |
| y1: torch.Tensor, |
| x2: torch.Tensor, |
| y2: torch.Tensor, |
| output_size: int, |
| padding_mode: str = 'border', |
| ) -> torch.Tensor: |
| """Crop and resize images to a square using ``F.grid_sample`` (GPU-friendly). |
| |
| Builds a regular output grid in [0, 1]^2, maps it to the source rectangle |
| [x1, x2] x [y1, y2] via affine scaling, normalizes to [-1, 1] for |
| ``grid_sample``, and samples with bilinear interpolation (``align_corners=True``). |
| |
| Crop coordinates may extend beyond image bounds; the ``padding_mode`` |
| controls how out-of-bounds pixels are filled. |
| |
| Args: |
| images: Input images (B, C, H, W). |
| x1, y1, x2, y2: Crop box corners (B,). May exceed [0, W-1] / [0, H-1]. |
| output_size: Side length of the square output. |
| padding_mode: ``'border'`` (repeat edge, default) or ``'zeros'`` (black fill). |
| |
| Returns: |
| Cropped and resized images (B, C, output_size, output_size). |
| """ |
| B, C, H, W = images.shape |
| device = images.device |
| dtype = images.dtype |
|
|
| |
| out_coords = torch.linspace(0, 1, output_size, device=device, dtype=dtype) |
| out_y, out_x = torch.meshgrid(out_coords, out_coords, indexing='ij') |
| out_grid = torch.stack([out_x, out_y], dim=-1) |
| out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) |
|
|
| |
| |
| x1 = x1.view(B, 1, 1, 1) |
| y1 = y1.view(B, 1, 1, 1) |
| x2 = x2.view(B, 1, 1, 1) |
| y2 = y2.view(B, 1, 1, 1) |
|
|
| |
| sample_x = x1 + out_grid[..., 0:1] * (x2 - x1) |
| sample_y = y1 + out_grid[..., 1:2] * (y2 - y1) |
|
|
| |
| sample_x_norm = 2.0 * sample_x / (W - 1) - 1.0 |
| sample_y_norm = 2.0 * sample_y / (H - 1) - 1.0 |
|
|
| grid = torch.cat([sample_x_norm, sample_y_norm], dim=-1) |
|
|
| |
| cropped = F.grid_sample( |
| images, grid, mode='bilinear', padding_mode=padding_mode, align_corners=True |
| ) |
|
|
| return cropped |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def batch_crop_and_resize_mask( |
| masks: torch.Tensor, |
| x1: torch.Tensor, |
| y1: torch.Tensor, |
| x2: torch.Tensor, |
| y2: torch.Tensor, |
| output_size: int, |
| padding_mode: str = "zeros", |
| ) -> torch.Tensor: |
| """Crop and resize segmentation masks using nearest-neighbor sampling. |
| |
| Same spatial transform as ``batch_crop_and_resize`` but uses ``mode='nearest'`` |
| to preserve discrete label values. The output is rounded and cast to ``torch.long`` |
| to guard against floating-point drift in ``grid_sample``. |
| |
| Args: |
| masks: Integer label masks (B, 1, H, W) — any dtype (converted to float internally). |
| x1, y1, x2, y2: Crop box corners (B,). May exceed image bounds. |
| output_size: Side length of the square output. |
| padding_mode: ``'zeros'`` (background = 0, default) or ``'border'`` (repeat edge). |
| |
| Returns: |
| Cropped and resized masks (B, 1, output_size, output_size) as ``torch.long``. |
| """ |
|
|
| B, C, H, W = masks.shape |
| device = masks.device |
|
|
| |
| masks_f = masks.float() |
|
|
| |
| coords = torch.linspace(0, 1, output_size, device=device) |
| out_y, out_x = torch.meshgrid(coords, coords, indexing="ij") |
| out_grid = torch.stack([out_x, out_y], dim=-1) |
| out_grid = out_grid.unsqueeze(0).expand(B, -1, -1, -1) |
|
|
| |
| x1 = x1.view(B, 1, 1, 1) |
| y1 = y1.view(B, 1, 1, 1) |
| x2 = x2.view(B, 1, 1, 1) |
| y2 = y2.view(B, 1, 1, 1) |
|
|
| |
| sample_x = x1 + out_grid[..., 0:1] * (x2 - x1) |
| sample_y = y1 + out_grid[..., 1:2] * (y2 - y1) |
|
|
| |
| sample_x = 2.0 * sample_x / (W - 1) - 1.0 |
| sample_y = 2.0 * sample_y / (H - 1) - 1.0 |
|
|
| grid = torch.cat([sample_x, sample_y], dim=-1) |
|
|
| |
| cropped = F.grid_sample( |
| masks_f, |
| grid, |
| mode="nearest", |
| padding_mode=padding_mode, |
| align_corners=True, |
| ) |
|
|
| |
| |
| |
| return cropped.round().long() |
|
|
| |
| |
| |
|
|
| def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor: |
| """Apply the sRGB electro-optical transfer function (EOTF) to convert sRGB to linear RGB. |
| |
| Uses the IEC 61966-2-1 piecewise formula with threshold 0.04045. |
| """ |
| threshold = 0.04045 |
| linear = torch.where( |
| rgb <= threshold, |
| rgb / 12.92, |
| ((rgb + 0.055) / 1.055) ** 2.4 |
| ) |
| return linear |
|
|
|
|
| def _linear_to_srgb(linear: torch.Tensor) -> torch.Tensor: |
| """Apply the inverse sRGB EOTF to convert linear RGB to sRGB. |
| |
| Uses the IEC 61966-2-1 piecewise formula with threshold 0.0031308. |
| Input must be non-negative; negative values will produce NaN from the power function. |
| """ |
| threshold = 0.0031308 |
| srgb = torch.where( |
| linear <= threshold, |
| linear * 12.92, |
| 1.055 * (linear ** (1.0 / 2.4)) - 0.055 |
| ) |
| return srgb |
|
|
|
|
| def rgb_to_lab(images: torch.Tensor) -> tuple: |
| """Convert sRGB images to CIE LAB colour space (D65 illuminant). |
| |
| Conversion chain: sRGB → linear RGB → CIE XYZ → CIE LAB. |
| The raw LAB values are rescaled for internal convenience: |
| |
| - L ∈ [0, 100] → L / 100 → [0, 1] |
| - a ∈ ~[-128, 127] → a / 256 + 0.5 → ~[0, 1] |
| - b ∈ ~[-128, 127] → b / 256 + 0.5 → ~[0, 1] |
| |
| These normalised values are **not** standard LAB; use ``lab_to_rgb`` to |
| invert them back to sRGB. |
| |
| Args: |
| images: RGB images (B, 3, H, W) in [0, 1] sRGB. |
| |
| Returns: |
| Tuple of (L, a, b_ch), each (B, 1, H, W): |
| - L: Normalised luminance in [0, 1]. |
| - a: Normalised green–red chrominance, roughly [0, 1]. |
| - b_ch: Normalised blue–yellow chrominance, roughly [0, 1]. |
| """ |
| device = images.device |
| dtype = images.dtype |
|
|
| |
| linear_rgb = _srgb_to_linear(images) |
|
|
| |
| |
| r = linear_rgb[:, 0:1, :, :] |
| g = linear_rgb[:, 1:2, :, :] |
| b = linear_rgb[:, 2:3, :, :] |
|
|
| x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b |
| y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b |
| z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b |
|
|
| |
| xn, yn, zn = 0.95047, 1.0, 1.08883 |
|
|
| x = x / xn |
| y = y / yn |
| z = z / zn |
|
|
| |
| delta = 6.0 / 29.0 |
| delta_cube = delta ** 3 |
|
|
| def f(t): |
| return torch.where( |
| t > delta_cube, |
| t ** (1.0 / 3.0), |
| t / (3.0 * delta ** 2) + 4.0 / 29.0 |
| ) |
|
|
| fx = f(x) |
| fy = f(y) |
| fz = f(z) |
|
|
| L = 116.0 * fy - 16.0 |
| a = 500.0 * (fx - fy) |
| b_ch = 200.0 * (fy - fz) |
|
|
| |
| L = L / 100.0 |
| a = a / 256.0 + 0.5 |
| b_ch = b_ch / 256.0 + 0.5 |
|
|
| return L, a, b_ch |
|
|
|
|
| def lab_to_rgb(L: torch.Tensor, a: torch.Tensor, b_ch: torch.Tensor) -> torch.Tensor: |
| """Convert normalised CIE LAB back to sRGB (inverse of ``rgb_to_lab``). |
| |
| Denormalisation: L*100, (a-0.5)*256, (b_ch-0.5)*256, then LAB → XYZ → linear RGB → sRGB. |
| Output is clamped to [0, 1]. |
| |
| Args: |
| L: Normalised luminance (B, 1, H, W) in [0, 1]. |
| a: Normalised green–red chrominance (B, 1, H, W), roughly [0, 1]. |
| b_ch: Normalised blue–yellow chrominance (B, 1, H, W), roughly [0, 1]. |
| |
| Returns: |
| sRGB images (B, 3, H, W) clamped to [0, 1]. |
| """ |
| |
| L_lab = L * 100.0 |
| a_lab = (a - 0.5) * 256.0 |
| b_lab = (b_ch - 0.5) * 256.0 |
|
|
| |
| fy = (L_lab + 16.0) / 116.0 |
| fx = a_lab / 500.0 + fy |
| fz = fy - b_lab / 200.0 |
|
|
| delta = 6.0 / 29.0 |
|
|
| def f_inv(t): |
| return torch.where( |
| t > delta, |
| t ** 3, |
| 3.0 * (delta ** 2) * (t - 4.0 / 29.0) |
| ) |
|
|
| |
| xn, yn, zn = 0.95047, 1.0, 1.08883 |
|
|
| x = xn * f_inv(fx) |
| y = yn * f_inv(fy) |
| z = zn * f_inv(fz) |
|
|
| |
| r = 3.2404542 * x - 1.5371385 * y - 0.4985314 * z |
| g = -0.9692660 * x + 1.8760108 * y + 0.0415560 * z |
| b = 0.0556434 * x - 0.2040259 * y + 1.0572252 * z |
|
|
| linear_rgb = torch.cat([r, g, b], dim=1) |
|
|
| |
| linear_rgb = linear_rgb.clamp(0.0, 1.0) |
|
|
| |
| srgb = _linear_to_srgb(linear_rgb) |
|
|
| return srgb.clamp(0.0, 1.0) |
|
|
|
|
| def compute_histogram( |
| tensor: torch.Tensor, |
| num_bins: int = 256, |
| ) -> torch.Tensor: |
| """Compute per-image histograms for a batch of single-channel images. |
| |
| Bins are uniformly spaced over [0, 1]. Each pixel is assigned to a bin via |
| ``floor(value * (num_bins - 1))``, accumulated with ``scatter_add`` in a |
| per-sample loop. |
| |
| Note: This function is used only by ``clahe_single_tile``. |
| The vectorized CLAHE path (``apply_clahe_vectorized``) computes histograms |
| inline for better GPU efficiency. |
| |
| Args: |
| tensor: Input (B, 1, H, W) with values in [0, 1]. |
| num_bins: Number of histogram bins. Default 256. |
| |
| Returns: |
| Histograms of shape (B, num_bins), dtype matching input. |
| """ |
| B = tensor.shape[0] |
| device = tensor.device |
| dtype = tensor.dtype |
|
|
| |
| flat = tensor.view(B, -1) |
|
|
| |
| bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) |
|
|
| |
| histograms = torch.zeros(B, num_bins, device=device, dtype=dtype) |
| ones = torch.ones_like(flat, dtype=dtype) |
|
|
| for i in range(B): |
| histograms[i] = histograms[i].scatter_add(0, bin_indices[i], ones[i]) |
|
|
| return histograms |
|
|
|
|
| def clahe_single_tile( |
| tile: torch.Tensor, |
| clip_limit: float, |
| num_bins: int = 256, |
| ) -> torch.Tensor: |
| """Compute the clipped-and-redistributed CDF for a single CLAHE tile. |
| |
| Clips the histogram so no bin exceeds ``clip_limit * num_pixels / num_bins``, |
| redistributes the excess uniformly, then computes and min-max normalises the CDF. |
| |
| Note: This function is not used by the main pipeline — see |
| ``apply_clahe_vectorized`` which processes all tiles in a single pass. |
| |
| Args: |
| tile: Single-channel tile images (B, 1, tile_h, tile_w) in [0, 1]. |
| clip_limit: Relative clip limit (higher = less contrast limiting). |
| num_bins: Number of histogram bins. Default 256. |
| |
| Returns: |
| Normalised CDF lookup table (B, num_bins) in [0, 1]. |
| """ |
| B, _, tile_h, tile_w = tile.shape |
| device = tile.device |
| dtype = tile.dtype |
| num_pixels = tile_h * tile_w |
|
|
| |
| hist = compute_histogram(tile, num_bins) |
|
|
| |
| clip_value = clip_limit * num_pixels / num_bins |
| excess = (hist - clip_value).clamp(min=0).sum(dim=1, keepdim=True) |
| hist = hist.clamp(max=clip_value) |
|
|
| |
| redistribution = excess / num_bins |
| hist = hist + redistribution |
|
|
| |
| cdf = hist.cumsum(dim=1) |
|
|
| |
| cdf_min = cdf[:, 0:1] |
| cdf_max = cdf[:, -1:] |
| cdf = (cdf - cdf_min) / (cdf_max - cdf_min + 1e-8) |
|
|
| return cdf |
|
|
|
|
| def apply_clahe_vectorized( |
| images: torch.Tensor, |
| grid_size: int = 8, |
| clip_limit: float = 2.0, |
| num_bins: int = 256, |
| ) -> torch.Tensor: |
| """Fully-vectorized CLAHE (Contrast Limited Adaptive Histogram Equalisation). |
| |
| For RGB input, converts to CIE LAB, applies CLAHE to the L channel only, |
| then converts back to sRGB. For single-channel input, operates directly. |
| |
| Algorithm: |
| 1. Pads the luminance channel to be divisible by ``grid_size`` (reflect padding). |
| 2. Reshapes into ``grid_size x grid_size`` non-overlapping tiles. |
| 3. Computes a histogram per tile via ``scatter_add_`` (fully batched, no loops). |
| 4. Clips each histogram at ``clip_limit * num_pixels / num_bins`` and |
| redistributes excess counts uniformly across all bins. |
| 5. Computes the cumulative distribution function (CDF) per tile and |
| min-max normalises it to [0, 1]. |
| 6. Maps each output pixel to the four surrounding tile centres and |
| bilinearly interpolates their CDF values for a smooth result. |
| |
| Args: |
| images: Input images (B, C, H, W) in [0, 1]. C must be 1 or 3. |
| grid_size: Tile grid resolution (tiles per axis). Default 8. |
| clip_limit: Relative clip limit for histogram clipping. Default 2.0. |
| num_bins: Number of histogram bins. Default 256. |
| |
| Returns: |
| CLAHE-enhanced images (B, C, H, W) in [0, 1]. |
| """ |
| B, C, H, W = images.shape |
| device = images.device |
| dtype = images.dtype |
|
|
| |
| if C == 3: |
| L, a, b_ch = rgb_to_lab(images) |
| else: |
| L = images.clone() |
| a = b_ch = None |
|
|
| |
| pad_h = (grid_size - H % grid_size) % grid_size |
| pad_w = (grid_size - W % grid_size) % grid_size |
|
|
| if pad_h > 0 or pad_w > 0: |
| L_padded = F.pad(L, (0, pad_w, 0, pad_h), mode='reflect') |
| else: |
| L_padded = L |
|
|
| _, _, H_pad, W_pad = L_padded.shape |
| tile_h = H_pad // grid_size |
| tile_w = W_pad // grid_size |
|
|
| |
| L_tiles = L_padded.view(B, 1, grid_size, tile_h, grid_size, tile_w) |
| L_tiles = L_tiles.permute(0, 2, 4, 1, 3, 5) |
| L_tiles = L_tiles.reshape(B * grid_size * grid_size, 1, tile_h, tile_w) |
|
|
| |
| num_pixels = tile_h * tile_w |
| flat = L_tiles.view(B * grid_size * grid_size, -1) |
| bin_indices = (flat * (num_bins - 1)).long().clamp(0, num_bins - 1) |
|
|
| |
| histograms = torch.zeros(B * grid_size * grid_size, num_bins, device=device, dtype=dtype) |
| histograms.scatter_add_(1, bin_indices, torch.ones_like(flat)) |
|
|
| |
| clip_value = clip_limit * num_pixels / num_bins |
| excess = (histograms - clip_value).clamp(min=0).sum(dim=1, keepdim=True) |
| histograms = histograms.clamp(max=clip_value) |
| histograms = histograms + excess / num_bins |
|
|
| |
| cdfs = histograms.cumsum(dim=1) |
| cdf_min = cdfs[:, 0:1] |
| cdf_max = cdfs[:, -1:] |
| cdfs = (cdfs - cdf_min) / (cdf_max - cdf_min + 1e-8) |
|
|
| |
| cdfs = cdfs.view(B, grid_size, grid_size, num_bins) |
|
|
| |
| y_coords = torch.arange(H_pad, device=device, dtype=dtype) |
| x_coords = torch.arange(W_pad, device=device, dtype=dtype) |
|
|
| |
| tile_y = (y_coords + 0.5) / tile_h - 0.5 |
| tile_x = (x_coords + 0.5) / tile_w - 0.5 |
|
|
| tile_y = tile_y.clamp(0, grid_size - 1.001) |
| tile_x = tile_x.clamp(0, grid_size - 1.001) |
|
|
| |
| ty0 = tile_y.long().clamp(0, grid_size - 2) |
| tx0 = tile_x.long().clamp(0, grid_size - 2) |
| ty1 = (ty0 + 1).clamp(max=grid_size - 1) |
| tx1 = (tx0 + 1).clamp(max=grid_size - 1) |
|
|
| wy = (tile_y - ty0.float()).view(1, H_pad, 1, 1) |
| wx = (tile_x - tx0.float()).view(1, 1, W_pad, 1) |
|
|
| |
| bin_idx = (L_padded * (num_bins - 1)).long().clamp(0, num_bins - 1) |
| bin_idx = bin_idx.squeeze(1) |
|
|
| |
| |
|
|
| |
| b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, H_pad, W_pad) |
| ty0_exp = ty0.view(1, H_pad, 1).expand(B, H_pad, W_pad) |
| ty1_exp = ty1.view(1, H_pad, 1).expand(B, H_pad, W_pad) |
| tx0_exp = tx0.view(1, 1, W_pad).expand(B, H_pad, W_pad) |
| tx1_exp = tx1.view(1, 1, W_pad).expand(B, H_pad, W_pad) |
|
|
| |
| v00 = cdfs[b_idx, ty0_exp, tx0_exp, bin_idx] |
| v01 = cdfs[b_idx, ty0_exp, tx1_exp, bin_idx] |
| v10 = cdfs[b_idx, ty1_exp, tx0_exp, bin_idx] |
| v11 = cdfs[b_idx, ty1_exp, tx1_exp, bin_idx] |
|
|
| |
| wy = wy.squeeze(-1) |
| wx = wx.squeeze(-1) |
|
|
| L_out = (1 - wy) * (1 - wx) * v00 + (1 - wy) * wx * v01 + wy * (1 - wx) * v10 + wy * wx * v11 |
| L_out = L_out.unsqueeze(1) |
|
|
| |
| if pad_h > 0 or pad_w > 0: |
| L_out = L_out[:, :, :H, :W] |
|
|
| |
| if C == 3: |
| output = lab_to_rgb(L_out, a, b_ch) |
| else: |
| output = L_out |
|
|
| return output |
|
|
|
|
| |
| |
| |
|
|
| |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
|
|
| def resize_images( |
| images: torch.Tensor, |
| size: int, |
| mode: str = 'bilinear', |
| antialias: bool = True, |
| ) -> torch.Tensor: |
| """Resize images to a square target size using ``F.interpolate``. |
| |
| Args: |
| images: Input images (B, C, H, W). Must be float for bilinear/bicubic modes. |
| size: Target side length (output is always square). |
| mode: Interpolation mode (``'bilinear'``, ``'bicubic'``, ``'nearest'``, etc.). |
| Default ``'bilinear'``. |
| antialias: Enable antialiasing for bilinear/bicubic downscaling. Default True. |
| |
| Returns: |
| Resized images (B, C, size, size). |
| """ |
| return F.interpolate( |
| images, |
| size=(size, size), |
| mode=mode, |
| align_corners=False if mode in ['bilinear', 'bicubic'] else None, |
| antialias=antialias if mode in ['bilinear', 'bicubic'] else False, |
| ) |
|
|
|
|
| def normalize_images( |
| images: torch.Tensor, |
| mean: Optional[List[float]] = None, |
| std: Optional[List[float]] = None, |
| mode: str = 'imagenet', |
| ) -> torch.Tensor: |
| """Channel-wise normalisation: ``(image - mean) / std``. |
| |
| Args: |
| images: Input images (B, C, H, W) in [0, 1]. |
| mean: Per-channel means (length C). Required when ``mode='custom'``. |
| std: Per-channel stds (length C). Required when ``mode='custom'``. |
| mode: ``'imagenet'`` (uses ImageNet stats), ``'none'`` (identity), or |
| ``'custom'`` (uses caller-supplied mean/std). Default ``'imagenet'``. |
| |
| Returns: |
| Normalised images (B, C, H, W). Range depends on mean/std. |
| """ |
| if mode == 'none': |
| return images |
|
|
| if mode == 'imagenet': |
| mean = IMAGENET_MEAN |
| std = IMAGENET_STD |
| elif mode == 'custom': |
| if mean is None or std is None: |
| raise ValueError("Custom mode requires mean and std") |
| else: |
| raise ValueError(f"Unknown normalization mode: {mode}") |
|
|
| device = images.device |
| dtype = images.dtype |
|
|
| mean_tensor = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1) |
| std_tensor = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1) |
|
|
| return (images - mean_tensor) / std_tensor |
|
|
|
|
| |
| |
| |
|
|
| class EyeCLAHEImageProcessor(BaseImageProcessor): |
| """GPU-native Hugging Face image processor for Colour Fundus Photography (CFP). |
| |
| Processing pipeline (all steps optional via constructor flags): |
| |
| 1. **Eye localisation** (``do_crop=True``): detects the fundus disc centre via |
| gradient-based radial symmetry (dark-region centre-of-mass → Sobel gradients → |
| radial alignment score → Gaussian smoothing → soft argmax) and estimates the |
| disc radius from the strongest negative radial intensity gradient. |
| 2. **Square crop & resize**: crops a square region around the detected disc |
| (``radius * crop_scale_factor``), optionally allowing overflow beyond image |
| bounds (``allow_overflow``), then resamples to ``size x size`` via bilinear |
| ``grid_sample``. When ``do_crop=False``, the whole image is resized directly. |
| 3. **CLAHE** (``do_clahe=True``): applies Contrast Limited Adaptive Histogram |
| Equalisation to the CIE LAB luminance channel, using a fully-vectorized |
| tile-based implementation with bilinear CDF interpolation. |
| 4. **Normalisation**: channel-wise ``(image - mean) / std`` with configurable |
| mode (ImageNet, custom, or none). |
| |
| The processor also returns per-image coordinate-mapping scalars (``scale_x/y``, |
| ``offset_x/y``) so that predictions in processed-image space can be mapped back |
| to original pixel coordinates. |
| |
| All operations are pure PyTorch — no OpenCV, PIL, or NumPy at runtime — and are |
| CUDA-compatible and batch-friendly. |
| """ |
|
|
| model_input_names = ["pixel_values"] |
|
|
| def __init__( |
| self, |
| size: int = 224, |
| crop_scale_factor: float = 1.1, |
| clahe_grid_size: int = 8, |
| clahe_clip_limit: float = 2.0, |
| normalization_mode: str = "imagenet", |
| custom_mean: Optional[List[float]] = None, |
| custom_std: Optional[List[float]] = None, |
| do_clahe: bool = True, |
| do_crop: bool = True, |
| min_radius_frac: float = 0.1, |
| max_radius_frac: float = 0.5, |
| allow_overflow: bool = False, |
| softmax_temperature: float = 0.1, |
| **kwargs, |
| ): |
| """ |
| Initialize the EyeCLAHEImageProcessor. |
| |
| Args: |
| size: Output image size (square) |
| crop_scale_factor: Scale factor for crop box (relative to detected radius) |
| clahe_grid_size: Number of tiles for CLAHE |
| clahe_clip_limit: Histogram clip limit for CLAHE |
| normalization_mode: 'imagenet', 'none', or 'custom' |
| custom_mean: Custom normalization mean (if mode='custom') |
| custom_std: Custom normalization std (if mode='custom') |
| do_clahe: Whether to apply CLAHE |
| do_crop: Whether to perform eye-centered cropping |
| min_radius_frac: Minimum radius as fraction of image size |
| max_radius_frac: Maximum radius as fraction of image size |
| allow_overflow: If True, allow crop box to extend beyond image bounds |
| and fill missing regions with black. Useful for pre-cropped |
| images where the fundus circle is partially cut off. |
| softmax_temperature: Temperature for soft argmax in eye center detection. |
| Lower values (0.01-0.1) give sharper peak detection, higher values |
| (0.3-0.5) provide more averaging for noisy images. Default: 0.1. |
| """ |
| super().__init__(**kwargs) |
|
|
| self.size = size |
| self.crop_scale_factor = crop_scale_factor |
| self.clahe_grid_size = clahe_grid_size |
| self.clahe_clip_limit = clahe_clip_limit |
| self.normalization_mode = normalization_mode |
| self.custom_mean = custom_mean |
| self.custom_std = custom_std |
| self.do_clahe = do_clahe |
| self.do_crop = do_crop |
| self.min_radius_frac = min_radius_frac |
| self.max_radius_frac = max_radius_frac |
| self.allow_overflow = allow_overflow |
| self.softmax_temperature = softmax_temperature |
|
|
| def preprocess( |
| self, |
| images, |
| masks=None, |
| return_tensors: str = "pt", |
| device: Optional[Union[str, torch.device]] = None, |
| **kwargs, |
| ) -> BatchFeature: |
| """Run the full preprocessing pipeline on a batch of images. |
| |
| Accepts any combination of torch.Tensor, PIL.Image, or numpy.ndarray inputs |
| (see ``standardize_input`` for format details). Optionally processes |
| accompanying segmentation masks with matching spatial transforms. |
| |
| Args: |
| images: Input images in any supported format. |
| masks: Optional segmentation masks in any format accepted by |
| ``standardize_mask_input``. Undergo the same crop/resize as images |
| (nearest-neighbour interpolation, label-preserving). Returned as |
| ``torch.long`` under the ``"mask"`` key (or ``None`` if not provided). |
| return_tensors: Only ``"pt"`` is supported. |
| device: Device for all tensor operations (e.g. ``"cuda:0"``). |
| Defaults to the device of the input tensor, or CPU for PIL/numpy. |
| **kwargs: Passed through to ``BaseImageProcessor``. |
| |
| Returns: |
| ``BatchFeature`` with keys: |
| |
| - ``pixel_values`` (B, 3, size, size): Processed float32 images. |
| - ``mask`` (B, 1, size, size) or ``None``: Processed long masks. |
| - ``scale_x``, ``scale_y`` (B,): Per-image scale factors. |
| - ``offset_x``, ``offset_y`` (B,): Per-image offsets. |
| |
| Coordinate mapping from processed → original pixel space:: |
| |
| orig_x = offset_x + proc_x * scale_x |
| orig_y = offset_y + proc_y * scale_y |
| """ |
| if return_tensors != "pt": |
| raise ValueError("Only 'pt' (PyTorch) tensors are supported") |
|
|
| |
| if device is not None: |
| device = torch.device(device) |
| elif isinstance(images, torch.Tensor): |
| device = images.device |
| elif isinstance(images, list) and len(images) > 0 and isinstance(images[0], torch.Tensor): |
| device = images[0].device |
| else: |
| |
| device = torch.device('cpu') |
|
|
| |
| images = standardize_input(images, device) |
| if masks is not None: |
| masks = standardize_mask_input(masks, device) |
| |
| B, C, H_orig, W_orig = images.shape |
|
|
| if self.do_crop: |
| |
| cx, cy = estimate_eye_center(images, softmax_temperature=self.softmax_temperature) |
|
|
| |
| radius = estimate_radius( |
| images, cx, cy, |
| min_radius_frac=self.min_radius_frac, |
| max_radius_frac=self.max_radius_frac, |
| ) |
|
|
| |
| x1, y1, x2, y2 = compute_crop_box( |
| cx, cy, radius, H_orig, W_orig, |
| scale_factor=self.crop_scale_factor, |
| allow_overflow=self.allow_overflow, |
| ) |
|
|
| |
| |
| scale_x = (x2 - x1) / (self.size - 1) |
| scale_y = (y2 - y1) / (self.size - 1) |
| offset_x = x1 |
| offset_y = y1 |
|
|
| |
| |
| padding_mode = 'zeros' if self.allow_overflow else 'border' |
| images = batch_crop_and_resize(images, x1, y1, x2, y2, self.size, padding_mode=padding_mode) |
|
|
| if masks is not None: |
| masks = batch_crop_and_resize_mask( |
| masks, x1, y1, x2, y2, |
| self.size, |
| padding_mode=padding_mode, |
| ) |
| else: |
| |
| |
| scale_x = torch.full((B,), (W_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) |
| scale_y = torch.full((B,), (H_orig - 1) / (self.size - 1), device=device, dtype=images.dtype) |
| offset_x = torch.zeros(B, device=device, dtype=images.dtype) |
| offset_y = torch.zeros(B, device=device, dtype=images.dtype) |
| images = resize_images(images, self.size) |
|
|
| if masks is not None: |
| |
| masks = resize_images(masks.float(), self.size, mode="nearest", antialias=False).round().long() |
|
|
| |
| if self.do_clahe: |
| images = apply_clahe_vectorized( |
| images, |
| grid_size=self.clahe_grid_size, |
| clip_limit=self.clahe_clip_limit, |
| ) |
|
|
| |
| images = normalize_images( |
| images, |
| mean=self.custom_mean, |
| std=self.custom_std, |
| mode=self.normalization_mode, |
| ) |
|
|
| |
| data = { |
| "pixel_values": images, |
| "scale_x": scale_x, |
| "scale_y": scale_y, |
| "offset_x": offset_x, |
| "offset_y": offset_y, |
| } |
| if masks is not None: |
| data["mask"] = masks |
| return BatchFeature(data=data, tensor_type="pt") |
|
|
| def __call__( |
| self, |
| images: Union[torch.Tensor, List[torch.Tensor]], |
| **kwargs, |
| ) -> BatchFeature: |
| """Alias for ``preprocess`` — enables ``processor(images, ...)`` call syntax.""" |
| return self.preprocess(images, **kwargs) |
|
|
|
|
| |
| EyeGPUImageProcessor = EyeCLAHEImageProcessor |
|
|