# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the CC-by-NC license found in the # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Union import torch from torch import Tensor @dataclass class SchedulerOutput: r"""Represents a sample of a conditional-flow generated probability path. Attributes: alpha_t (Tensor): :math:`\alpha_t`, shape (...). sigma_t (Tensor): :math:`\sigma_t`, shape (...). d_alpha_t (Tensor): :math:`\frac{\partial}{\partial t}\alpha_t`, shape (...). d_sigma_t (Tensor): :math:`\frac{\partial}{\partial t}\sigma_t`, shape (...). """ alpha_t: Tensor = field(metadata={"help": "alpha_t"}) sigma_t: Tensor = field(metadata={"help": "sigma_t"}) d_alpha_t: Tensor = field(metadata={"help": "Derivative of alpha_t."}) d_sigma_t: Tensor = field(metadata={"help": "Derivative of sigma_t."}) class Scheduler(ABC): """Base Scheduler class.""" @abstractmethod def __call__(self, t: Tensor) -> SchedulerOutput: r""" Args: t (Tensor): times in [0,1], shape (...). Returns: SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` """ ... @abstractmethod def snr_inverse(self, snr: Tensor) -> Tensor: r""" Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. Args: snr (Tensor): The signal-to-noise, shape (...) Returns: Tensor: t, shape (...) """ ... class ConvexScheduler(Scheduler): @abstractmethod def __call__(self, t: Tensor) -> SchedulerOutput: """Scheduler for convex paths. Args: t (Tensor): times in [0,1], shape (...). Returns: SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t` """ ... @abstractmethod def kappa_inverse(self, kappa: Tensor) -> Tensor: """ Computes :math:`t` from :math:`\kappa_t`. Args: kappa (Tensor): :math:`\kappa`, shape (...) Returns: Tensor: t, shape (...) """ ... def snr_inverse(self, snr: Tensor) -> Tensor: r""" Computes :math:`t` from the signal-to-noise ratio :math:`\frac{\alpha_t}{\sigma_t}`. Args: snr (Tensor): The signal-to-noise, shape (...) Returns: Tensor: t, shape (...) """ kappa_t = snr / (1.0 + snr) return self.kappa_inverse(kappa=kappa_t) class CondOTScheduler(ConvexScheduler): """CondOT Scheduler.""" def __call__(self, t: Tensor) -> SchedulerOutput: return SchedulerOutput( alpha_t=t, sigma_t=1 - t, d_alpha_t=torch.ones_like(t), d_sigma_t=-torch.ones_like(t), ) def kappa_inverse(self, kappa: Tensor) -> Tensor: return kappa class PolynomialConvexScheduler(ConvexScheduler): """Polynomial Scheduler.""" def __init__(self, n: Union[float, int]) -> None: assert isinstance( n, (float, int) ), f"`n` must be a float or int. Got {type(n)=}." assert n > 0, f"`n` must be positive. Got {n=}." self.n = n def __call__(self, t: Tensor) -> SchedulerOutput: return SchedulerOutput( alpha_t=t**self.n, sigma_t=1 - t**self.n, d_alpha_t=self.n * (t ** (self.n - 1)), d_sigma_t=-self.n * (t ** (self.n - 1)), ) def kappa_inverse(self, kappa: Tensor) -> Tensor: return torch.pow(kappa, 1.0 / self.n) class VPScheduler(Scheduler): """Variance Preserving Scheduler.""" def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0) -> None: self.beta_min = beta_min self.beta_max = beta_max super().__init__() def __call__(self, t: Tensor) -> SchedulerOutput: b = self.beta_min B = self.beta_max T = 0.5 * (1 - t) ** 2 * (B - b) + (1 - t) * b dT = -(1 - t) * (B - b) - b return SchedulerOutput( alpha_t=torch.exp(-0.5 * T), sigma_t=torch.sqrt(1 - torch.exp(-T)), d_alpha_t=-0.5 * dT * torch.exp(-0.5 * T), d_sigma_t=0.5 * dT * torch.exp(-T) / torch.sqrt(1 - torch.exp(-T)), ) def snr_inverse(self, snr: Tensor) -> Tensor: T = -torch.log(snr**2 / (snr**2 + 1)) b = self.beta_min B = self.beta_max t = 1 - ((-b + torch.sqrt(b**2 + 2 * (B - b) * T)) / (B - b)) return t class LinearVPScheduler(Scheduler): """Linear Variance Preserving Scheduler.""" def __call__(self, t: Tensor) -> SchedulerOutput: return SchedulerOutput( alpha_t=t, sigma_t=(1 - t**2) ** 0.5, d_alpha_t=torch.ones_like(t), d_sigma_t=-t / (1 - t**2) ** 0.5, ) def snr_inverse(self, snr: Tensor) -> Tensor: return torch.sqrt(snr**2 / (1 + snr**2)) class CosineScheduler(Scheduler): """Cosine Scheduler.""" def __call__(self, t: Tensor) -> SchedulerOutput: pi = torch.pi return SchedulerOutput( alpha_t=torch.sin(pi / 2 * t), sigma_t=torch.cos(pi / 2 * t), d_alpha_t=pi / 2 * torch.cos(pi / 2 * t), d_sigma_t=-pi / 2 * torch.sin(pi / 2 * t), ) def snr_inverse(self, snr: Tensor) -> Tensor: return 2.0 * torch.atan(snr) / torch.pi