|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import copy |
|
|
import math |
|
|
from tqdm.auto import tqdm |
|
|
import functools |
|
|
from torch.utils.data import DataLoader |
|
|
import os |
|
|
import argparse |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
def process_dic(state_dict): |
|
|
new_state_dict = {} |
|
|
for k,v in state_dict.items(): |
|
|
if 'module' in k: |
|
|
new_state_dict[k[7:]] = v |
|
|
else: |
|
|
new_state_dict[k] = v |
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
def calc_distogram(pos, min_bin, max_bin, num_bins): |
|
|
dists_2d = torch.linalg.norm( |
|
|
pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None] |
|
|
lower = torch.linspace( |
|
|
min_bin, |
|
|
max_bin, |
|
|
num_bins, |
|
|
device=pos.device) |
|
|
upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) |
|
|
dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) |
|
|
return dgram |
|
|
|
|
|
|
|
|
def get_index_embedding(indices, embed_size, max_len=2056): |
|
|
"""Creates sine / cosine positional embeddings from a prespecified indices. |
|
|
|
|
|
Args: |
|
|
indices: offsets of size [..., N_edges] of type integer |
|
|
max_len: maximum length. |
|
|
embed_size: dimension of the embeddings to create |
|
|
|
|
|
Returns: |
|
|
positional embedding of shape [N, embed_size] |
|
|
""" |
|
|
K = torch.arange(embed_size//2, device=indices.device) |
|
|
pos_embedding_sin = torch.sin( |
|
|
indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device) |
|
|
pos_embedding_cos = torch.cos( |
|
|
indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device) |
|
|
pos_embedding = torch.cat([ |
|
|
pos_embedding_sin, pos_embedding_cos], axis=-1) |
|
|
return pos_embedding |
|
|
|
|
|
|
|
|
def get_time_embedding(timesteps, embedding_dim, max_positions=2000): |
|
|
|
|
|
assert len(timesteps.shape) == 1 |
|
|
timesteps = timesteps * max_positions |
|
|
half_dim = embedding_dim // 2 |
|
|
emb = math.log(max_positions) / (half_dim - 1) |
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) |
|
|
emb = timesteps.float()[:, None] * emb[None, :] |
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
|
if embedding_dim % 2 == 1: |
|
|
emb = F.pad(emb, (0, 1), mode='constant') |
|
|
assert emb.shape == (timesteps.shape[0], embedding_dim) |
|
|
return emb |