|
|
import math |
|
|
import torch |
|
|
|
|
|
|
|
|
def tor_expmap(x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: |
|
|
return (x + u) % (2 * math.pi) |
|
|
|
|
|
def tor_logmap(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
|
return torch.atan2(torch.sin(y - x), torch.cos(y - x)) |
|
|
|
|
|
def tor_projx(x: torch.Tensor) -> torch.Tensor: |
|
|
return x % (2 * math.pi) |
|
|
|
|
|
def tor_random_uniform(*size, dtype=None, device=None) -> torch.Tensor: |
|
|
z = torch.rand(*size, dtype=dtype, device=device) |
|
|
return z * 2 * math.pi |
|
|
|
|
|
def tor_uniform_logprob(x): |
|
|
dim = x.shape[-1] |
|
|
return torch.full_like(x[..., 0], -dim * math.log(2 * math.pi)) |
|
|
|
|
|
def tor_geodesic_t(t, angles_1, angles_0): |
|
|
|
|
|
tangent_vec = t * tor_logmap(angles_0, angles_1) |
|
|
points_at_time_t = tor_expmap(angles_0, tangent_vec) |
|
|
return points_at_time_t |
|
|
|
|
|
if __name__ =='__main__': |
|
|
a = tor_random_uniform((2,3,5)) |
|
|
b = tor_random_uniform((2,3,5)) |
|
|
t = torch.ones((2,1)) * 0.2 |
|
|
c = tor_geodesic_t(t[...,None],a,b) |
|
|
print(c) |
|
|
print(c.shape) |