| | import torch |
| |
|
| |
|
| | def weighting_function(x, samples, gamma): |
| | pairwise_sq_diff = (x[:, None, :] - samples[None, :, :]) ** 2 |
| | pairwise_sq_dist = pairwise_sq_diff.sum(-1) |
| | weights = torch.exp(-pairwise_sq_dist / (2 * gamma**2)) |
| | return weights |
| |
|
| |
|
| | def land_metric_tensor(x, samples, gamma, rho): |
| | weights = weighting_function(x, samples, gamma) |
| | differences = samples[None, :, :] - x[:, None, :] |
| | squared_differences = differences**2 |
| |
|
| | |
| | M_dd_diag = torch.einsum("bn,bnd->bd", weights, squared_differences) + rho |
| |
|
| | |
| | M_dd_inv_diag = 1.0 / M_dd_diag |
| | return M_dd_inv_diag |
| |
|
| |
|
| | def weighting_function_dt(x, dx_dt, samples, gamma, weights): |
| | pairwise_sq_diff_dt = (x[:, None, :] - samples[None, :, :]) * dx_dt[:, None, :] |
| | return -pairwise_sq_diff_dt.sum(-1) * weights / (gamma**2) |
| |
|