|
|
import pdb |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.nn as nn |
|
|
import pytorch_lightning as pl |
|
|
import time |
|
|
from transformers import AutoModel, AutoConfig, AutoTokenizer |
|
|
import xgboost as xgb |
|
|
import esm |
|
|
|
|
|
from flow_matching.path import MixtureDiscreteProbPath |
|
|
from flow_matching.path.scheduler import PolynomialConvexScheduler |
|
|
from flow_matching.solver import MixtureDiscreteEulerSolver |
|
|
from flow_matching.utils import ModelWrapper |
|
|
from flow_matching.loss import MixturePathGeneralizedKL |
|
|
|
|
|
from models.peptide_models import CNNModel |
|
|
from modules.bindevaluator_modules import * |
|
|
|
|
|
def parse_motifs(motif: str) -> list: |
|
|
parts = motif.split(',') |
|
|
result = [] |
|
|
|
|
|
for part in parts: |
|
|
part = part.strip() |
|
|
if '-' in part: |
|
|
start, end = map(int, part.split('-')) |
|
|
result.extend(range(start, end + 1)) |
|
|
else: |
|
|
result.append(int(part)) |
|
|
|
|
|
result = [pos-1 for pos in result] |
|
|
print(f'Target Motifs: {result}') |
|
|
return torch.tensor(result) |
|
|
|
|
|
class BindEvaluator(pl.LightningModule): |
|
|
def __init__(self, n_layers, d_model, d_hidden, n_head, |
|
|
d_k, d_v, d_inner, dropout=0.2, |
|
|
learning_rate=0.00001, max_epochs=15, kl_weight=1): |
|
|
super(BindEvaluator, self).__init__() |
|
|
|
|
|
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
self.esm_model.eval() |
|
|
|
|
|
for param in self.esm_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden, |
|
|
n_head, d_k, d_v, d_inner, dropout=dropout) |
|
|
|
|
|
self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model, |
|
|
d_k, d_v, dropout=dropout) |
|
|
|
|
|
self.final_ffn = FFN(d_model, d_inner, dropout=dropout) |
|
|
|
|
|
self.output_projection_prot = nn.Linear(d_model, 1) |
|
|
|
|
|
self.learning_rate = learning_rate |
|
|
self.max_epochs = max_epochs |
|
|
self.kl_weight = kl_weight |
|
|
|
|
|
self.classification_threshold = nn.Parameter(torch.tensor(0.5)) |
|
|
self.historical_memory = 0.9 |
|
|
self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) |
|
|
|
|
|
def forward(self, binder_tokens, target_tokens): |
|
|
peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state |
|
|
protein_sequence = self.esm_model(**target_tokens).last_hidden_state |
|
|
|
|
|
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ |
|
|
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, |
|
|
protein_sequence) |
|
|
|
|
|
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) |
|
|
|
|
|
prot_enc = self.final_ffn(prot_enc) |
|
|
|
|
|
prot_enc = self.output_projection_prot(prot_enc) |
|
|
|
|
|
return prot_enc |
|
|
|
|
|
def get_probs(self, x_t, target_sequence): |
|
|
''' |
|
|
Inputs: |
|
|
- xt: Shape (bsz, seq_len) |
|
|
- target_sequence: Shape (1, tgt_len) |
|
|
''' |
|
|
|
|
|
target_sequence = target_sequence.repeat(x_t.shape[0], 1) |
|
|
binder_attention_mask = torch.ones_like(x_t) |
|
|
target_attention_mask = torch.ones_like(target_sequence) |
|
|
|
|
|
binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0 |
|
|
target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0 |
|
|
|
|
|
binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)} |
|
|
target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)} |
|
|
|
|
|
logits = self.forward(binder_tokens, target_tokens).squeeze(-1) |
|
|
|
|
|
logits[:, 0] = logits[:, -1] = -100 |
|
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
return probs |
|
|
|
|
|
def motif_score(self, x_t, target_sequence, motifs): |
|
|
probs = self.get_probs(x_t, target_sequence) |
|
|
motif_probs = probs[:, motifs] |
|
|
motif_score = motif_probs.sum(dim=-1) / len(motifs) |
|
|
|
|
|
return motif_score |
|
|
|
|
|
def non_motif_score(self, x_t, target_sequence, motifs): |
|
|
probs = self.get_probs(x_t, target_sequence) |
|
|
non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] |
|
|
mask = non_motif_probs >= 0.5 |
|
|
count = mask.sum(dim=-1) |
|
|
|
|
|
non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count)) |
|
|
|
|
|
return non_motif_score |
|
|
|
|
|
def scoring(self, x_t, target_sequence, motifs, penalty=False): |
|
|
probs = self.get_probs(x_t, target_sequence) |
|
|
motif_probs = probs[:, motifs] |
|
|
motif_score = motif_probs.sum(dim=-1) / len(motifs) |
|
|
|
|
|
|
|
|
if penalty: |
|
|
non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] |
|
|
mask = non_motif_probs >= 0.5 |
|
|
count = mask.sum(dim=-1) |
|
|
|
|
|
non_motif_score = count / target_sequence.shape[1] |
|
|
return motif_score, 1 - non_motif_score |
|
|
else: |
|
|
return motif_score |
|
|
|
|
|
class MotifModel(nn.Module): |
|
|
def __init__(self, bindevaluator, target_sequence, motifs, penalty=False): |
|
|
super(MotifModel, self).__init__() |
|
|
self.bindevaluator = bindevaluator |
|
|
self.target_sequence = target_sequence |
|
|
self.motifs = motifs |
|
|
self.penalty = penalty |
|
|
|
|
|
def forward(self, x): |
|
|
return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty) |
|
|
|
|
|
class UnpooledBindingPredictor(nn.Module): |
|
|
def __init__(self, |
|
|
esm_model_name="facebook/esm2_t33_650M_UR50D", |
|
|
hidden_dim=512, |
|
|
kernel_sizes=[3, 5, 7], |
|
|
n_heads=8, |
|
|
n_layers=3, |
|
|
dropout=0.1, |
|
|
freeze_esm=True): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.tight_threshold = 7.5 |
|
|
self.weak_threshold = 6.0 |
|
|
|
|
|
|
|
|
self.esm_model = AutoModel.from_pretrained(esm_model_name) |
|
|
self.config = AutoConfig.from_pretrained(esm_model_name) |
|
|
|
|
|
|
|
|
if freeze_esm: |
|
|
for param in self.esm_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
esm_dim = self.config.hidden_size |
|
|
|
|
|
|
|
|
output_channels_per_kernel = 64 |
|
|
|
|
|
|
|
|
self.protein_conv_layers = nn.ModuleList([ |
|
|
nn.Conv1d( |
|
|
in_channels=esm_dim, |
|
|
out_channels=output_channels_per_kernel, |
|
|
kernel_size=k, |
|
|
padding='same' |
|
|
) for k in kernel_sizes |
|
|
]) |
|
|
|
|
|
self.binder_conv_layers = nn.ModuleList([ |
|
|
nn.Conv1d( |
|
|
in_channels=esm_dim, |
|
|
out_channels=output_channels_per_kernel, |
|
|
kernel_size=k, |
|
|
padding='same' |
|
|
) for k in kernel_sizes |
|
|
]) |
|
|
|
|
|
|
|
|
total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2 |
|
|
|
|
|
|
|
|
self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim) |
|
|
self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim) |
|
|
|
|
|
self.protein_norm = nn.LayerNorm(hidden_dim) |
|
|
self.binder_norm = nn.LayerNorm(hidden_dim) |
|
|
|
|
|
|
|
|
self.cross_attention_layers = nn.ModuleList([ |
|
|
nn.ModuleDict({ |
|
|
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), |
|
|
'norm1': nn.LayerNorm(hidden_dim), |
|
|
'ffn': nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim * 4), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim * 4, hidden_dim) |
|
|
), |
|
|
'norm2': nn.LayerNorm(hidden_dim) |
|
|
}) for _ in range(n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.shared_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim * 2, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
|
|
|
self.regression_head = nn.Linear(hidden_dim, 1) |
|
|
|
|
|
|
|
|
self.classification_head = nn.Linear(hidden_dim, 3) |
|
|
|
|
|
def get_binding_class(self, affinity): |
|
|
"""Convert affinity values to class indices |
|
|
0: tight binding (>= 7.5) |
|
|
1: medium binding (6.0-7.5) |
|
|
2: weak binding (< 6.0) |
|
|
""" |
|
|
if isinstance(affinity, torch.Tensor): |
|
|
tight_mask = affinity >= self.tight_threshold |
|
|
weak_mask = affinity < self.weak_threshold |
|
|
medium_mask = ~(tight_mask | weak_mask) |
|
|
|
|
|
classes = torch.zeros_like(affinity, dtype=torch.long) |
|
|
classes[medium_mask] = 1 |
|
|
classes[weak_mask] = 2 |
|
|
return classes |
|
|
else: |
|
|
if affinity >= self.tight_threshold: |
|
|
return 0 |
|
|
elif affinity < self.weak_threshold: |
|
|
return 2 |
|
|
else: |
|
|
return 1 |
|
|
|
|
|
def compute_embeddings(self, input_ids, attention_mask=None): |
|
|
"""Compute ESM embeddings on the fly""" |
|
|
esm_outputs = self.esm_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
|
|
|
return esm_outputs.last_hidden_state |
|
|
|
|
|
def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None): |
|
|
"""Process a sequence through CNN layers and pooling""" |
|
|
|
|
|
x = unpooled_emb.transpose(1, 2) |
|
|
|
|
|
|
|
|
conv_outputs = [] |
|
|
for conv in conv_layers: |
|
|
conv_out = F.relu(conv(x)) |
|
|
conv_outputs.append(conv_out) |
|
|
|
|
|
|
|
|
conv_output = torch.cat(conv_outputs, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1) |
|
|
|
|
|
|
|
|
masked_output = conv_output.clone() |
|
|
masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf')) |
|
|
|
|
|
|
|
|
max_pooled = torch.max(masked_output, dim=2)[0] |
|
|
|
|
|
|
|
|
sum_pooled = torch.sum(conv_output * expanded_mask, dim=2) |
|
|
valid_positions = torch.sum(expanded_mask, dim=2) |
|
|
valid_positions = torch.clamp(valid_positions, min=1.0) |
|
|
avg_pooled = sum_pooled / valid_positions |
|
|
else: |
|
|
|
|
|
max_pooled = torch.max(conv_output, dim=2)[0] |
|
|
avg_pooled = torch.mean(conv_output, dim=2) |
|
|
|
|
|
|
|
|
pooled = torch.cat([max_pooled, avg_pooled], dim=1) |
|
|
|
|
|
return pooled |
|
|
|
|
|
def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None): |
|
|
|
|
|
protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask) |
|
|
binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask) |
|
|
|
|
|
|
|
|
protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask) |
|
|
binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask) |
|
|
|
|
|
|
|
|
protein = self.protein_norm(self.protein_projection(protein_features)) |
|
|
binder = self.binder_norm(self.binder_projection(binder_features)) |
|
|
|
|
|
|
|
|
protein = protein.unsqueeze(0) |
|
|
binder = binder.unsqueeze(0) |
|
|
|
|
|
|
|
|
for layer in self.cross_attention_layers: |
|
|
|
|
|
attended_protein = layer['attention']( |
|
|
protein, binder, binder |
|
|
)[0] |
|
|
protein = layer['norm1'](protein + attended_protein) |
|
|
protein = layer['norm2'](protein + layer['ffn'](protein)) |
|
|
|
|
|
|
|
|
attended_binder = layer['attention']( |
|
|
binder, protein, protein |
|
|
)[0] |
|
|
binder = layer['norm1'](binder + attended_binder) |
|
|
binder = layer['norm2'](binder + layer['ffn'](binder)) |
|
|
|
|
|
|
|
|
protein_pool = protein.squeeze(0) |
|
|
binder_pool = binder.squeeze(0) |
|
|
|
|
|
|
|
|
combined = torch.cat([protein_pool, binder_pool], dim=-1) |
|
|
|
|
|
|
|
|
shared_features = self.shared_head(combined) |
|
|
|
|
|
regression_output = self.regression_head(shared_features) |
|
|
|
|
|
|
|
|
|
|
|
return regression_output |
|
|
|
|
|
class ImprovedBindingPredictor(nn.Module): |
|
|
def __init__(self, |
|
|
esm_dim=1280, |
|
|
smiles_dim=1280, |
|
|
hidden_dim=512, |
|
|
n_heads=8, |
|
|
n_layers=5, |
|
|
dropout=0.1): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.tight_threshold = 7.5 |
|
|
self.weak_threshold = 6.0 |
|
|
|
|
|
|
|
|
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) |
|
|
self.protein_projection = nn.Linear(esm_dim, hidden_dim) |
|
|
self.protein_norm = nn.LayerNorm(hidden_dim) |
|
|
self.smiles_norm = nn.LayerNorm(hidden_dim) |
|
|
|
|
|
|
|
|
self.cross_attention_layers = nn.ModuleList([ |
|
|
nn.ModuleDict({ |
|
|
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), |
|
|
'norm1': nn.LayerNorm(hidden_dim), |
|
|
'ffn': nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim * 4), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim * 4, hidden_dim) |
|
|
), |
|
|
'norm2': nn.LayerNorm(hidden_dim) |
|
|
}) for _ in range(n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.shared_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim * 2, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
|
|
|
self.regression_head = nn.Linear(hidden_dim, 1) |
|
|
|
|
|
|
|
|
self.classification_head = nn.Linear(hidden_dim, 3) |
|
|
|
|
|
def get_binding_class(self, affinity): |
|
|
"""Convert affinity values to class indices |
|
|
0: tight binding (>= 7.5) |
|
|
1: medium binding (6.0-7.5) |
|
|
2: weak binding (< 6.0) |
|
|
""" |
|
|
if isinstance(affinity, torch.Tensor): |
|
|
tight_mask = affinity >= self.tight_threshold |
|
|
weak_mask = affinity < self.weak_threshold |
|
|
medium_mask = ~(tight_mask | weak_mask) |
|
|
|
|
|
classes = torch.zeros_like(affinity, dtype=torch.long) |
|
|
classes[medium_mask] = 1 |
|
|
classes[weak_mask] = 2 |
|
|
return classes |
|
|
else: |
|
|
if affinity >= self.tight_threshold: |
|
|
return 0 |
|
|
elif affinity < self.weak_threshold: |
|
|
return 2 |
|
|
else: |
|
|
return 1 |
|
|
|
|
|
def forward(self, protein_emb, binder_emb): |
|
|
|
|
|
protein = self.protein_norm(self.protein_projection(protein_emb)) |
|
|
smiles = self.smiles_norm(self.smiles_projection(binder_emb)) |
|
|
|
|
|
protein = protein.transpose(0, 1) |
|
|
smiles = smiles.transpose(0, 1) |
|
|
|
|
|
|
|
|
for layer in self.cross_attention_layers: |
|
|
|
|
|
attended_protein = layer['attention']( |
|
|
protein, smiles, smiles |
|
|
)[0] |
|
|
protein = layer['norm1'](protein + attended_protein) |
|
|
protein = layer['norm2'](protein + layer['ffn'](protein)) |
|
|
|
|
|
|
|
|
attended_smiles = layer['attention']( |
|
|
smiles, protein, protein |
|
|
)[0] |
|
|
smiles = layer['norm1'](smiles + attended_smiles) |
|
|
smiles = layer['norm2'](smiles + layer['ffn'](smiles)) |
|
|
|
|
|
|
|
|
protein_pool = torch.mean(protein, dim=0) |
|
|
smiles_pool = torch.mean(smiles, dim=0) |
|
|
|
|
|
|
|
|
combined = torch.cat([protein_pool, smiles_pool], dim=-1) |
|
|
|
|
|
|
|
|
shared_features = self.shared_head(combined) |
|
|
|
|
|
regression_output = self.regression_head(shared_features) |
|
|
|
|
|
return regression_output |
|
|
|
|
|
class PooledAffinityModel(nn.Module): |
|
|
def __init__(self, affinity_predictor, target_sequence): |
|
|
super(PooledAffinityModel, self).__init__() |
|
|
self.affinity_predictor = affinity_predictor |
|
|
self.target_sequence = target_sequence |
|
|
self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device) |
|
|
for param in self.esm_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def compute_embeddings(self, input_ids, attention_mask=None): |
|
|
"""Compute ESM embeddings on the fly""" |
|
|
esm_outputs = self.esm_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
|
|
|
return esm_outputs.last_hidden_state |
|
|
|
|
|
def forward(self, x): |
|
|
target_sequence = self.target_sequence.repeat(x.shape[0], 1) |
|
|
|
|
|
protein_emb = self.compute_embeddings(input_ids=target_sequence) |
|
|
binder_emb = self.compute_embeddings(input_ids=x) |
|
|
return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1) |
|
|
|
|
|
class AffinityModel(nn.Module): |
|
|
def __init__(self, affinity_predictor, target_sequence): |
|
|
super(AffinityModel, self).__init__() |
|
|
self.affinity_predictor = affinity_predictor |
|
|
self.target_sequence = target_sequence |
|
|
|
|
|
def forward(self, x): |
|
|
target_sequence = self.target_sequence.repeat(x.shape[0], 1) |
|
|
affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1) |
|
|
return affinity / 10 |
|
|
|
|
|
class HemolysisModel: |
|
|
def __init__(self, device): |
|
|
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_hemolysis.json') |
|
|
|
|
|
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) |
|
|
self.model.eval() |
|
|
|
|
|
self.device = device |
|
|
|
|
|
def generate_embeddings(self, sequences): |
|
|
"""Generate ESM embeddings for protein sequences""" |
|
|
with torch.no_grad(): |
|
|
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) |
|
|
embeddings = embeddings.cpu().numpy() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def get_scores(self, input_seqs): |
|
|
scores = np.ones(len(input_seqs)) |
|
|
features = self.generate_embeddings(input_seqs) |
|
|
|
|
|
if len(features) == 0: |
|
|
return scores |
|
|
|
|
|
features = np.nan_to_num(features, nan=0.) |
|
|
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
|
|
|
|
|
features = xgb.DMatrix(features) |
|
|
|
|
|
probs = self.predictor.predict(features) |
|
|
|
|
|
return torch.from_numpy(scores - probs).to(self.device) |
|
|
|
|
|
def __call__(self, input_seqs: list): |
|
|
scores = self.get_scores(input_seqs) |
|
|
return scores |
|
|
|
|
|
class NonfoulingModel: |
|
|
def __init__(self, device): |
|
|
|
|
|
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_nonfouling.json') |
|
|
|
|
|
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) |
|
|
self.model.eval() |
|
|
|
|
|
self.device = device |
|
|
|
|
|
def generate_embeddings(self, sequences): |
|
|
"""Generate ESM embeddings for protein sequences""" |
|
|
with torch.no_grad(): |
|
|
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) |
|
|
embeddings = embeddings.cpu().numpy() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def get_scores(self, input_seqs): |
|
|
scores = np.zeros(len(input_seqs)) |
|
|
features = self.generate_embeddings(input_seqs) |
|
|
|
|
|
if len(features) == 0: |
|
|
return scores |
|
|
|
|
|
features = np.nan_to_num(features, nan=0.) |
|
|
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
|
|
|
|
|
features = xgb.DMatrix(features) |
|
|
|
|
|
scores = self.predictor.predict(features) |
|
|
return torch.from_numpy(scores).to(self.device) |
|
|
|
|
|
def __call__(self, input_seqs: list): |
|
|
scores = self.get_scores(input_seqs) |
|
|
return scores |
|
|
|
|
|
class SolubilityModel: |
|
|
def __init__(self, device): |
|
|
|
|
|
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_solubility.json') |
|
|
|
|
|
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) |
|
|
self.model.eval() |
|
|
|
|
|
self.device = device |
|
|
|
|
|
def generate_embeddings(self, sequences): |
|
|
"""Generate ESM embeddings for protein sequences""" |
|
|
with torch.no_grad(): |
|
|
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) |
|
|
embeddings = embeddings.cpu().numpy() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def get_scores(self, input_seqs: list): |
|
|
scores = np.zeros(len(input_seqs)) |
|
|
features = self.generate_embeddings(input_seqs) |
|
|
|
|
|
if len(features) == 0: |
|
|
return scores |
|
|
|
|
|
features = np.nan_to_num(features, nan=0.) |
|
|
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
|
|
|
|
|
features = xgb.DMatrix(features) |
|
|
|
|
|
scores = self.predictor.predict(features) |
|
|
return torch.from_numpy(scores).to(self.device) |
|
|
|
|
|
def __call__(self, input_seqs: list): |
|
|
scores = self.get_scores(input_seqs) |
|
|
return scores |
|
|
|
|
|
class SolubilityModelNew: |
|
|
def __init__(self, device): |
|
|
self.hydro_ids = torch.tensor([5, 7, 4, 12, 20, 18, 22, 14], device=device) |
|
|
self.device = device |
|
|
|
|
|
def get_scores(self, x): |
|
|
mask = (x.unsqueeze(-1) == self.hydro_ids).any(dim=-1) |
|
|
ratios = mask.float().mean(dim=1) |
|
|
return 1 - ratios |
|
|
|
|
|
def __call__(self, input_seqs: list): |
|
|
scores = self.get_scores(input_seqs) |
|
|
return scores |
|
|
|
|
|
class PeptideCNN(nn.Module): |
|
|
def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1) |
|
|
self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1) |
|
|
self.fc = nn.Linear(hidden_dims[1], output_dim) |
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
self.predictor = nn.Linear(output_dim, 1) |
|
|
|
|
|
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
self.esm_model.eval() |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, return_features=False): |
|
|
with torch.no_grad(): |
|
|
x = self.esm_model(input_ids, attention_mask).last_hidden_state |
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
|
x = nn.functional.relu(self.conv1(x)) |
|
|
x = self.dropout(x) |
|
|
x = nn.functional.relu(self.conv2(x)) |
|
|
x = self.dropout(x) |
|
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
x = x.mean(dim=1) |
|
|
|
|
|
features = self.fc(x) |
|
|
if return_features: |
|
|
return features |
|
|
return self.predictor(features) |
|
|
|
|
|
class HalfLifeModel: |
|
|
def __init__(self, device): |
|
|
input_dim = 1280 |
|
|
hidden_dims = [input_dim // 2, input_dim // 4] |
|
|
output_dim = input_dim // 8 |
|
|
dropout_rate = 0.3 |
|
|
self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device) |
|
|
self.model.load_state_dict(torch.load('./classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False)) |
|
|
self.model.eval() |
|
|
|
|
|
def __call__(self, x): |
|
|
prediction = self.model(x, return_features=False) |
|
|
halflife = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0) |
|
|
return halflife / 2 |
|
|
|
|
|
|
|
|
def load_bindevaluator(checkpoint_path, device): |
|
|
bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, weights_only=False, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device) |
|
|
bindevaluator.eval() |
|
|
for param in bindevaluator.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
return bindevaluator |
|
|
|
|
|
|
|
|
def load_solver(checkpoint_path, vocab_size, device): |
|
|
lr = 1e-4 |
|
|
epochs = 200 |
|
|
embed_dim = 512 |
|
|
hidden_dim = 256 |
|
|
epsilon = 1e-3 |
|
|
batch_size = 256 |
|
|
warmup_epochs = epochs // 10 |
|
|
device = 'cuda:0' |
|
|
|
|
|
|
|
|
probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device) |
|
|
probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False)) |
|
|
probability_denoiser.eval() |
|
|
for param in probability_denoiser.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
scheduler = PolynomialConvexScheduler(n=2.0) |
|
|
path = MixtureDiscreteProbPath(scheduler=scheduler) |
|
|
|
|
|
class WrappedModel(ModelWrapper): |
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras): |
|
|
return torch.softmax(self.model(x, t), dim=-1) |
|
|
|
|
|
wrapped_probability_denoiser = WrappedModel(probability_denoiser) |
|
|
solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size) |
|
|
|
|
|
return solver |
|
|
|
|
|
|
|
|
def load_pooled_affinity_predictor(checkpoint_path, device): |
|
|
"""Load trained model from checkpoint.""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
|
|
|
model = ImprovedBindingPredictor().to(device) |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
def load_affinity_predictor(checkpoint_path, device): |
|
|
"""Load trained model from checkpoint.""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
|
|
|
model = UnpooledBindingPredictor( |
|
|
esm_model_name="facebook/esm2_t33_650M_UR50D", |
|
|
hidden_dim=384, |
|
|
kernel_sizes=[3, 5, 7], |
|
|
n_heads=8, |
|
|
n_layers=4, |
|
|
dropout=0.14561457009902096, |
|
|
freeze_esm=True |
|
|
).to(device) |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|