PepFlow / models_con /utils.py
Irwiny123's picture
添加PepFlow模型初始代码
ef423c5
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from tqdm.auto import tqdm
import functools
from torch.utils.data import DataLoader
import os
import argparse
import pandas as pd
def process_dic(state_dict):
new_state_dict = {}
for k,v in state_dict.items():
if 'module' in k:
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
return new_state_dict
def calc_distogram(pos, min_bin, max_bin, num_bins):
dists_2d = torch.linalg.norm(
pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None]
lower = torch.linspace(
min_bin,
max_bin,
num_bins,
device=pos.device)
upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
return dgram
def get_index_embedding(indices, embed_size, max_len=2056):
"""Creates sine / cosine positional embeddings from a prespecified indices.
Args:
indices: offsets of size [..., N_edges] of type integer
max_len: maximum length.
embed_size: dimension of the embeddings to create
Returns:
positional embedding of shape [N, embed_size]
"""
K = torch.arange(embed_size//2, device=indices.device)
pos_embedding_sin = torch.sin(
indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
pos_embedding_cos = torch.cos(
indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
pos_embedding = torch.cat([
pos_embedding_sin, pos_embedding_cos], axis=-1)
return pos_embedding
def get_time_embedding(timesteps, embedding_dim, max_positions=2000):
# Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
assert len(timesteps.shape) == 1
timesteps = timesteps * max_positions
half_dim = embedding_dim // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb