File size: 5,108 Bytes
3527383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

from torch import Tensor

from flow_matching.path.scheduler.scheduler import Scheduler
from flow_matching.utils import ModelWrapper


class ScheduleTransformedModel(ModelWrapper):
    """
    Change of scheduler for a velocity model.

    This class wraps a given velocity model and transforms its scheduling
    to a new scheduler function. It modifies the time
    dynamics of the model according to the new scheduler while maintaining
    the original model's behavior.

    Example:

    .. code-block:: python

        import torch
        from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
        from flow_matching.solver import ODESolver

        # Initialize the model and schedulers
        model = ...

        original_scheduler = CondOTScheduler()
        new_scheduler = CosineScheduler()

        # Create the transformed model
        transformed_model = ScheduleTransformedModel(
            velocity_model=model,
            original_scheduler=original_scheduler,
            new_scheduler=new_scheduler
        )

        # Set up the solver
        solver = ODESolver(velocity_model=transformed_model)

        x_0 = torch.randn([10, 2])  # Example initial condition

        x_1 = solver.sample(
            time_steps=torch.tensor([0.0, 1.0]),
            x_init=x_0,
            step_size=1/1000
            )[1]

    Args:
        velocity_model (ModelWrapper): The original velocity model to be transformed.
        original_scheduler (Scheduler): The scheduler used by the original model. Must implement the snr_inverse function.
        new_scheduler (Scheduler): The new scheduler to be applied to the model.
    """

    def __init__(
        self,
        velocity_model: ModelWrapper,
        original_scheduler: Scheduler,
        new_scheduler: Scheduler,
    ):
        super().__init__(model=velocity_model)
        self.original_scheduler = original_scheduler
        self.new_scheduler = new_scheduler

        assert hasattr(self.original_scheduler, "snr_inverse") and callable(
            getattr(self.original_scheduler, "snr_inverse")
        ), "The original scheduler must have a callable 'snr_inverse' method."

    def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
        r"""
        Compute the transformed marginal velocity field for a new scheduler.
        This method implements a post-training velocity scheduler change for
        affine conditional flows. It transforms a generating marginal velocity
        field :math:`u_t(x)` based on an original scheduler to a new marginal velocity
        field :math:`\bar{u}_r(x)` based on a different scheduler, while maintaining
        the same data coupling.
        The transformation is based on the scale-time (ST) transformation
        between the two conditional flows, defined as:

        .. math::

            \bar{X}_r = s_r X_{t_r},

        where :math:`X_t` and :math:`\bar{X}_r` are defined by their respective schedulers.
        The ST transformation is computed as:

        .. math::

            t_r = \rho^{-1}(\bar{\rho}(r)) \quad \text{and} \quad  s_r = \frac{\bar{\sigma}_r}{\sigma_{t_r}}.

        Here, :math:`\rho(t)` is the signal-to-noise ratio (SNR) defined as:

        .. math::

            \rho(t) = \frac{\alpha_t}{\sigma_t}.

        :math:`\bar{\rho}(r)` is similarly defined for the new scheduler.
        The marginal velocity for the new scheduler is then given by:

        .. math::

            \bar{u}_r(x) = \left(\frac{\dot{s}_r}{s_r}\right) x + s_r \dot{t}_r u_{t_r}\left(\frac{x}{s_r}\right).

        Args:
            x (Tensor): :math:`x_t`, the input tensor.
            t (Tensor): The time tensor (denoted as :math:`r` above).
            **extras: Additional arguments for the model.
        Returns:
            Tensor: The transformed velocity.
        """
        r = t

        r_scheduler_output = self.new_scheduler(t=r)

        alpha_r = r_scheduler_output.alpha_t
        sigma_r = r_scheduler_output.sigma_t
        d_alpha_r = r_scheduler_output.d_alpha_t
        d_sigma_r = r_scheduler_output.d_sigma_t

        t = self.original_scheduler.snr_inverse(alpha_r / sigma_r)

        t_scheduler_output = self.original_scheduler(t=t)

        alpha_t = t_scheduler_output.alpha_t
        sigma_t = t_scheduler_output.sigma_t
        d_alpha_t = t_scheduler_output.d_alpha_t
        d_sigma_t = t_scheduler_output.d_sigma_t

        s_r = sigma_r / sigma_t

        dt_r = (
            sigma_t
            * sigma_t
            * (sigma_r * d_alpha_r - alpha_r * d_sigma_r)
            / (sigma_r * sigma_r * (sigma_t * d_alpha_t - alpha_t * d_sigma_t))
        )

        ds_r = (sigma_t * d_sigma_r - sigma_r * d_sigma_t * dt_r) / (sigma_t * sigma_t)

        u_t = self.model(x=x / s_r, t=t, **extras)
        u_r = ds_r * x / s_r + dt_r * s_r * u_t

        return u_r