moPPIt / models /peptide_classifiers.py
AlienChen's picture
Update models/peptide_classifiers.py
30fd543 verified
import pdb
import torch
import torch.nn.functional as F
import torch.nn as nn
import pytorch_lightning as pl
import time
from transformers import AutoModel, AutoConfig, AutoTokenizer
import xgboost as xgb
import esm
from flow_matching.path import MixtureDiscreteProbPath
from flow_matching.path.scheduler import PolynomialConvexScheduler
from flow_matching.solver import MixtureDiscreteEulerSolver
from flow_matching.utils import ModelWrapper
from flow_matching.loss import MixturePathGeneralizedKL
from models.peptide_models import CNNModel
from modules.bindevaluator_modules import *
def parse_motifs(motif: str) -> list:
parts = motif.split(',')
result = []
for part in parts:
part = part.strip()
if '-' in part:
start, end = map(int, part.split('-'))
result.extend(range(start, end + 1))
else:
result.append(int(part))
result = [pos-1 for pos in result]
print(f'Target Motifs: {result}')
return torch.tensor(result)
class BindEvaluator(pl.LightningModule):
def __init__(self, n_layers, d_model, d_hidden, n_head,
d_k, d_v, d_inner, dropout=0.2,
learning_rate=0.00001, max_epochs=15, kl_weight=1):
super(BindEvaluator, self).__init__()
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.esm_model.eval()
# freeze all the esm_model parameters
for param in self.esm_model.parameters():
param.requires_grad = False
self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
n_head, d_k, d_v, d_inner, dropout=dropout)
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
d_k, d_v, dropout=dropout)
self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
self.output_projection_prot = nn.Linear(d_model, 1)
self.learning_rate = learning_rate
self.max_epochs = max_epochs
self.kl_weight = kl_weight
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
self.historical_memory = 0.9
self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
def forward(self, binder_tokens, target_tokens):
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
protein_sequence)
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
prot_enc = self.final_ffn(prot_enc)
prot_enc = self.output_projection_prot(prot_enc)
return prot_enc
def get_probs(self, x_t, target_sequence):
'''
Inputs:
- xt: Shape (bsz, seq_len)
- target_sequence: Shape (1, tgt_len)
'''
# pdb.set_trace()
target_sequence = target_sequence.repeat(x_t.shape[0], 1)
binder_attention_mask = torch.ones_like(x_t)
target_attention_mask = torch.ones_like(target_sequence)
binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0
target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0
binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)}
target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)}
logits = self.forward(binder_tokens, target_tokens).squeeze(-1)
# pdb.set_trace()
logits[:, 0] = logits[:, -1] = -100 # float('-inf')
probs = torch.sigmoid(logits)
return probs # shape (bsz, tgt_len)
def motif_score(self, x_t, target_sequence, motifs):
probs = self.get_probs(x_t, target_sequence)
motif_probs = probs[:, motifs]
motif_score = motif_probs.sum(dim=-1) / len(motifs)
# pdb.set_trace()
return motif_score
def non_motif_score(self, x_t, target_sequence, motifs):
probs = self.get_probs(x_t, target_sequence)
non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
mask = non_motif_probs >= 0.5
count = mask.sum(dim=-1)
non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
return non_motif_score
def scoring(self, x_t, target_sequence, motifs, penalty=False):
probs = self.get_probs(x_t, target_sequence)
motif_probs = probs[:, motifs]
motif_score = motif_probs.sum(dim=-1) / len(motifs)
# pdb.set_trace()
if penalty:
non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
mask = non_motif_probs >= 0.5
count = mask.sum(dim=-1)
# non_motif_score = 1 - torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
non_motif_score = count / target_sequence.shape[1]
return motif_score, 1 - non_motif_score
else:
return motif_score
class MotifModel(nn.Module):
def __init__(self, bindevaluator, target_sequence, motifs, penalty=False):
super(MotifModel, self).__init__()
self.bindevaluator = bindevaluator
self.target_sequence = target_sequence
self.motifs = motifs
self.penalty = penalty
def forward(self, x):
return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
class UnpooledBindingPredictor(nn.Module):
def __init__(self,
esm_model_name="facebook/esm2_t33_650M_UR50D",
hidden_dim=512,
kernel_sizes=[3, 5, 7],
n_heads=8,
n_layers=3,
dropout=0.1,
freeze_esm=True):
super().__init__()
# Define binding thresholds
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
# Load ESM model for computing embeddings on the fly
self.esm_model = AutoModel.from_pretrained(esm_model_name)
self.config = AutoConfig.from_pretrained(esm_model_name)
# Freeze ESM parameters if needed
if freeze_esm:
for param in self.esm_model.parameters():
param.requires_grad = False
# Get ESM hidden size
esm_dim = self.config.hidden_size
# Output channels for CNN layers
output_channels_per_kernel = 64
# CNN layers for handling variable length sequences
self.protein_conv_layers = nn.ModuleList([
nn.Conv1d(
in_channels=esm_dim,
out_channels=output_channels_per_kernel,
kernel_size=k,
padding='same'
) for k in kernel_sizes
])
self.binder_conv_layers = nn.ModuleList([
nn.Conv1d(
in_channels=esm_dim,
out_channels=output_channels_per_kernel,
kernel_size=k,
padding='same'
) for k in kernel_sizes
])
# Calculate total features after convolution and pooling
total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
# Project to same dimension after CNN processing
self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
self.protein_norm = nn.LayerNorm(hidden_dim)
self.binder_norm = nn.LayerNorm(hidden_dim)
# Cross attention blocks with layer norm
self.cross_attention_layers = nn.ModuleList([
nn.ModuleDict({
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
'norm1': nn.LayerNorm(hidden_dim),
'ffn': nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim)
),
'norm2': nn.LayerNorm(hidden_dim)
}) for _ in range(n_layers)
])
# Prediction heads
self.shared_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
# Regression head
self.regression_head = nn.Linear(hidden_dim, 1)
# Classification head (3 classes: tight, medium, loose binding)
self.classification_head = nn.Linear(hidden_dim, 3)
def get_binding_class(self, affinity):
"""Convert affinity values to class indices
0: tight binding (>= 7.5)
1: medium binding (6.0-7.5)
2: weak binding (< 6.0)
"""
if isinstance(affinity, torch.Tensor):
tight_mask = affinity >= self.tight_threshold
weak_mask = affinity < self.weak_threshold
medium_mask = ~(tight_mask | weak_mask)
classes = torch.zeros_like(affinity, dtype=torch.long)
classes[medium_mask] = 1
classes[weak_mask] = 2
return classes
else:
if affinity >= self.tight_threshold:
return 0 # tight binding
elif affinity < self.weak_threshold:
return 2 # weak binding
else:
return 1 # medium binding
def compute_embeddings(self, input_ids, attention_mask=None):
"""Compute ESM embeddings on the fly"""
esm_outputs = self.esm_model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
return esm_outputs.last_hidden_state
def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
"""Process a sequence through CNN layers and pooling"""
# Transpose for CNN: [batch_size, hidden_size, seq_length]
x = unpooled_emb.transpose(1, 2)
# Apply CNN layers and collect outputs
conv_outputs = []
for conv in conv_layers:
conv_out = F.relu(conv(x))
conv_outputs.append(conv_out)
# Concatenate along channel dimension
conv_output = torch.cat(conv_outputs, dim=1)
# Global pooling (both max and average)
# If attention mask is provided, use it to create a proper mask for pooling
if attention_mask is not None:
# Create a mask for pooling (1 for valid positions, 0 for padding)
# Expand mask to match conv_output channels
expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
# Apply mask (set padding to large negative value for max pooling)
masked_output = conv_output.clone()
masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
# Max pooling along sequence dimension
max_pooled = torch.max(masked_output, dim=2)[0]
# Average pooling (sum divided by number of valid positions)
sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
valid_positions = torch.sum(expanded_mask, dim=2)
valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
avg_pooled = sum_pooled / valid_positions
else:
# If no mask, use standard pooling
max_pooled = torch.max(conv_output, dim=2)[0]
avg_pooled = torch.mean(conv_output, dim=2)
# Concatenate the pooled features
pooled = torch.cat([max_pooled, avg_pooled], dim=1)
return pooled
def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
# Compute embeddings on the fly using the ESM model
protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
# Process protein and binder sequences through CNN layers
protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
# Project to same dimension
protein = self.protein_norm(self.protein_projection(protein_features))
binder = self.binder_norm(self.binder_projection(binder_features))
# Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
protein = protein.unsqueeze(0)
binder = binder.unsqueeze(0)
# Cross attention layers
for layer in self.cross_attention_layers:
# Protein attending to binder
attended_protein = layer['attention'](
protein, binder, binder
)[0]
protein = layer['norm1'](protein + attended_protein)
protein = layer['norm2'](protein + layer['ffn'](protein))
# Binder attending to protein
attended_binder = layer['attention'](
binder, protein, protein
)[0]
binder = layer['norm1'](binder + attended_binder)
binder = layer['norm2'](binder + layer['ffn'](binder))
# Remove sequence dimension
protein_pool = protein.squeeze(0)
binder_pool = binder.squeeze(0)
# Concatenate both representations
combined = torch.cat([protein_pool, binder_pool], dim=-1)
# Shared features
shared_features = self.shared_head(combined)
regression_output = self.regression_head(shared_features)
# classification_logits = self.classification_head(shared_features)
# return regression_output, classification_logits
return regression_output
class ImprovedBindingPredictor(nn.Module):
def __init__(self,
esm_dim=1280,
smiles_dim=1280,
hidden_dim=512,
n_heads=8,
n_layers=5,
dropout=0.1):
super().__init__()
# Define binding thresholds
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
# Project to same dimension
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
self.protein_norm = nn.LayerNorm(hidden_dim)
self.smiles_norm = nn.LayerNorm(hidden_dim)
# Cross attention blocks with layer norm
self.cross_attention_layers = nn.ModuleList([
nn.ModuleDict({
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
'norm1': nn.LayerNorm(hidden_dim),
'ffn': nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim)
),
'norm2': nn.LayerNorm(hidden_dim)
}) for _ in range(n_layers)
])
# Prediction heads
self.shared_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
# Regression head
self.regression_head = nn.Linear(hidden_dim, 1)
# Classification head (3 classes: tight, medium, loose binding)
self.classification_head = nn.Linear(hidden_dim, 3)
def get_binding_class(self, affinity):
"""Convert affinity values to class indices
0: tight binding (>= 7.5)
1: medium binding (6.0-7.5)
2: weak binding (< 6.0)
"""
if isinstance(affinity, torch.Tensor):
tight_mask = affinity >= self.tight_threshold
weak_mask = affinity < self.weak_threshold
medium_mask = ~(tight_mask | weak_mask)
classes = torch.zeros_like(affinity, dtype=torch.long)
classes[medium_mask] = 1
classes[weak_mask] = 2
return classes
else:
if affinity >= self.tight_threshold:
return 0 # tight binding
elif affinity < self.weak_threshold:
return 2 # weak binding
else:
return 1 # medium binding
def forward(self, protein_emb, binder_emb):
protein = self.protein_norm(self.protein_projection(protein_emb))
smiles = self.smiles_norm(self.smiles_projection(binder_emb))
protein = protein.transpose(0, 1)
smiles = smiles.transpose(0, 1)
# Cross attention layers
for layer in self.cross_attention_layers:
# Protein attending to SMILES
attended_protein = layer['attention'](
protein, smiles, smiles
)[0]
protein = layer['norm1'](protein + attended_protein)
protein = layer['norm2'](protein + layer['ffn'](protein))
# SMILES attending to protein
attended_smiles = layer['attention'](
smiles, protein, protein
)[0]
smiles = layer['norm1'](smiles + attended_smiles)
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
# Get sequence-level representations
protein_pool = torch.mean(protein, dim=0)
smiles_pool = torch.mean(smiles, dim=0)
# Concatenate both representations
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
# Shared features
shared_features = self.shared_head(combined)
regression_output = self.regression_head(shared_features)
return regression_output
class PooledAffinityModel(nn.Module):
def __init__(self, affinity_predictor, target_sequence):
super(PooledAffinityModel, self).__init__()
self.affinity_predictor = affinity_predictor
self.target_sequence = target_sequence
self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device)
for param in self.esm_model.parameters():
param.requires_grad = False
def compute_embeddings(self, input_ids, attention_mask=None):
"""Compute ESM embeddings on the fly"""
esm_outputs = self.esm_model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
return esm_outputs.last_hidden_state
def forward(self, x):
target_sequence = self.target_sequence.repeat(x.shape[0], 1)
protein_emb = self.compute_embeddings(input_ids=target_sequence)
binder_emb = self.compute_embeddings(input_ids=x)
return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1)
class AffinityModel(nn.Module):
def __init__(self, affinity_predictor, target_sequence):
super(AffinityModel, self).__init__()
self.affinity_predictor = affinity_predictor
self.target_sequence = target_sequence
def forward(self, x):
target_sequence = self.target_sequence.repeat(x.shape[0], 1)
affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1)
return affinity / 10
class HemolysisModel:
def __init__(self, device):
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_hemolysis.json')
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
self.model.eval()
self.device = device
def generate_embeddings(self, sequences):
"""Generate ESM embeddings for protein sequences"""
with torch.no_grad():
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
embeddings = embeddings.cpu().numpy()
return embeddings
def get_scores(self, input_seqs):
scores = np.ones(len(input_seqs))
features = self.generate_embeddings(input_seqs)
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
probs = self.predictor.predict(features)
# return the probability of it being not hemolytic
return torch.from_numpy(scores - probs).to(self.device)
def __call__(self, input_seqs: list):
scores = self.get_scores(input_seqs)
return scores
class NonfoulingModel:
def __init__(self, device):
# change model path
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_nonfouling.json')
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
self.model.eval()
self.device = device
def generate_embeddings(self, sequences):
"""Generate ESM embeddings for protein sequences"""
with torch.no_grad():
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
embeddings = embeddings.cpu().numpy()
return embeddings
def get_scores(self, input_seqs):
scores = np.zeros(len(input_seqs))
features = self.generate_embeddings(input_seqs)
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
scores = self.predictor.predict(features)
return torch.from_numpy(scores).to(self.device)
def __call__(self, input_seqs: list):
scores = self.get_scores(input_seqs)
return scores
class SolubilityModel:
def __init__(self, device):
# change model path
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_solubility.json')
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
self.model.eval()
self.device = device
def generate_embeddings(self, sequences):
"""Generate ESM embeddings for protein sequences"""
with torch.no_grad():
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
embeddings = embeddings.cpu().numpy()
return embeddings
def get_scores(self, input_seqs: list):
scores = np.zeros(len(input_seqs))
features = self.generate_embeddings(input_seqs)
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
scores = self.predictor.predict(features)
return torch.from_numpy(scores).to(self.device)
def __call__(self, input_seqs: list):
scores = self.get_scores(input_seqs)
return scores
class SolubilityModelNew:
def __init__(self, device):
self.hydro_ids = torch.tensor([5, 7, 4, 12, 20, 18, 22, 14], device=device)
self.device = device
def get_scores(self, x):
mask = (x.unsqueeze(-1) == self.hydro_ids).any(dim=-1)
ratios = mask.float().mean(dim=1)
return 1 - ratios
def __call__(self, input_seqs: list):
scores = self.get_scores(input_seqs)
return scores
class PeptideCNN(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
super().__init__()
self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
self.fc = nn.Linear(hidden_dims[1], output_dim)
self.dropout = nn.Dropout(dropout_rate)
self.predictor = nn.Linear(output_dim, 1) # For regression/classification
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.esm_model.eval()
def forward(self, input_ids, attention_mask=None, return_features=False):
with torch.no_grad():
x = self.esm_model(input_ids, attention_mask).last_hidden_state
# x shape: (B, L, input_dim)
x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
x = nn.functional.relu(self.conv1(x))
x = self.dropout(x)
x = nn.functional.relu(self.conv2(x))
x = self.dropout(x)
x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
# Global average pooling over the sequence dimension (L)
x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
features = self.fc(x) # features shape: (B, output_dim)
if return_features:
return features
return self.predictor(features) # Output shape: (B, 1)
class HalfLifeModel:
def __init__(self, device):
input_dim = 1280
hidden_dims = [input_dim // 2, input_dim // 4]
output_dim = input_dim // 8
dropout_rate = 0.3
self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
self.model.load_state_dict(torch.load('./classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
self.model.eval()
def __call__(self, x):
prediction = self.model(x, return_features=False)
halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
return halflife / 2
def load_bindevaluator(checkpoint_path, device):
bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, weights_only=False, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
bindevaluator.eval()
for param in bindevaluator.parameters():
param.requires_grad = False
return bindevaluator
def load_solver(checkpoint_path, vocab_size, device):
lr = 1e-4
epochs = 200
embed_dim = 512
hidden_dim = 256
epsilon = 1e-3
batch_size = 256
warmup_epochs = epochs // 10
device = 'cuda:0'
probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device)
probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False))
probability_denoiser.eval()
for param in probability_denoiser.parameters():
param.requires_grad = False
# instantiate a convex path object
scheduler = PolynomialConvexScheduler(n=2.0)
path = MixtureDiscreteProbPath(scheduler=scheduler)
class WrappedModel(ModelWrapper):
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
return torch.softmax(self.model(x, t), dim=-1)
wrapped_probability_denoiser = WrappedModel(probability_denoiser)
solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
return solver
def load_pooled_affinity_predictor(checkpoint_path, device):
"""Load trained model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model = ImprovedBindingPredictor().to(device)
# Load the trained weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval() # Set to evaluation mode
return model
def load_affinity_predictor(checkpoint_path, device):
"""Load trained model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model = UnpooledBindingPredictor(
esm_model_name="facebook/esm2_t33_650M_UR50D",
hidden_dim=384,
kernel_sizes=[3, 5, 7],
n_heads=8,
n_layers=4,
dropout=0.14561457009902096,
freeze_esm=True
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model