| diff --git a/config/locomotion.py b/config/locomotion.py |
| deleted file mode 100644 |
| index 4410bb1..0000000 |
| --- a/config/locomotion.py |
| +++ /dev/null |
| @@ -1,70 +0,0 @@ |
| -import socket |
| - |
| -from diffuser.utils import watch |
| - |
| -#------------------------ base ------------------------# |
| - |
| -## automatically make experiment names for planning |
| -## by labelling folders with these args |
| - |
| -diffusion_args_to_watch = [ |
| - ('prefix', ''), |
| - ('horizon', 'H'), |
| - ('n_diffusion_steps', 'T'), |
| -] |
| - |
| -base = { |
| - 'diffusion': { |
| - ## model |
| - 'model': 'models.TemporalUnet', |
| - 'diffusion': 'models.GaussianDiffusion', |
| - 'horizon': 32, |
| - 'n_diffusion_steps': 100, |
| - 'action_weight': 10, |
| - 'loss_weights': None, |
| - 'loss_discount': 1, |
| - 'predict_epsilon': False, |
| - 'dim_mults': (1, 4, 8), |
| - 'renderer': 'utils.MuJoCoRenderer', |
| - |
| - ## dataset |
| - 'loader': 'datasets.SequenceDataset', |
| - 'normalizer': 'LimitsNormalizer', |
| - 'preprocess_fns': [], |
| - 'clip_denoised': True, |
| - 'use_padding': True, |
| - 'max_path_length': 1000, |
| - |
| - ## serialization |
| - 'logbase': 'logs', |
| - 'prefix': 'diffusion/', |
| - 'exp_name': watch(diffusion_args_to_watch), |
| - |
| - ## training |
| - 'n_steps_per_epoch': 10000, |
| - 'loss_type': 'l2', |
| - 'n_train_steps': 1e6, |
| - 'batch_size': 32, |
| - 'learning_rate': 2e-4, |
| - 'gradient_accumulate_every': 2, |
| - 'ema_decay': 0.995, |
| - 'save_freq': 1000, |
| - 'sample_freq': 1000, |
| - 'n_saves': 5, |
| - 'save_parallel': False, |
| - 'n_reference': 8, |
| - 'n_samples': 2, |
| - 'bucket': None, |
| - 'device': 'cuda', |
| - }, |
| -} |
| - |
| -#------------------------ overrides ------------------------# |
| - |
| -## put environment-specific overrides here |
| - |
| -halfcheetah_medium_expert_v2 = { |
| - 'diffusion': { |
| - 'horizon': 16, |
| - }, |
| -} |
| diff --git a/config/maze2d.py b/config/maze2d.py |
| index a06ac7f..0a8d22a 100644 |
| --- a/config/maze2d.py |
| +++ b/config/maze2d.py |
| @@ -34,11 +34,11 @@ base = { |
| 'model': 'models.TemporalUnet', |
| 'diffusion': 'models.GaussianDiffusion', |
| 'horizon': 256, |
| - 'n_diffusion_steps': 256, |
| + 'n_diffusion_steps': 512, |
| 'action_weight': 1, |
| 'loss_weights': None, |
| 'loss_discount': 1, |
| - 'predict_epsilon': False, |
| + 'predict_epsilon': True, |
| 'dim_mults': (1, 4, 8), |
| 'renderer': 'utils.Maze2dRenderer', |
| |
| @@ -57,14 +57,14 @@ base = { |
| 'exp_name': watch(diffusion_args_to_watch), |
| |
| ## training |
| - 'n_steps_per_epoch': 10000, |
| - 'loss_type': 'l2', |
| - 'n_train_steps': 2e6, |
| - 'batch_size': 32, |
| - 'learning_rate': 2e-4, |
| - 'gradient_accumulate_every': 2, |
| + 'n_steps_per_epoch': 60000, |
| + 'loss_type': 'spline', |
| + 'n_train_steps': 6e4, |
| + 'batch_size': 1, |
| + 'learning_rate': 5e-6, |
| + 'gradient_accumulate_every': 8, |
| 'ema_decay': 0.995, |
| - 'save_freq': 1000, |
| + 'save_freq': 2000, |
| 'sample_freq': 1000, |
| 'n_saves': 50, |
| 'save_parallel': False, |
| @@ -89,7 +89,6 @@ base = { |
| 'prefix': 'plans/release', |
| 'exp_name': watch(plan_args_to_watch), |
| 'suffix': '0', |
| - |
| 'conditional': False, |
| |
| ## loading |
| @@ -122,10 +121,10 @@ maze2d_umaze_v1 = { |
| maze2d_large_v1 = { |
| 'diffusion': { |
| 'horizon': 384, |
| - 'n_diffusion_steps': 256, |
| + 'n_diffusion_steps': 16, |
| }, |
| 'plan': { |
| 'horizon': 384, |
| - 'n_diffusion_steps': 256, |
| + 'n_diffusion_steps': 16, |
| }, |
| } |
| diff --git a/diffuser/datasets/buffer.py b/diffuser/datasets/buffer.py |
| index 1ad2106..5991f01 100644 |
| --- a/diffuser/datasets/buffer.py |
| +++ b/diffuser/datasets/buffer.py |
| @@ -9,7 +9,7 @@ class ReplayBuffer: |
| |
| def __init__(self, max_n_episodes, max_path_length, termination_penalty): |
| self._dict = { |
| - 'path_lengths': np.zeros(max_n_episodes, dtype=np.int), |
| + 'path_lengths': np.zeros(max_n_episodes, dtype=np.int_), |
| } |
| self._count = 0 |
| self.max_n_episodes = max_n_episodes |
| diff --git a/diffuser/datasets/sequence.py b/diffuser/datasets/sequence.py |
| index 356c540..73c1b04 100644 |
| --- a/diffuser/datasets/sequence.py |
| +++ b/diffuser/datasets/sequence.py |
| @@ -83,6 +83,7 @@ class SequenceDataset(torch.utils.data.Dataset): |
| actions = self.fields.normed_actions[path_ind, start:end] |
| |
| conditions = self.get_conditions(observations) |
| + |
| trajectories = np.concatenate([actions, observations], axis=-1) |
| batch = Batch(trajectories, conditions) |
| return batch |
| diff --git a/diffuser/models/diffusion.py b/diffuser/models/diffusion.py |
| index fae4cfd..461680a 100644 |
| --- a/diffuser/models/diffusion.py |
| +++ b/diffuser/models/diffusion.py |
| @@ -2,6 +2,7 @@ import numpy as np |
| import torch |
| from torch import nn |
| import pdb |
| +import matplotlib.pyplot as plt |
| |
| import diffuser.utils as utils |
| from .helpers import ( |
| @@ -9,6 +10,7 @@ from .helpers import ( |
| extract, |
| apply_conditioning, |
| Losses, |
| + catmull_rom_spline_with_rotation, |
| ) |
| |
| class GaussianDiffusion(nn.Module): |
| @@ -26,6 +28,7 @@ class GaussianDiffusion(nn.Module): |
| betas = cosine_beta_schedule(n_timesteps) |
| alphas = 1. - betas |
| alphas_cumprod = torch.cumprod(alphas, axis=0) |
| + print(f"Alphas Cumprod: {alphas_cumprod}") |
| alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) |
| |
| self.n_timesteps = int(n_timesteps) |
| @@ -73,7 +76,7 @@ class GaussianDiffusion(nn.Module): |
| ''' |
| self.action_weight = action_weight |
| |
| - dim_weights = torch.ones(self.transition_dim, dtype=torch.float32) |
| + dim_weights = torch.ones(self.transition_dim, dtype=torch.float64) |
| |
| ## set loss coefficients for dimensions of observation |
| if weights_dict is None: weights_dict = {} |
| @@ -97,18 +100,16 @@ class GaussianDiffusion(nn.Module): |
| otherwise, model predicts x0 directly |
| ''' |
| if self.predict_epsilon: |
| - return ( |
| - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - |
| - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise |
| - ) |
| + return noise |
| else: |
| return noise |
| |
| def q_posterior(self, x_start, x_t, t): |
| posterior_mean = ( |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + |
| - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t[:, :, self.action_dim:] |
| ) |
| + |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) |
| posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped |
| @@ -129,7 +130,7 @@ class GaussianDiffusion(nn.Module): |
| def p_sample(self, x, cond, t): |
| b, *_, device = *x.shape, x.device |
| model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t) |
| - noise = torch.randn_like(x) |
| + noise = torch.randn_like(x[:, :, self.action_dim:]) |
| # no noise when t == 0 |
| nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) |
| return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise |
| @@ -139,22 +140,59 @@ class GaussianDiffusion(nn.Module): |
| device = self.betas.device |
| |
| batch_size = shape[0] |
| - x = torch.randn(shape, device=device) |
| - x = apply_conditioning(x, cond, self.action_dim) |
| + # x = torch.randn(shape, device=device, dtype=torch.float64) |
| + # Extract known indices and values |
| + known_indices = np.array(list(cond.keys()), dtype=int) |
| + |
| + # candidate_no x batch_size x dim |
| + known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0) |
| + known_values = np.moveaxis(known_values, 0, 1) |
| + |
| + # Sort the timepoints |
| + sorted_indices = np.argsort(known_indices) |
| + known_indices = known_indices[sorted_indices] |
| + known_values = known_values[:, sorted_indices] |
| + |
| + # Build the structured spline guess |
| + catmull_spline_trajectory = np.array([ |
| + catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, shape[1]) |
| + for b in range(batch_size) |
| + ]) |
| + catmull_spline_trajectory = torch.tensor( |
| + catmull_spline_trajectory, |
| + dtype=torch.float64, |
| + device=device |
| + ) |
| + |
| + |
| + if self.predict_epsilon: |
| + x = torch.randn((shape[0], shape[1], self.observation_dim), device=device, dtype=torch.float64) |
| + cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()} |
| + is_cond = torch.zeros((shape[0], shape[1], 1), device=device, dtype=torch.float64) |
| + is_cond[:, known_indices, :] = 1.0 |
| |
| if return_diffusion: diffusion = [x] |
| |
| - progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent() |
| + # progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent() |
| for i in reversed(range(0, self.n_timesteps)): |
| + if self.predict_epsilon: |
| + x = torch.cat([catmull_spline_trajectory, is_cond, x], dim=-1) |
| + |
| timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) |
| - x = self.p_sample(x, cond, timesteps) |
| - x = apply_conditioning(x, cond, self.action_dim) |
| + x = self.p_sample(x, cond_residual, timesteps) |
| + |
| + x = apply_conditioning(x, cond_residual, 0) |
| |
| - progress.update({'t': i}) |
| + if return_diffusion: diffusion.append(x) |
| |
| - if return_diffusion: diffusion.append(x) |
| + x = catmull_spline_trajectory + x |
| |
| - progress.close() |
| + |
| + |
| + # Normalize the quaternions |
| + # x[:, :, 3:7] = x[:, :, 3:7] / torch.norm(x[:, :, 3:7], dim=-1, keepdim=True) |
| + |
| + # progress.close() |
| |
| if return_diffusion: |
| return x, torch.stack(diffusion, dim=1) |
| @@ -167,7 +205,7 @@ class GaussianDiffusion(nn.Module): |
| conditions : [ (time, state), ... ] |
| ''' |
| device = self.betas.device |
| - batch_size = len(cond[0]) |
| + batch_size = len(next(iter(cond.values()))) |
| horizon = horizon or self.horizon |
| shape = (batch_size, horizon, self.transition_dim) |
| |
| @@ -175,38 +213,106 @@ class GaussianDiffusion(nn.Module): |
| |
| #------------------------------------------ training ------------------------------------------# |
| |
| - def q_sample(self, x_start, t, noise=None): |
| + def q_sample(self, x_start, t, spline=None, noise=None): |
| + x_start_noise = x_start[:, : , :-1] |
| + x_start_is_cond = x_start[:, :, [-1]] |
| + |
| + if spline is None: |
| + spline = torch.randn_like(x_start_noise) |
| if noise is None: |
| - noise = torch.randn_like(x_start) |
| + noise = torch.randn_like(x_start_noise) |
| |
| - sample = ( |
| - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
| - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise |
| - ) |
| + alpha = extract(self.sqrt_alphas_cumprod, t, x_start.shape) |
| + oneminusalpha = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) |
| + |
| + # Weighted combination of x_0 and the spline |
| + out = alpha * x_start_noise + oneminusalpha * noise |
| + |
| + # Concatenate the binary feature and the spline as the conditioning |
| + out = torch.cat([spline, x_start_is_cond, out], dim=-1) |
| |
| - return sample |
| + return out |
| |
| def p_losses(self, x_start, cond, t): |
| - noise = torch.randn_like(x_start) |
| + batch_size, horizon, _ = x_start.shape |
| + # Extract known indices and values |
| + known_indices = np.array(list(cond.keys()), dtype=int) |
| + |
| + # candidate_no x batch_size x dim |
| + known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0) |
| + known_values = np.moveaxis(known_values, 0, 1) |
| + |
| + # Sort the timepoints |
| + sorted_indices = np.argsort(known_indices) |
| + known_indices = known_indices[sorted_indices] |
| + known_values = known_values[:, sorted_indices] |
| + |
| + # Build your structured guess |
| + catmull_spline_trajectory = np.array([ |
| + catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, horizon) |
| + for b in range(batch_size) |
| + ]) |
| + catmull_spline_trajectory = torch.tensor( |
| + catmull_spline_trajectory, |
| + dtype=torch.float64, |
| + device=x_start.device |
| + ) |
| |
| - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) |
| - x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) |
| + # Plot the quaternions |
| + # plt.plot(x_start[0, :, 3].cpu().numpy()) |
| + # plt.plot(catmull_spline_trajectory[0, :, 3].cpu().numpy()) |
| + # plt.legend(["x_start", "catmull_spline"]) |
| + # plt.show() |
| + # raise Exception |
| |
| - x_recon = self.model(x_noisy, cond, t) |
| - x_recon = apply_conditioning(x_recon, cond, self.action_dim) |
| |
| - assert noise.shape == x_recon.shape |
| + if not self.predict_epsilon: |
| + # Forward diffuse with the structured trajectory |
| + x_noisy = self.q_sample( |
| + x_start, |
| + t, |
| + spline=catmull_spline_trajectory, |
| + ) |
| + x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) |
| |
| - if self.predict_epsilon: |
| - loss, info = self.loss_fn(x_recon, noise) |
| + # Reverse pass guess |
| + x_recon = self.model(x_noisy, cond, t) |
| + x_recon = apply_conditioning(x_recon, cond, self.action_dim) |
| + |
| + # Then x_recon is the predicted x_0, compare to the true x_0 |
| + loss, info = self.loss_fn(x_recon, x_start, cond) |
| else: |
| - loss, info = self.loss_fn(x_recon, x_start) |
| + residual = x_start.clone() |
| + |
| + residual[:, :, :-1] -= catmull_spline_trajectory |
| + |
| + |
| + cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()} |
| + |
| + x_noisy = self.q_sample( |
| + residual, |
| + t, |
| + spline=catmull_spline_trajectory, |
| + ) |
| + x_noisy = apply_conditioning(x_noisy, cond_residual, self.action_dim) |
| + |
| + # Reverse pass guess |
| + x_recon = self.model(x_noisy, cond, t) |
| + x_recon = apply_conditioning(x_recon, cond_residual, 0) |
| + |
| + x_recon = x_recon + catmull_spline_trajectory |
| + |
| + loss, info = self.loss_fn(x_recon, x_start[:, :, :-1], cond) |
| |
| return loss, info |
| |
| def loss(self, x, cond): |
| batch_size = len(x) |
| t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() |
| + # t = torch.randint(1, 2, (batch_size,), device=x.device).long() |
| + # x = x.double() |
| + # cond = {k: v.double() for k, v in cond.items()} |
| + # print(f"Time: {t.item()}") |
| return self.p_losses(x, cond, t) |
| |
| def forward(self, cond, *args, **kwargs): |
| diff --git a/diffuser/models/helpers.py b/diffuser/models/helpers.py |
| index d39f35d..9f43ef8 100644 |
| --- a/diffuser/models/helpers.py |
| +++ b/diffuser/models/helpers.py |
| @@ -1,11 +1,11 @@ |
| import math |
| +import json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| -import einops |
| from einops.layers.torch import Rearrange |
| -import pdb |
| +from pytorch3d.transforms import quaternion_to_matrix, quaternion_to_axis_angle |
| |
| import diffuser.utils as utils |
| |
| @@ -30,7 +30,7 @@ class SinusoidalPosEmb(nn.Module): |
| class Downsample1d(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| - self.conv = nn.Conv1d(dim, dim, 3, 2, 1) |
| + self.conv = nn.Conv1d(dim, dim, 3, 2, 1).to(torch.float64) |
| |
| def forward(self, x): |
| return self.conv(x) |
| @@ -38,7 +38,7 @@ class Downsample1d(nn.Module): |
| class Upsample1d(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| - self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) |
| + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1).to(torch.float64) |
| |
| def forward(self, x): |
| return self.conv(x) |
| @@ -52,9 +52,9 @@ class Conv1dBlock(nn.Module): |
| super().__init__() |
| |
| self.block = nn.Sequential( |
| - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), |
| + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2).to(torch.float64), |
| Rearrange('batch channels horizon -> batch channels 1 horizon'), |
| - nn.GroupNorm(n_groups, out_channels), |
| + nn.GroupNorm(n_groups, out_channels).to(torch.float64), |
| Rearrange('batch channels 1 horizon -> batch channels horizon'), |
| nn.Mish(), |
| ) |
| @@ -72,7 +72,7 @@ def extract(a, t, x_shape): |
| out = a.gather(-1, t) |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
| |
| -def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): |
| +def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float64): |
| """ |
| cosine schedule |
| as proposed in https://openreview.net/forum?id=-NEXDKk8gZ |
| @@ -157,9 +157,979 @@ class ValueL2(ValueLoss): |
| def _loss(self, pred, targ): |
| return F.mse_loss(pred, targ, reduction='none') |
| |
| +class GeodesicL2Loss(nn.Module): |
| + def __init__(self, *args): |
| + super().__init__() |
| + pass |
| + |
| + def _loss(self, pred, targ): |
| + # Compute L2 loss for the first three dimensions |
| + l2_loss = F.mse_loss(pred[..., :3], targ[..., :3], reduction='mean') |
| + |
| + # Normalize to unit quaternions for the last four dimensions |
| + pred_quat = pred[..., 3:] / pred[..., 3:].norm(dim=-1, keepdim=True) |
| + targ_quat = targ[..., 3:] / targ[..., 3:].norm(dim=-1, keepdim=True) |
| + |
| + assert not torch.isnan(pred_quat).any(), "Pred Quat has NaNs" |
| + assert not torch.isnan(targ_quat).any(), "Targ Quat has NaNs" |
| + |
| + # Compute dot product for the quaternions |
| + dot_product = torch.sum(pred_quat * targ_quat, dim=-1) |
| + dot_product = torch.clamp(torch.abs(dot_product), -1.0, 1.0) |
| + |
| + # Compute geodesic loss for the quaternions |
| + geodesic_loss = 2 * torch.acos(dot_product).mean() |
| + |
| + assert not torch.isnan(geodesic_loss).any(), "Geodesic Loss has NaNs" |
| + assert not torch.isnan(l2_loss).any(), "L2 Loss has NaNs" |
| + |
| + return l2_loss + geodesic_loss, l2_loss, geodesic_loss |
| + |
| + def forward(self, pred, targ): |
| + loss, l2, geodesic = self._loss(pred, targ) |
| + |
| + info = { |
| + 'l2': l2.item(), |
| + 'geodesic': geodesic.item(), |
| + } |
| + |
| + return loss, info |
| + |
| +class RotationTranslationLoss(nn.Module): |
| + def __init__(self, *args): |
| + super().__init__() |
| + pass |
| + |
| + def _loss(self, pred, targ, cond=None): |
| + |
| + # Make sure the dtype is float64 |
| + pred = pred.to(torch.float64) |
| + targ = targ.to(torch.float64) |
| + |
| + eps = 1e-8 |
| + |
| + pred_trans = pred[..., :3] |
| + pred_quat = pred[..., 3:7] |
| + targ_trans = targ[..., :3] |
| + targ_quat = targ[..., 3:7] |
| + |
| + l2_loss = F.mse_loss(pred_trans, targ_trans, reduction='mean') |
| + |
| + # Calculate the geodesic loss |
| + pred_n = pred_quat.norm(dim=-1, keepdim=True).clamp(min=eps) |
| + targ_n = targ_quat.norm(dim=-1, keepdim=True).clamp(min=eps) |
| + |
| + pred_quat_norm = pred_quat / pred_n |
| + targ_quat_norm = targ_quat / targ_n |
| + |
| + |
| + dot_product = torch.sum(pred_quat_norm * targ_quat_norm, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps) |
| + quaternion_dist = 1 - (dot_product ** 2).mean() |
| + |
| + # Calculate the rotation error |
| + pred_rot = quaternion_to_matrix(pred_quat_norm).reshape(-1, 3, 3) |
| + targ_rot = quaternion_to_matrix(targ_quat_norm).reshape(-1, 3, 3) |
| + |
| + r2r1 = pred_rot @ targ_rot.permute(0, 2, 1) |
| + trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1) |
| + trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps) |
| + geodesic_loss = torch.acos(trace).mean() |
| + |
| + # Add a smoothness and acceleration term to the positions and quaternions |
| + alpha = 1.0 |
| + smoothness_loss = F.mse_loss(pred[:, 1:, :7].reshape(-1, 7), pred[:, :-1, :7].reshape(-1, 7), reduction='mean') |
| + acceleration_loss = F.mse_loss(pred[:, 2:, :7].reshape(-1, 7), 2 * pred[:, 1:-1, :7].reshape(-1, 7) - pred[:, :-2, :7].reshape(-1, 7), reduction='mean') |
| + |
| + l2_multiplier = 10.0 |
| + |
| + loss = l2_multiplier * l2_loss + quaternion_dist + geodesic_loss + alpha * (smoothness_loss + acceleration_loss) |
| + |
| + dtw = DynamicTimeWarpingLoss() |
| + dtw_loss, _ = dtw.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) |
| + |
| + hausdorff = HausdorffDistanceLoss() |
| + hausdorff_loss, _ = hausdorff.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) |
| + |
| + frec = FrechetDistanceLoss() |
| + frechet_loss, _ = frec.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) |
| + |
| + chamfer = ChamferDistanceLoss() |
| + chamfer_loss, _ = chamfer.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) |
| + |
| + return loss, l2_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss |
| + |
| + |
| + def forward(self, pred, targ, cond=None): |
| + loss, err_t, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond) |
| + |
| + info = { |
| + 'rot. error': err_r.item(), |
| + 'geodesic error': err_geo.item(), |
| + 'trans. error': err_t.item(), |
| + 'dtw': err_dtw.item(), |
| + 'hausdorff': err_hausdorff.item(), |
| + 'frechet': err_frechet.item(), |
| + 'chamfer': err_chamfer.item(), |
| + } |
| + |
| + return loss, info |
| + |
| +class SplineLoss(nn.Module): |
| + def __init__(self, *args): |
| + super().__init__() |
| + self.scales = json.load(open('scene_scale.json')) |
| + |
| + def compute_spline_coeffs(self, trans): |
| + p0 = trans[:, :-3, :] |
| + p1 = trans[:, 1:-2, :] |
| + p2 = trans[:, 2:-1, :] |
| + p3 = trans[:, 3:, :] |
| + |
| + # Tangent approximations |
| + m1 = 0.5 * (-p0 + p2) |
| + m2 = 0.5 * (-p1 + p3) |
| + |
| + # Cubic spline coefficients for each dimension |
| + a = (2 * p1 - 2 * p2 + m1 + m2) |
| + b = (-3 * p1 + 3 * p2 - 2 * m1 - m2) |
| + c = (m1) |
| + d = (p1) |
| + |
| + return torch.stack([a, b, c, d], dim=-1) |
| + |
| + def q_normalize(self, q): |
| + return q / q.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-12) |
| + |
| + def q_conjugate(self, q): |
| + w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3] |
| + return torch.stack([w, -x, -y, -z], dim=-1) |
| + |
| + def q_multiply(self, q1, q2): |
| + """ |
| + q1*q2. |
| + """ |
| + w1, x1, y1, z1 = q1.unbind(-1) |
| + w2, x2, y2, z2 = q2.unbind(-1) |
| + w = w1*w2 - x1*x2 - y1*y2 - z1*z2 |
| + x = w1*x2 + x1*w2 + y1*z2 - z1*y2 |
| + y = w1*y2 - x1*z2 + y1*w2 + z1*x2 |
| + z = w1*z2 + x1*y2 - y1*x2 + z1*w2 |
| + return torch.stack([w, x, y, z], dim=-1) |
| + |
| + def q_inverse(self, q): |
| + return self.q_conjugate(self.q_normalize(q)) |
| + |
| + def q_log(self, q): |
| + """ |
| + Quaternion logarithm for a unit quaternion |
| + Only returns the imaginary part |
| + """ |
| + q = self.q_normalize(q) |
| + w = q[..., 0] |
| + xyz = q[..., 1:] # shape [..., 3] |
| + mag_v = xyz.norm(p=2, dim=-1) |
| + eps = 1e-12 |
| + angle = torch.acos(w.clamp(-1.0 + eps, 1.0 - eps)) |
| + |
| + # We do a safe-guard against zero for sin(angle) |
| + small_mask = (mag_v < 1e-12) | (angle < 1e-12) |
| + # Where small_mask is True => near identity => log(q) ~ 0 |
| + log_val = torch.zeros_like(xyz) |
| + |
| + # Normal case |
| + scale = angle / mag_v.clamp(min=1e-12) |
| + normal_case = scale.unsqueeze(-1) * xyz |
| + |
| + log_val = torch.where( |
| + small_mask.unsqueeze(-1), |
| + torch.zeros_like(xyz), |
| + normal_case |
| + ) |
| + return log_val |
| + |
| + def q_exp(self, v): |
| + """ |
| + Quaternion exponential |
| + """ |
| + norm_v = v.norm(p=2, dim=-1) |
| + small_mask = norm_v < 1e-12 |
| + |
| + w = torch.cos(norm_v) |
| + sin_v = torch.sin(norm_v) |
| + scale = torch.where( |
| + small_mask, |
| + torch.zeros_like(norm_v), # if zero, sin(0)/0 => 0 |
| + sin_v / norm_v.clamp(min=1e-12) |
| + ) |
| + xyz = scale.unsqueeze(-1) * v |
| + |
| + # For small angles, we approximate cos(norm_v) ~ 1, sin(norm_v)/norm_v ~ 1 |
| + w = torch.where( |
| + small_mask, |
| + torch.ones_like(w), |
| + w |
| + ) |
| + return torch.cat([w.unsqueeze(-1), xyz], dim=-1) |
| + |
| + def q_slerp(self, q1, q2, t): |
| + """ |
| + Spherical linear interpolation from q1 to q2 at t in [0,1]. |
| + Both q1, q2 assumed normalized. |
| + q1, q2, t can be 1D or broadcastable shapes, but typically 1D. |
| + """ |
| + q1 = self.q_normalize(q1) |
| + q2 = self.q_normalize(q2) |
| + dot = (q1 * q2).sum(dim=-1, keepdim=True) # the dot product |
| + |
| + eps = 1e-12 |
| + dot = dot.clamp(-1.0 + eps, 1.0 - eps) |
| + |
| + flip_mask = dot < 0.0 |
| + if flip_mask.any(): |
| + q2 = torch.where(flip_mask, -q2, q2) |
| + dot = torch.where(flip_mask, -dot, dot) |
| + |
| + # If they're very close, do a simple linear interpolation |
| + close_mask = dot.squeeze(-1) > 0.9995 |
| + # Using an epsilon to avoid potential issues close to 1.0 |
| + |
| + # Branch 1: Very close |
| + # linear LERP |
| + lerp_val = (1.0 - t) * q1 + t * q2 |
| + lerp_val = self.q_normalize(lerp_val) |
| + |
| + # Branch 2: Standard SLERP |
| + theta_0 = torch.acos(dot) |
| + sin_theta_0 = torch.sin(theta_0) |
| + theta = theta_0 * t |
| + s1 = torch.sin(theta_0 - theta) / sin_theta_0.clamp(min=1e-12) |
| + s2 = torch.sin(theta) / sin_theta_0.clamp(min=1e-12) |
| + slerp_val = s1 * q1 + s2 * q2 |
| + slerp_val = self.q_normalize(slerp_val) |
| + |
| + # Combine |
| + return torch.where( |
| + close_mask.unsqueeze(-1), |
| + lerp_val, |
| + slerp_val |
| + ) |
| + |
| + def compute_uniform_tangent(self, q_im1, q_i, q_ip1): |
| + """ |
| + Computes a 'Catmull–Rom-like' tangent T_i for quaternion q_i, |
| + given neighbors q_im1, q_i, q_ip1. |
| + |
| + T_i = q_i * exp( -0.25 * [ log(q_i^-1 q_ip1) + log(q_i^-1 q_im1) ] ) |
| + """ |
| + q_im1 = self.q_normalize(q_im1) |
| + q_i = self.q_normalize(q_i) |
| + q_ip1 = self.q_normalize(q_ip1) |
| + |
| + inv_qi = self.q_inverse(q_i) |
| + r1 = self.q_multiply(inv_qi, q_ip1) |
| + r2 = self.q_multiply(inv_qi, q_im1) |
| + |
| + lr1 = self.q_log(r1) |
| + lr2 = self.q_log(r2) |
| + |
| + m = -0.25 * (lr1 + lr2) |
| + exp_m = self.q_exp(m) |
| + return self.q_multiply(q_i, exp_m) |
| + |
| + def compute_all_uniform_tangents(self, quats): |
| + """ |
| + Vectorized version that computes tangents T_i for all keyframe quaternions at once. |
| + quats shape: [N,4], N >= 2 |
| + Returns shape [N,4]. |
| + """ |
| + q_im1 = torch.cat([quats[[0]], quats[:-1]], dim=0) # q_im1[0] = q0 |
| + q_ip1 = torch.cat([quats[1:], quats[[-1]]], dim=0) # q_ip1[N-1]= q_{N-1} |
| + |
| + return self.compute_uniform_tangent(q_im1, quats, q_ip1) |
| + |
| + def squad(self, q0, a, b, q1, t): |
| + """ |
| + Shoemake's "squad" interpolation for quaternion splines: |
| + squad(q0, a, b, q1; t) = slerp( slerp(q0, q1; t), |
| + slerp(a, b; t), |
| + 2t(1-t) ) |
| + where a, b are tangential control quaternions for q0, q1. |
| + """ |
| + s1 = self.q_slerp(q0, q1, t) |
| + s2 = self.q_slerp(a, b, t) |
| + alpha = 2.0*t*(1.0 - t) |
| + return self.q_slerp(s1, s2, alpha) |
| + |
| + def uniform_cr_spline(self, quats, num_samples_per_segment=10): |
| + """ |
| + Given a list of keyframe quaternions quats (each a torch 1D tensor [4]), |
| + compute a "Uniform Catmull–Rom–like" quaternion spline through them. |
| + |
| + Returns: |
| + A list (Python list) of interpolated quaternions (torch tensors), |
| + including all segment endpoints. |
| + |
| + Each interior qi gets a tangent T_i using neighbors q_{i-1}, q_i, q_{i+1}. |
| + For boundary tangents, we replicate the end quaternions. |
| + """ |
| + n = quats.shape[0] |
| + if n < 2: |
| + return quats.unsqueeze(0) # not enough quats to interpolate |
| + |
| + # Precompute tangents |
| + tangents = self.compute_all_uniform_tangents(quats) |
| + |
| + # Interpolate each segment [qi, q_{i+1}] |
| + q0 = quats[:-1].unsqueeze(1) |
| + q1 = quats[1:].unsqueeze(1) |
| + a = tangents[:-1].unsqueeze(1) |
| + b = tangents[1:].unsqueeze(1) |
| + |
| + t_vals = torch.linspace(0.0, 1.0, num_samples_per_segment, device=quats.device, dtype=quats.dtype) |
| + t_vals = t_vals.view(1, -1, 1) |
| + |
| + out = self.squad(q0, a, b, q1, t_vals) |
| + return out |
| + |
| + |
| + def forward(self, pred, targ, cond=None, scene_id=None, norm_params=None): |
| + loss, err_t, err_smooth, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond, scene_id, norm_params) |
| + |
| + info = { |
| + 'trans. error': err_t.item(), |
| + 'smoothness error': err_smooth.item(), |
| + # 'dtw': err_dtw.item(), |
| + # 'hausdorff': err_hausdorff.item(), |
| + # 'frechet': err_frechet.item(), |
| + # 'chamfer': err_chamfer.item(), |
| + 'quat. dist.': err_r.item(), |
| + 'geodesic dist.': err_geo.item(), |
| + } |
| + |
| + return loss, info |
| + |
| + def _loss(self, pred, targ, cond=None, scene_id=None, norm_params=None): |
| + def poly_eval(coeffs, x): |
| + """ |
| + Evaluates a polynomial (with highest-degree term first) at points x. |
| + coeffs: 2D tensor of shape [num_polynomials, degree + 1], highest-degree term first. |
| + x: 1D tensor of points at which to evaluate the polynomial. |
| + Returns: |
| + 2D tensor of shape [num_polynomials, len(x)], containing p(x). |
| + """ |
| + x_powers = torch.stack([x**i for i in range(coeffs.shape[-1] - 1, -1, -1)], dim=-1) |
| + x_powers = x_powers.to(torch.float64).to(coeffs.device) |
| + y = torch.matmul(coeffs, x_powers.T) |
| + return y |
| + |
| + # Make sure the dtype is float64 |
| + pred = pred.to(torch.float64) |
| + targ = targ.to(torch.float64) |
| + |
| + # Rescale the translations |
| + if scene_id is not None and norm_params is not None: |
| + scene_id = scene_id.item() |
| + scene_scale = self.scales[str(scene_id)] |
| + scene_scale = norm_params['scale'][0] * scene_scale |
| + pred[..., :3] = pred[..., :3] * scene_scale |
| + targ[..., :3] = targ[..., :3] * scene_scale |
| + # print(pred[..., :3].max(), targ[..., :3].max()) |
| + |
| + # We only consider interpolated points for loss calculation |
| + candidate_idxs = sorted(cond.keys()) |
| + pred = pred[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :] |
| + targ = targ[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :] |
| + |
| + pred_trans = pred[..., :3] |
| + pred_quat = pred[..., 3:7] |
| + targ_trans = targ[..., :3] |
| + targ_quat = targ[..., 3:7] |
| + |
| + pred_coeffs = self.compute_spline_coeffs(pred_trans) |
| + targ_coeffs = self.compute_spline_coeffs(targ_trans) |
| + |
| + n_points = 2000 |
| + |
| + # Distribute sample points among intervals |
| + dists = torch.norm(targ_trans[:, 1:, :] - targ_trans[:, :-1, :], dim=-1).reshape(-1) |
| + dists_c = torch.zeros(len(candidate_idxs) - 1, device=pred.device) |
| + for i in range(len(candidate_idxs) - 1): |
| + dists_c[i] = dists[candidate_idxs[i]:candidate_idxs[i+1]].sum() |
| + |
| + weights_c = dists_c / dists_c.sum() |
| + scaled_c = weights_c * n_points |
| + points_c = torch.floor(scaled_c).int() |
| + |
| + while points_c.sum() < n_points: |
| + idx = torch.argmax(scaled_c - points_c) |
| + points_c[idx] += 1 |
| + |
| + # Calculate the spline loss |
| + sample_points = 50 |
| + x = torch.linspace(0, 1, sample_points, device=pred.device) |
| + pred_spline = poly_eval(pred_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3) |
| + targ_spline = poly_eval(targ_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3) |
| + |
| + indexes = [] |
| + start_idx = candidate_idxs[0] |
| + for c, (idx_i0, idx_i1) in enumerate(zip(candidate_idxs[:-1], candidate_idxs[1:])): |
| + p = points_c[c] |
| + total_dist = dists_c[c] |
| + dist_arr = dists[idx_i0 - start_idx : idx_i1 - start_idx] |
| + |
| + step_distances = (dist_arr / sample_points).repeat_interleave(sample_points) |
| + cumul_distances = step_distances.cumsum(dim=0) |
| + |
| + dist_per_pick = total_dist / p |
| + pick_targets = torch.arange(1, p + 1, device=dists.device) * dist_per_pick |
| + |
| + pick_idxs = torch.searchsorted(cumul_distances, pick_targets, right=True) |
| + pick_idxs = torch.clamp(pick_idxs, max=len(cumul_distances) - 1) |
| + |
| + |
| + indexes_1d = torch.zeros_like(step_distances) |
| + indexes_1d[pick_idxs] = 1 |
| + |
| + indexes_2d = indexes_1d.view(len(dist_arr), sample_points) |
| + |
| + indexes.append(indexes_2d) |
| + |
| + indexes = torch.cat(indexes)[1: -1] # The first and last candidates don't have spline representations |
| + |
| + indexes_trans = torch.stack([indexes for _ in range(3)], dim=-1) |
| + indexes_quat = torch.stack([indexes for _ in range(4)], dim=-1) |
| + |
| + indexes_trans = indexes_trans.to(torch.bool) |
| + indexes_quat = indexes_quat.to(torch.bool) |
| + |
| + pred_trans_selected_values = pred_spline[indexes_trans] |
| + targ_trans_selected_values = targ_spline[indexes_trans] |
| + |
| + pred_trans_selected_values = pred_trans_selected_values.reshape(-1, 3) |
| + targ_trans_selected_values = targ_trans_selected_values.reshape(-1, 3) |
| + |
| + # Calculate the loss for quaternions |
| + pred_quat = pred_quat / pred_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8) |
| + targ_quat = targ_quat / targ_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8) |
| + |
| + targ_quat_spline = self.uniform_cr_spline(targ_quat.reshape(-1, 4), num_samples_per_segment=sample_points) |
| + pred_quat_spline = self.uniform_cr_spline(pred_quat.reshape(-1, 4), num_samples_per_segment=sample_points) |
| + |
| + |
| + targ_quat_spline = targ_quat_spline[1:-1] |
| + pred_quat_spline = pred_quat_spline[1:-1] |
| + |
| + |
| + pred_quat_selected_values = pred_quat_spline[indexes_quat] |
| + targ_quat_selected_values = targ_quat_spline[indexes_quat] |
| + |
| + pred_quat_selected_values = pred_quat_selected_values.reshape(-1, 4) |
| + targ_quat_selected_values = targ_quat_selected_values.reshape(-1, 4) |
| + |
| + # Calculate the geodesic loss |
| + pred_rot = quaternion_to_matrix(pred_quat_selected_values).reshape(-1, 3, 3) |
| + targ_rot = quaternion_to_matrix(targ_quat_selected_values).reshape(-1, 3, 3) |
| + |
| + eps = 1e-12 |
| + r2r1 = pred_rot @ targ_rot.permute(0, 2, 1) |
| + trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1) |
| + trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps) |
| + geodesic_loss = torch.acos(trace).mean() |
| + |
| + # Calculate the rotation error |
| + dot_product = torch.sum(pred_quat_selected_values * targ_quat_selected_values, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps) |
| + quaternion_dist = 1 - (dot_product ** 2).mean() |
| + |
| + # Calculate the L2 loss |
| + l2_loss = F.mse_loss(pred_trans_selected_values, targ_trans_selected_values, reduction='mean') |
| + |
| + # Calculate the smoothness loss for translation and quaternion |
| + smoothness_multiplier = 10 ** 2 # Empirically determined multiplier for smoothness loss |
| + weight_acceleration = 0.1 |
| + weight_jerk = 0.05 |
| + |
| + pos_acc = pred_trans_selected_values[2:, :] - 2 * pred_trans_selected_values[1:-1, :] + pred_trans_selected_values[:-2, :] |
| + pos_jerk = pred_trans_selected_values[3:, :] - 3 * pred_trans_selected_values[2:-1, :] + 3 * pred_trans_selected_values[1:-2, :] - pred_trans_selected_values[:-3, :] |
| + |
| + pos_acceleration_loss = torch.mean(pos_acc ** 2) |
| + pos_jerk_loss = torch.mean(pos_jerk ** 2) |
| + |
| + q0 = pred_quat_selected_values[:-1, :] |
| + q1 = pred_quat_selected_values[1:, :] |
| + sign = torch.where((q0 * q1).sum(dim=-1) < 0, -1.0, 1.0) |
| + q1 = sign.unsqueeze(-1) * q1 |
| + |
| + dq = self.q_multiply(q1, self.q_inverse(q0)) |
| + theta = 2 * torch.acos(torch.clamp(dq[..., 0], -1.0 + 1e-8, 1.0 - 1e-8)) |
| + |
| + rot_acc = theta[2:] - 2*theta[1:-1] + theta[:-2] |
| + rot_jerk = theta[3:] - 3*theta[2:-1] + 3*theta[1:-2] - theta[:-3] |
| + |
| + rot_acceleration_loss = torch.mean(rot_acc ** 2) |
| + rot_jerk_loss = torch.mean(rot_jerk ** 2) |
| + |
| + alpha_rot = 0.1 # <-- tune this (e.g. 0.1 … 10) |
| + |
| + |
| + acceleration_loss = pos_acceleration_loss + alpha_rot * rot_acceleration_loss |
| + jerk_loss = pos_jerk_loss + alpha_rot * rot_jerk_loss |
| + |
| + smoothness_loss = ( |
| + weight_acceleration * acceleration_loss |
| + + weight_jerk * jerk_loss |
| + ) * smoothness_multiplier |
| + |
| + |
| + # Calculate the spline loss |
| + l2_multiplier = 10.0 |
| + spline_loss = l2_multiplier * (l2_loss + smoothness_loss) + geodesic_loss + quaternion_dist |
| + |
| + dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss = None, None, None, None |
| + |
| + # Uncomment these lines if you want to use the other losses |
| + ''' |
| + dtw = DynamicTimeWarpingLoss() |
| + dtw_loss, _ = dtw.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) |
| + |
| + hausdorff = HausdorffDistanceLoss() |
| + hausdorff_loss, _ = hausdorff.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) |
| + |
| + frec = FrechetDistanceLoss() |
| + frechet_loss, _ = frec.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) |
| + |
| + chamfer = ChamferDistanceLoss() |
| + chamfer_loss, _ = chamfer.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) |
| + ''' |
| + |
| + return spline_loss, l2_multiplier * l2_loss, l2_multiplier * smoothness_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss |
| + |
| + |
| +class DynamicTimeWarpingLoss(nn.Module): |
| + def __init__(self): |
| + super().__init__() |
| + |
| + def _dtw_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor: |
| + """ |
| + Computes the DTW distance between two 2D tensors (T x D), |
| + where T is sequence length and D is feature dimension. |
| + """ |
| + # seq1, seq2 shapes: (time_steps, feature_dim) |
| + n, m = seq1.size(0), seq2.size(0) |
| + |
| + # Cost matrix (pairwise distances between all elements) |
| + cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype) |
| + for i in range(n): |
| + for j in range(m): |
| + cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2) |
| + |
| + # Accumulated cost matrix |
| + dist = torch.full((n + 1, m + 1), float('inf'), |
| + device=seq1.device, dtype=seq1.dtype) |
| + dist[0, 0] = 0.0 |
| + |
| + # Populate the DP table |
| + for i in range(1, n + 1): |
| + for j in range(1, m + 1): |
| + dist[i, j] = cost[i - 1, j - 1] + torch.min( |
| + torch.min( |
| + dist[i - 1, j], # Insertion |
| + dist[i, j - 1], # Deletion |
| + ), |
| + dist[i - 1, j - 1]# Match |
| + ) |
| + |
| + return dist[n, m] |
| + |
| + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: |
| + """ |
| + Compute the average DTW loss over a batch of sequences. |
| + |
| + pred, targ shapes: (batch_size, T, D) |
| + """ |
| + # Ensure shapes match in batch dimension |
| + assert pred.size(0) == targ.size(0), "Batch sizes must match." |
| + |
| + # Compute DTW distance per sample in the batch |
| + distances = [] |
| + for b in range(pred.size(0)): |
| + seq1 = pred[b] |
| + seq2 = targ[b] |
| + dtw_val = self._dtw_distance(seq1, seq2) |
| + distances.append(dtw_val) |
| + |
| + # Stack and take mean to get scalar loss |
| + dtw_loss = torch.stack(distances).mean() |
| + return dtw_loss |
| + |
| + def forward(self, pred: torch.Tensor, targ: torch.Tensor): |
| + """ |
| + Returns a tuple: (loss, info_dict), |
| + where loss is a scalar tensor and info_dict is a dictionary |
| + of extra information (e.g., loss components). |
| + """ |
| + loss = self._loss(pred, targ) |
| + |
| + info = { |
| + 'dtw': loss.item() |
| + } |
| + |
| + return loss, info |
| + |
| +class HausdorffDistanceLoss(nn.Module): |
| + def __init__(self): |
| + super().__init__() |
| + |
| + def _hausdorff_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor: |
| + """ |
| + Computes the Hausdorff distance between two 2D tensors (N x D), |
| + where N is the number of points and D is the feature dimension. |
| + |
| + The Hausdorff distance H(A,B) between two sets A and B is defined as: |
| + H(A, B) = max( h(A, B), h(B, A) ), |
| + where |
| + h(A, B) = max_{a in A} min_{b in B} d(a, b). |
| + |
| + Here, d(a, b) is the Euclidean distance between points a and b. |
| + """ |
| + # set1, set2 shapes: (num_points, feature_dim) |
| + n, m = set1.size(0), set2.size(0) |
| + |
| + # Compute pairwise distances |
| + cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype) |
| + for i in range(n): |
| + for j in range(m): |
| + cost[i, j] = torch.norm(set1[i] - set2[j], p=2) |
| + |
| + # Forward direction: for each point in set1, find distance to closest point in set2 |
| + forward_min = cost.min(dim=1)[0] # Shape (n,) |
| + forward_hausdorff = forward_min.max() # max over n |
| + |
| + # Backward direction: for each point in set2, find distance to closest point in set1 |
| + backward_min = cost.min(dim=0)[0] # Shape (m,) |
| + backward_hausdorff = backward_min.max() # max over m |
| + |
| + # Hausdorff distance is the max of the two |
| + hausdorff_dist = torch.max(forward_hausdorff, backward_hausdorff) |
| + return hausdorff_dist |
| + |
| + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: |
| + """ |
| + Compute the average Hausdorff distance over a batch of point sets. |
| + |
| + pred, targ shapes: (batch_size, N, D) |
| + """ |
| + # Ensure shapes match in batch dimension |
| + assert pred.size(0) == targ.size(0), "Batch sizes must match." |
| + |
| + distances = [] |
| + for b in range(pred.size(0)): |
| + set1 = pred[b] |
| + set2 = targ[b] |
| + h_dist = self._hausdorff_distance(set1, set2) |
| + distances.append(h_dist) |
| + |
| + # Stack and take mean to get scalar loss |
| + hausdorff_loss = torch.stack(distances).mean() |
| + return hausdorff_loss |
| + |
| + def forward(self, pred: torch.Tensor, targ: torch.Tensor): |
| + """ |
| + Returns a tuple: (loss, info_dict), |
| + where loss is a scalar tensor and info_dict is a dictionary |
| + of extra information (e.g., distance components). |
| + """ |
| + loss = self._loss(pred, targ) |
| + |
| + info = { |
| + 'hausdorff': loss.item() |
| + } |
| + |
| + return loss, info |
| + |
| +class FrechetDistanceLoss(nn.Module): |
| + def __init__(self): |
| + super().__init__() |
| + |
| + def _frechet_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor: |
| + """ |
| + Computes the (discrete) Fréchet distance between two 2D tensors (T x D), |
| + where T is the sequence length and D is the feature dimension. |
| + |
| + The Fréchet distance between two curves in discrete form can be computed |
| + by filling in a DP table “ca” where: |
| + |
| + ca[i, j] = max( d(seq1[i], seq2[j]), |
| + min(ca[i-1, j], ca[i, j-1], ca[i-1, j-1]) ) |
| + |
| + with boundary conditions handled appropriately. |
| + Here, d(seq1[i], seq2[j]) is the Euclidean distance. |
| + """ |
| + n, m = seq1.size(0), seq2.size(0) |
| + |
| + # Cost matrix (pairwise distances between all elements) |
| + cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype) |
| + for i in range(n): |
| + for j in range(m): |
| + cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2) |
| + |
| + # DP matrix for the Fréchet distance |
| + ca = torch.full((n, m), float('inf'), device=seq1.device, dtype=seq1.dtype) |
| + ca[0, 0] = cost[0, 0] |
| + |
| + # Initialize first row |
| + for i in range(1, n): |
| + ca[i, 0] = torch.max(ca[i - 1, 0], cost[i, 0]) |
| + |
| + # Initialize first column |
| + for j in range(1, m): |
| + ca[0, j] = torch.max(ca[0, j - 1], cost[0, j]) |
| + |
| + # Populate the DP table |
| + for i in range(1, n): |
| + for j in range(1, m): |
| + ca[i, j] = torch.max( |
| + cost[i, j], |
| + torch.min( |
| + torch.min( |
| + ca[i - 1, j], |
| + ca[i, j - 1], |
| + ), |
| + ca[i - 1, j - 1] |
| + ) |
| + ) |
| + |
| + return ca[n - 1, m - 1] |
| + |
| + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: |
| + """ |
| + Compute the average Fréchet distance over a batch of sequences. |
| + |
| + pred, targ shapes: (batch_size, T, D) |
| + """ |
| + # Ensure shapes match in batch dimension |
| + assert pred.size(0) == targ.size(0), "Batch sizes must match." |
| + |
| + distances = [] |
| + for b in range(pred.size(0)): |
| + seq1 = pred[b] |
| + seq2 = targ[b] |
| + fd_val = self._frechet_distance(seq1, seq2) |
| + distances.append(fd_val) |
| + |
| + # Stack and take mean to get scalar loss |
| + frechet_loss = torch.stack(distances).mean() |
| + return frechet_loss |
| + |
| + def forward(self, pred: torch.Tensor, targ: torch.Tensor): |
| + """ |
| + Returns a tuple: (loss, info_dict), |
| + where loss is a scalar tensor and info_dict is a dictionary |
| + of extra information (e.g., distance components). |
| + """ |
| + loss = self._loss(pred, targ) |
| + info = { |
| + 'frechet': loss.item() |
| + } |
| + return loss, info |
| + |
| +class ChamferDistanceLoss(nn.Module): |
| + def __init__(self): |
| + super().__init__() |
| + |
| + def _chamfer_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor: |
| + """ |
| + Computes the symmetrical Chamfer distance between |
| + two 2D tensors (N x D), where N is the number of points |
| + and D is the feature dimension. |
| + |
| + The Chamfer distance between two point sets A and B is often defined as: |
| + |
| + d_chamfer(A, B) = 1/|A| ∑_{a ∈ A} min_{b ∈ B} ‖a - b‖₂ |
| + + 1/|B| ∑_{b ∈ B} min_{a ∈ A} ‖b - a‖₂, |
| + |
| + where ‖·‖₂ is the Euclidean distance. |
| + """ |
| + # set1, set2 shapes: (num_points, feature_dim) |
| + n, m = set1.size(0), set2.size(0) |
| + |
| + cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype) |
| + for i in range(n): |
| + for j in range(m): |
| + cost[i, j] = torch.norm(set1[i] - set2[j], p=2) |
| + |
| + # For each point in set1, find distance to the closest point in set2 |
| + forward_min = cost.min(dim=1)[0] # shape: (n,) |
| + forward_mean = forward_min.mean() |
| + |
| + # For each point in set2, find distance to the closest point in set1 |
| + backward_min = cost.min(dim=0)[0] # shape: (m,) |
| + backward_mean = backward_min.mean() |
| + |
| + chamfer_dist = forward_mean + backward_mean |
| + return chamfer_dist |
| + |
| + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: |
| + """ |
| + Compute the average Chamfer distance over a batch of point sets. |
| + |
| + pred, targ shapes: (batch_size, N, D) |
| + """ |
| + # Ensure shapes match in batch dimension |
| + assert pred.size(0) == targ.size(0), "Batch sizes must match." |
| + |
| + distances = [] |
| + for b in range(pred.size(0)): |
| + set1 = pred[b] |
| + set2 = targ[b] |
| + distance_val = self._chamfer_distance(set1, set2) |
| + distances.append(distance_val) |
| + |
| + # Combine into a single scalar |
| + chamfer_loss = torch.stack(distances).mean() |
| + return chamfer_loss |
| + |
| + def forward(self, pred: torch.Tensor, targ: torch.Tensor): |
| + """ |
| + Returns a tuple: (loss, info_dict), |
| + where 'loss' is a scalar tensor and 'info_dict' is a dictionary |
| + of extra information (e.g., distance components). |
| + """ |
| + loss = self._loss(pred, targ) |
| + info = { |
| + 'chamfer': loss.item() |
| + } |
| + return loss, info |
| + |
| + |
| +def slerp(q1, q2, t): |
| + """Spherical linear interpolation between two quaternions.""" |
| + q1 = q1 / np.linalg.norm(q1) |
| + q2 = q2 / np.linalg.norm(q2) |
| + dot = np.dot(q1, q2) |
| + |
| + if dot < 0.0: |
| + q2 = -q2 |
| + dot = -dot |
| + # If dot is very close to 1, use linear interpolation |
| + |
| + if dot > 0.9995: |
| + result = q1 + t * (q2 - q1) |
| + result = result / np.linalg.norm(result) |
| + return result |
| + |
| + theta_0 = np.arccos(dot) |
| + theta = theta_0 * t |
| + |
| + q3 = q2 - q1 * dot |
| + q3 = q3 / np.linalg.norm(q3) |
| + return q1 * np.cos(theta) + q3 * np.sin(theta) |
| + |
| +def catmull_rom_spline_with_rotation(control_points, timepoints, horizon): |
| + """Compute Catmull-Rom spline for both position and quaternion rotation.""" |
| + spline_points = [] |
| + # Extrapolate the initial points |
| + if timepoints[0] != 0: |
| + for t in range(timepoints[0]): |
| + x = control_points[0][0] |
| + y = control_points[0][1] |
| + z = control_points[0][2] |
| + q = control_points[0][3:7] |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) |
| + |
| + #Linear interpolate between 0th and 1th control points |
| + for t in np.linspace(0, 1, timepoints[1] - timepoints[0] + 1): |
| + x = control_points[0][0] + t * (control_points[1][0] - control_points[0][0]) |
| + y = control_points[0][1] + t * (control_points[1][1] - control_points[0][1]) |
| + z = control_points[0][2] + t * (control_points[1][2] - control_points[0][2]) |
| + q = slerp(control_points[0][3:7], control_points[1][3:7], t) |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) |
| + |
| + |
| + # Iterate over the control points |
| + for i in range(1, len(control_points) - 2): |
| + P0 = control_points[i-1][:3] |
| + P1 = control_points[i][:3] |
| + P2 = control_points[i+1][:3] |
| + P3 = control_points[i+2][:3] |
| + Q0 = control_points[i-1][3:7] |
| + Q1 = control_points[i][3:7] |
| + Q2 = control_points[i+1][3:7] |
| + Q3 = control_points[i+2][3:7] |
| + |
| + # Interpolate position (using Catmull-Rom spline) |
| + for idx, t in enumerate(np.linspace(0, 1, timepoints[i+1] - timepoints[i] + 1)): |
| + if idx == 0: |
| + continue |
| + |
| + x = 0.5 * ((2 * P1[0]) + (-P0[0] + P2[0]) * t + |
| + (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t**2 + |
| + (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t**3) |
| + y = 0.5 * ((2 * P1[1]) + (-P0[1] + P2[1]) * t + |
| + (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t**2 + |
| + (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t**3) |
| + z = 0.5 * ((2 * P1[2]) + (-P0[2] + P2[2]) * t + |
| + (2 * P0[2] - 5 * P1[2] + 4 * P2[2] - P3[2]) * t**2 + |
| + (-P0[2] + 3 * P1[2] - 3 * P2[2] + P3[2]) * t**3) |
| + q = slerp(Q1, Q2, t) |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) |
| + |
| + #Linear interpolate between 2nd last and last control points |
| + for idx, t in enumerate(np.linspace(0, 1, timepoints[-1] - timepoints[-2] + 1)): |
| + if idx == 0: |
| + continue |
| + x = control_points[-2][0] + t * (control_points[-1][0] - control_points[-2][0]) |
| + y = control_points[-2][1] + t * (control_points[-1][1] - control_points[-2][1]) |
| + z = control_points[-2][2] + t * (control_points[-1][2] - control_points[-2][2]) |
| + q = slerp(control_points[-2][3:7], control_points[-1][3:7], t) |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) |
| + |
| + # Extrapolate the rest of the points |
| + if timepoints[-1] != horizon: |
| + for t in range(timepoints[-1] + 1, horizon): |
| + x = control_points[-1][0] |
| + y = control_points[-1][1] |
| + z = control_points[-1][2] |
| + q = control_points[-1][3:7] |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) |
| + |
| + stacked_spline_points = np.stack(spline_points, axis=0) |
| + |
| + if control_points.shape[1] != 7: |
| + stacked_spline_points = np.concatenate([stacked_spline_points, np.zeros((stacked_spline_points.shape[0], 1))], axis=1) |
| + |
| + |
| + return stacked_spline_points |
| + |
| +def catmull_rom_loss(trajectories, conditions, loss_fc): |
| + ''' |
| + loss for catmull-rom interpolation |
| + ''' |
| + batch_size, horizon, transition = trajectories.shape |
| + |
| + # Extract known indices and values |
| + known_indices = np.array(list(conditions.keys()), dtype=int) |
| + |
| + # candidate_no x batch_size x dim |
| + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) |
| + known_values = np.moveaxis(known_values, 0, 1) |
| + |
| + # Sort the timepoints |
| + sorted_indices = np.argsort(known_indices) |
| + known_indices = known_indices[sorted_indices] |
| + known_values = known_values[:, sorted_indices] |
| + spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)]) |
| + |
| + # Convert to tensor and move to the same device as trajectories |
| + spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device) |
| + assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}" |
| + return loss_fc(spline_points, trajectories) |
| + |
| Losses = { |
| 'l1': WeightedL1, |
| 'l2': WeightedL2, |
| 'value_l1': ValueL1, |
| 'value_l2': ValueL2, |
| + 'geodesic_l2': GeodesicL2Loss, |
| + 'rotation_translation': RotationTranslationLoss, |
| + 'spline': SplineLoss, |
| } |
| diff --git a/diffuser/models/temporal.py b/diffuser/models/temporal.py |
| index e0b9e5c..0f7854a 100644 |
| --- a/diffuser/models/temporal.py |
| +++ b/diffuser/models/temporal.py |
| @@ -17,18 +17,18 @@ class ResidualTemporalBlock(nn.Module): |
| super().__init__() |
| |
| self.blocks = nn.ModuleList([ |
| - Conv1dBlock(inp_channels, out_channels, kernel_size), |
| - Conv1dBlock(out_channels, out_channels, kernel_size), |
| + Conv1dBlock(inp_channels, out_channels, kernel_size).to(dtype=torch.float64), |
| + Conv1dBlock(out_channels, out_channels, kernel_size).to(dtype=torch.float64), |
| ]) |
| |
| self.time_mlp = nn.Sequential( |
| nn.Mish(), |
| - nn.Linear(embed_dim, out_channels), |
| + nn.Linear(embed_dim, out_channels).to(dtype=torch.float64), |
| Rearrange('batch t -> batch t 1'), |
| - ) |
| + ).to(dtype=torch.float64) |
| |
| - self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ |
| - if inp_channels != out_channels else nn.Identity() |
| + self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1).to(dtype=torch.float64) \ |
| + if inp_channels != out_channels else nn.Identity().to(dtype=torch.float64) |
| |
| def forward(self, x, t): |
| ''' |
| @@ -37,7 +37,8 @@ class ResidualTemporalBlock(nn.Module): |
| returns: |
| out : [ batch_size x out_channels x horizon ] |
| ''' |
| - out = self.blocks[0](x) + self.time_mlp(t) |
| + |
| + out = self.blocks[0](x) + self.time_mlp(t.double()) |
| out = self.blocks[1](out) |
| return out + self.residual_conv(x) |
| |
| @@ -49,11 +50,11 @@ class TemporalUnet(nn.Module): |
| transition_dim, |
| cond_dim, |
| dim=32, |
| - dim_mults=(1, 2, 4, 8), |
| + dim_mults=(1, 2, 4), |
| ): |
| super().__init__() |
| |
| - dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] |
| + dims = [(transition_dim + cond_dim), *map(lambda m: dim * m, dim_mults)] |
| in_out = list(zip(dims[:-1], dims[1:])) |
| print(f'[ models/temporal ] Channel dimensions: {in_out}') |
| |
| @@ -100,7 +101,7 @@ class TemporalUnet(nn.Module): |
| |
| self.final_conv = nn.Sequential( |
| Conv1dBlock(dim, dim, kernel_size=5), |
| - nn.Conv1d(dim, transition_dim, 1), |
| + nn.Conv1d(dim, transition_dim, 1).to(dtype=torch.float64), |
| ) |
| |
| def forward(self, x, cond, time): |
| @@ -129,7 +130,6 @@ class TemporalUnet(nn.Module): |
| x = upsample(x) |
| |
| x = self.final_conv(x) |
| - |
| x = einops.rearrange(x, 'b t h -> b h t') |
| return x |
| |
| diff --git a/diffuser/utils/arrays.py b/diffuser/utils/arrays.py |
| index c3a9d24..96a7093 100644 |
| --- a/diffuser/utils/arrays.py |
| +++ b/diffuser/utils/arrays.py |
| @@ -54,7 +54,7 @@ def batchify(batch): |
| 1) converting np arrays to torch tensors and |
| 2) and ensuring that everything has a batch dimension |
| ''' |
| - fn = lambda x: to_torch(x[None]) |
| + fn = lambda x: to_torch(x[None], dtype=torch.float64) |
| |
| batched_vals = [] |
| for field in batch._fields: |
| diff --git a/diffuser/utils/serialization.py b/diffuser/utils/serialization.py |
| index 6cc9db9..039eb64 100644 |
| --- a/diffuser/utils/serialization.py |
| +++ b/diffuser/utils/serialization.py |
| @@ -19,7 +19,7 @@ def mkdir(savepath): |
| return False |
| |
| def get_latest_epoch(loadpath): |
| - states = glob.glob1(os.path.join(*loadpath), 'state_*') |
| + states = glob.glob1(os.path.join(loadpath), 'state_*') |
| latest_epoch = -1 |
| for state in states: |
| epoch = int(state.replace('state_', '').replace('.pt', '')) |
| diff --git a/diffuser/utils/training.py b/diffuser/utils/training.py |
| index be3556e..c21e0f0 100644 |
| --- a/diffuser/utils/training.py |
| +++ b/diffuser/utils/training.py |
| @@ -4,16 +4,24 @@ import numpy as np |
| import torch |
| import einops |
| import pdb |
| +from tqdm import tqdm |
| +import wandb |
| +from pytorch3d.transforms import axis_angle_to_quaternion |
| |
| from .arrays import batch_to_device, to_np, to_device, apply_dict |
| from .timer import Timer |
| from .cloud import sync_logs |
| +from ..models.helpers import catmull_rom_spline_with_rotation |
| |
| def cycle(dl): |
| while True: |
| for data in dl: |
| yield data |
| |
| +def assert_no_nan_weights(model): |
| + for name, param in model.named_parameters(): |
| + assert not torch.isnan(param).any(), f"NaN detected in parameter: {name}" |
| + |
| class EMA(): |
| ''' |
| empirical moving average |
| @@ -71,13 +79,35 @@ class Trainer(object): |
| self.gradient_accumulate_every = gradient_accumulate_every |
| |
| self.dataset = dataset |
| - self.dataloader = cycle(torch.utils.data.DataLoader( |
| - self.dataset, batch_size=train_batch_size, num_workers=1, shuffle=True, pin_memory=True |
| + dataset_size = len(self.dataset) |
| + |
| + # Read the indices from the .txt file |
| + with open(os.path.join(results_folder, 'train_indices.txt'), 'r') as f: |
| + self.train_indices = f.read() |
| + self.train_indices = [int(i) for i in self.train_indices.split('\n') if i] |
| + |
| + with open(os.path.join(results_folder, 'val_indices.txt'), 'r') as f: |
| + self.val_indices = f.read() |
| + self.val_indices = [int(i) for i in self.val_indices.split('\n') if i] |
| + |
| + |
| + self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices) |
| + self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices) |
| + self.train_dataloader = cycle(torch.utils.data.DataLoader( |
| + self.train_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False |
| + )) |
| + |
| + self.val_dataloader = cycle(torch.utils.data.DataLoader( |
| + self.val_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False |
| )) |
| + |
| self.dataloader_vis = cycle(torch.utils.data.DataLoader( |
| self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True |
| )) |
| self.renderer = renderer |
| + |
| + |
| + |
| self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr) |
| |
| self.logdir = results_folder |
| @@ -88,6 +118,8 @@ class Trainer(object): |
| self.reset_parameters() |
| self.step = 0 |
| |
| + self.log_to_wandb = False |
| + |
| def reset_parameters(self): |
| self.ema_model.load_state_dict(self.model.state_dict()) |
| |
| @@ -102,36 +134,129 @@ class Trainer(object): |
| #-----------------------------------------------------------------------------# |
| |
| def train(self, n_train_steps): |
| - |
| + # Save the indices as .txt files |
| + with open(os.path.join(self.logdir, 'train_indices.txt'), 'w') as f: |
| + for idx in self.train_indices: |
| + f.write(f"{idx}\n") |
| + with open(os.path.join(self.logdir, 'val_indices.txt'), 'w') as f: |
| + for idx in self.val_indices: |
| + f.write(f"{idx}\n") |
| + |
| timer = Timer() |
| - for step in range(n_train_steps): |
| + torch.autograd.set_detect_anomaly(True) |
| + |
| + # Setup wandb |
| + if self.log_to_wandb: |
| + wandb.init( |
| + project='trajectory-generation', |
| + config={'lr': self.optimizer.param_groups[0]['lr'], 'batch_size': self.batch_size, 'gradient_accumulate_every': self.gradient_accumulate_every}, |
| + ) |
| + |
| + for step in tqdm(range(n_train_steps)): |
| + |
| + mean_train_loss = 0.0 |
| for i in range(self.gradient_accumulate_every): |
| - batch = next(self.dataloader) |
| + batch = next(self.train_dataloader) |
| batch = batch_to_device(batch) |
| - |
| - loss, infos = self.model.loss(*batch) |
| + |
| + loss, infos = self.model.loss(x=batch.trajectories, cond=batch.conditions) |
| loss = loss / self.gradient_accumulate_every |
| + mean_train_loss += loss.item() |
| loss.backward() |
| |
| + if self.log_to_wandb: |
| + wandb.log({ |
| + 'step': self.step, |
| + 'train/loss': mean_train_loss |
| + }) |
| + |
| + # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| + |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| |
| + assert_no_nan_weights(self.model) |
| + |
| if self.step % self.update_ema_every == 0: |
| self.step_ema() |
| |
| if self.step % self.save_freq == 0: |
| - label = self.step // self.label_freq * self.label_freq |
| + label = self.step |
| + print(f'Saving model at step {self.step}...') |
| self.save(label) |
| |
| if self.step % self.log_freq == 0: |
| - infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()]) |
| - print(f'{self.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}') |
| + val_losses = [] |
| + lin_int_losses = [] |
| + |
| + val_infos_list = [] |
| + lin_int_infos_list = [] |
| + |
| + catmull_losses = [] |
| + catmull_infos_list = [] |
| + |
| + for _ in range(len(self.val_indices)): |
| + val_batch = next(self.val_dataloader) |
| + val_batch = batch_to_device(val_batch) |
| + |
| + traj = self.model.forward(val_batch.conditions, horizon=val_batch.trajectories.shape[1]) |
| + val_loss, val_infos = self.model.loss_fn(traj, val_batch.trajectories, cond=val_batch.conditions) |
| + |
| + val_losses.append(val_loss.item()) |
| + val_infos_list.append({key: val for key, val in val_infos.items()}) |
| + |
| + |
| + (lin_int_loss, lin_int_infos), lin_int_traj = self.linear_interpolation_loss( |
| + val_batch.trajectories, val_batch.conditions, self.model.loss_fn |
| + ) |
| + lin_int_losses.append(lin_int_loss.item()) |
| + lin_int_infos_list.append({key: val for key, val in lin_int_infos.items()}) |
| + |
| + (catmull_loss, catmull_infos), catmull_traj = self.catmull_rom_loss( |
| + val_batch.trajectories, val_batch.conditions, self.model.loss_fn |
| + ) |
| + |
| + catmull_losses.append(catmull_loss.item()) |
| + catmull_infos_list.append(catmull_infos) |
| + |
| + avg_val_loss = np.mean(val_losses) |
| + avg_lin_int_loss = np.mean(lin_int_losses) |
| + |
| + val_infos = {key: np.mean([info[key] for info in val_infos_list]) for key in val_infos_list[0].keys()} |
| + lin_int_infos = {key: np.mean([info[key] for info in lin_int_infos_list]) for key in lin_int_infos_list[0].keys()} |
| |
| - if self.step == 0 and self.sample_freq: |
| - self.render_reference(self.n_reference) |
| + avg_catmull_loss = np.mean(catmull_losses) |
| + catmull_infos = {key: np.mean([info[key] for info in catmull_infos_list]) for key in catmull_infos_list[0].keys()} |
| |
| - if self.sample_freq and self.step % self.sample_freq == 0: |
| - self.render_samples(n_samples=self.n_samples) |
| + val_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in val_infos.items()]) |
| + lin_int_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in lin_int_infos.items()]) |
| + catmull_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in catmull_infos.items()]) |
| + |
| + |
| + infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()]) |
| + print("Learning Rate: ", self.optimizer.param_groups[0]['lr']) |
| + print(f'Step {self.step}: {loss * self.gradient_accumulate_every:8.4f} | {infos_str} | t: {timer():8.4f}') |
| + print(f'Validation - {self.step}: {avg_val_loss:8.4f} | {val_infos_str} | t: {timer():8.4f}') |
| + print(f'Linear Interpolation Loss - {self.step}: {avg_lin_int_loss:8.4f} | {lin_int_infos_str} | t: {timer():8.4f}') |
| + print(f'Catmull Rom Loss - {self.step}: {avg_catmull_loss:8.4f} | {catmull_infos_str} | t: {timer():8.4f}') |
| + print() |
| + |
| + if self.log_to_wandb: |
| + wandb.log({ |
| + 'step': self.step, |
| + 'val/loss': avg_val_loss, |
| + 'val/linear_interp/loss': avg_lin_int_loss, |
| + 'val/linear_interp/quaternion dist.': lin_int_infos['quat. dist.'], |
| + 'val/linear_interp/euclidean dist.': lin_int_infos['trans. error'], |
| + 'val/linear_interp/geodesic loss': lin_int_infos['geodesic dist.'], |
| + 'val/catmull_rom/loss': avg_catmull_loss, |
| + 'val/catmull_rom/quaternion dist.': catmull_infos['quat. dist.'], |
| + 'val/catmull_rom/euclidean dist.': catmull_infos['trans. error'], |
| + 'val/catmull_rom/geodesic loss': catmull_infos['geodesic dist.'], |
| + 'val/quaternion dist.': val_infos['quat. dist.'], |
| + 'val/euclidean dist.': val_infos['trans. error'], |
| + 'val/geodesic loss': val_infos['geodesic dist.'], |
| + }) |
| |
| self.step += 1 |
| |
| @@ -186,15 +311,6 @@ class Trainer(object): |
| normed_observations = trajectories[:, :, self.dataset.action_dim:] |
| observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') |
| |
| - # from diffusion.datasets.preprocessing import blocks_cumsum_quat |
| - # # observations = conditions + blocks_cumsum_quat(deltas) |
| - # observations = conditions + deltas.cumsum(axis=1) |
| - |
| - #### @TODO: remove block-stacking specific stuff |
| - # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka |
| - # observations = blocks_add_kuka(observations) |
| - #### |
| - |
| savepath = os.path.join(self.logdir, f'_sample-reference.png') |
| self.renderer.composite(savepath, observations) |
| |
| @@ -225,9 +341,6 @@ class Trainer(object): |
| # [ 1 x 1 x observation_dim ] |
| normed_conditions = to_np(batch.conditions[0])[:,None] |
| |
| - # from diffusion.datasets.preprocessing import blocks_cumsum_quat |
| - # observations = conditions + blocks_cumsum_quat(deltas) |
| - # observations = conditions + deltas.cumsum(axis=1) |
| |
| ## [ n_samples x (horizon + 1) x observation_dim ] |
| normed_observations = np.concatenate([ |
| @@ -238,10 +351,70 @@ class Trainer(object): |
| ## [ n_samples x (horizon + 1) x observation_dim ] |
| observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') |
| |
| - #### @TODO: remove block-stacking specific stuff |
| - # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka |
| - # observations = blocks_add_kuka(observations) |
| - #### |
| - |
| savepath = os.path.join(self.logdir, f'sample-{self.step}-{i}.png') |
| self.renderer.composite(savepath, observations) |
| + |
| + def linear_interpolation_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None): |
| + batch_size, horizon, transition = trajectories.shape |
| + |
| + # Extract known indices and values |
| + known_indices = np.array(list(conditions.keys()), dtype=int) |
| + # candidate_no x batch_size x dim |
| + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) |
| + known_values = np.moveaxis(known_values, 0, 1) |
| + |
| + # Create time steps for interpolation |
| + time_steps = np.linspace(0, horizon, num=horizon) |
| + |
| + # Perform interpolation across all dimensions at once |
| + linear_int_arr = np.array([[ |
| + np.interp(time_steps, known_indices, known_values[b, :, dim]) |
| + for dim in range(transition)] |
| + for b in range(batch_size)] |
| + ).T # Transpose to match shape (horizon, transition) |
| + |
| + # Convert to tensor and move to the same device as trajectories |
| + linear_int_arr = np.transpose(linear_int_arr, axes=[2, 0, 1]) |
| + linear_int_tensor = torch.tensor(linear_int_arr, dtype=torch.float64, device=trajectories.device) |
| + |
| + return loss_fc(linear_int_tensor, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), linear_int_tensor |
| + |
| + |
| + def catmull_rom_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None): |
| + ''' |
| + loss for catmull-rom interpolation |
| + ''' |
| + |
| + batch_size, horizon, transition = trajectories.shape |
| + |
| + # Extract known indices and values |
| + known_indices = np.array(list(conditions.keys()), dtype=int) |
| + # candidate_no x batch_size x dim |
| + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) |
| + known_values = np.moveaxis(known_values, 0, 1) |
| + |
| + # Sort the timepoints |
| + sorted_indices = np.argsort(known_indices) |
| + known_indices = known_indices[sorted_indices] |
| + known_values = known_values[:, sorted_indices] |
| + |
| + spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)]) |
| + |
| + # Convert to tensor and move to the same device as trajectories |
| + spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device) |
| + |
| + assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}" |
| + |
| + return loss_fc(spline_points, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), spline_points |
| + |
| + |
| + |
| + |
| + |
| + |
| + |
| + |
| + |
| + |
| + |
| + |
| diff --git a/scripts/train.py b/scripts/train.py |
| index 2c5f299..6728d6f 100644 |
| --- a/scripts/train.py |
| +++ b/scripts/train.py |
| @@ -108,6 +108,7 @@ utils.report_parameters(model) |
| |
| print('Testing forward...', end=' ', flush=True) |
| batch = utils.batchify(dataset[0]) |
| + |
| loss, _ = diffusion.loss(*batch) |
| loss.backward() |
| print('✓') |