|
|
import numpy as np |
|
|
import torch |
|
|
import random |
|
|
import matplotlib |
|
|
import matplotlib.pyplot as plt |
|
|
import math |
|
|
import umap |
|
|
import scanpy as sc |
|
|
from sklearn.decomposition import PCA |
|
|
|
|
|
import ot as pot |
|
|
from tqdm import tqdm |
|
|
from functools import partial |
|
|
from typing import Optional |
|
|
|
|
|
from matplotlib.colors import LinearSegmentedColormap |
|
|
|
|
|
|
|
|
def set_seed(seed): |
|
|
""" |
|
|
Sets the seed for reproducibility in PyTorch, Numpy, and Python's Random. |
|
|
|
|
|
Parameters: |
|
|
seed (int): The seed for the random number generators. |
|
|
""" |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
def wasserstein_distance( |
|
|
x0: torch.Tensor, |
|
|
x1: torch.Tensor, |
|
|
method: Optional[str] = None, |
|
|
reg: float = 0.05, |
|
|
power: int = 1, |
|
|
**kwargs, |
|
|
) -> float: |
|
|
assert power == 1 or power == 2 |
|
|
if method == "exact" or method is None: |
|
|
ot_fn = pot.emd2 |
|
|
elif method == "sinkhorn": |
|
|
ot_fn = partial(pot.sinkhorn2, reg=reg) |
|
|
else: |
|
|
raise ValueError(f"Unknown method: {method}") |
|
|
|
|
|
a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) |
|
|
if x0.dim() > 2: |
|
|
x0 = x0.reshape(x0.shape[0], -1) |
|
|
if x1.dim() > 2: |
|
|
x1 = x1.reshape(x1.shape[0], -1) |
|
|
M = torch.cdist(x0, x1) |
|
|
if power == 2: |
|
|
M = M**2 |
|
|
ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7) |
|
|
if power == 2: |
|
|
ret = math.sqrt(ret) |
|
|
return ret |
|
|
|
|
|
|
|
|
def plot_lidar(ax, dataset, xs=None, S=25, branch_idx=None): |
|
|
|
|
|
combined_points = [] |
|
|
combined_colors = [] |
|
|
combined_sizes = [] |
|
|
|
|
|
|
|
|
custom_colors_1 = ["#05009E", "#A19EFF", "#50B2D7"] |
|
|
custom_colors_2 = ["#05009E", "#A19EFF", "#D577FF"] |
|
|
|
|
|
custom_cmap_1 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_1) |
|
|
custom_cmap_2 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_2) |
|
|
|
|
|
|
|
|
z_coords = ( |
|
|
dataset[:, 2].numpy() if torch.is_tensor(dataset[:, 2]) else dataset[:, 2] |
|
|
) |
|
|
z_min, z_max = z_coords.min(), z_coords.max() |
|
|
z_norm = (z_coords - z_min) / (z_max - z_min) |
|
|
|
|
|
|
|
|
for i, point in enumerate(dataset): |
|
|
grey_value = 0.95 - 0.7 * z_norm[i] |
|
|
combined_points.append(point.numpy()) |
|
|
combined_colors.append( |
|
|
( |
|
|
grey_value, |
|
|
grey_value, |
|
|
grey_value, |
|
|
1.0 |
|
|
) |
|
|
) |
|
|
combined_sizes.append(0.1) |
|
|
|
|
|
|
|
|
if xs is not None: |
|
|
if branch_idx == 0: |
|
|
cmap = custom_cmap_1 |
|
|
else: |
|
|
cmap = custom_cmap_2 |
|
|
|
|
|
B, T, D = xs.shape |
|
|
steps_to_log = np.linspace(0, T - 1, S).astype(int) |
|
|
xs = xs.cpu().detach().clone() |
|
|
for idx, step in enumerate(steps_to_log): |
|
|
for point in xs[:512, step]: |
|
|
combined_points.append( |
|
|
point.numpy() if torch.is_tensor(point) else point |
|
|
) |
|
|
combined_colors.append(cmap(idx / (len(steps_to_log) - 1))) |
|
|
combined_sizes.append(0.8) |
|
|
|
|
|
|
|
|
combined_points = np.array(combined_points) |
|
|
combined_colors = np.array(combined_colors) |
|
|
combined_sizes = np.array(combined_sizes) |
|
|
|
|
|
|
|
|
sorted_indices = np.argsort(combined_points[:, 2]) |
|
|
combined_points = combined_points[sorted_indices] |
|
|
combined_colors = combined_colors[sorted_indices] |
|
|
combined_sizes = combined_sizes[sorted_indices] |
|
|
|
|
|
|
|
|
ax.scatter( |
|
|
combined_points[:, 0], |
|
|
combined_points[:, 1], |
|
|
combined_points[:, 2], |
|
|
s=combined_sizes, |
|
|
c=combined_colors, |
|
|
depthshade=True, |
|
|
) |
|
|
|
|
|
ax.set_xlim3d(left=-4.8, right=4.8) |
|
|
ax.set_ylim3d(bottom=-4.8, top=4.8) |
|
|
ax.set_zlim3d(bottom=0.0, top=2.0) |
|
|
ax.set_zticks([0, 1.0, 2.0]) |
|
|
ax.grid(False) |
|
|
plt.axis("off") |
|
|
|
|
|
return ax |
|
|
|
|
|
|
|
|
def plot_images_trajectory(trajectories, vae, processor, num_steps): |
|
|
|
|
|
|
|
|
t_span = torch.linspace(0, trajectories.shape[1] - 1, num_steps) |
|
|
t_span = [int(t) for t in t_span] |
|
|
num_images = trajectories.shape[0] |
|
|
|
|
|
|
|
|
decoded_images = [ |
|
|
[ |
|
|
processor.postprocess( |
|
|
vae.decode( |
|
|
trajectories[i_image, traj_step].unsqueeze(0) |
|
|
).sample.detach() |
|
|
)[0] |
|
|
for traj_step in t_span |
|
|
] |
|
|
for i_image in range(num_images) |
|
|
] |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots( |
|
|
num_images, num_steps, figsize=(num_steps * 2, num_images * 2) |
|
|
) |
|
|
if num_images == 1: |
|
|
axes = [axes] |
|
|
for img_idx, img_traj in enumerate(decoded_images): |
|
|
for step_idx, img in enumerate(img_traj): |
|
|
ax = axes[img_idx][step_idx] if num_images > 1 else axes[step_idx] |
|
|
if ( |
|
|
isinstance(img, np.ndarray) and img.shape[0] == 3 |
|
|
): |
|
|
img = img.transpose(1, 2, 0) |
|
|
ax.imshow(img) |
|
|
ax.axis("off") |
|
|
if img_idx == 0: |
|
|
ax.set_title(f"t={t_span[step_idx]/t_span[-1]:.2f}") |
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def plot_growth(dataset, growth_nets, xs, output_file='plot.pdf'): |
|
|
x0s = [dataset["x0"][0]] |
|
|
w0s = [dataset["x0"][1]] |
|
|
x1s_list = [[dataset["x1_1"][0]], [dataset["x1_2"][0]]] |
|
|
w1s_list = [[dataset["x1_1"][1]], [dataset["x1_2"][1]]] |
|
|
|
|
|
|
|
|
|
|
|
plt.show() |