File size: 3,586 Bytes
3527383 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import sys
import os
sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts')
import xgboost as xgb
import torch
import numpy as np
import warnings
import numpy as np
from rdkit import Chem, rdBase, DataStructs
from transformers import AutoTokenizer, EsmModel
rdBase.DisableLog('rdApp.error')
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
class Hemolysis:
def __init__(self):
# change model path
self.predictor = xgb.Booster(model_file='/home/tc415/flow_matching/classifier_ckpt/best_model_hemolysis.json')
# Load ESM model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
self.model.eval()
def generate_embeddings(self, sequences):
"""Generate ESM embeddings for protein sequences"""
embeddings = []
# Process sequences in batches to avoid memory issues
batch_size = 8
for i in range(0, len(sequences), batch_size):
batch_sequences = sequences[i:i + batch_size]
inputs = self.tokenizer(
batch_sequences,
padding=True,
truncation=True,
return_tensors="pt"
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
self.model = self.model.cuda()
# Generate embeddings
with torch.no_grad():
outputs = self.model(**inputs)
# Get last hidden states
last_hidden_states = outputs.last_hidden_state
# pdb.set_trace()
# Compute mean pooling (excluding padding tokens)
attention_mask = inputs['attention_mask'].unsqueeze(-1)
masked_hidden_states = last_hidden_states * attention_mask
sum_hidden_states = masked_hidden_states.sum(dim=1)
seq_lengths = attention_mask.sum(dim=1)
batch_embeddings = sum_hidden_states / seq_lengths
batch_embeddings = batch_embeddings.cpu().numpy()
embeddings.append(batch_embeddings)
if embeddings:
return np.vstack(embeddings)
else:
return np.array([])
def get_scores(self, input_seqs: list):
scores = np.ones(len(input_seqs))
features = self.generate_embeddings(input_seqs)
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
probs = self.predictor.predict(features)
# return the probability of it being not hemolytic
return scores - probs
def __call__(self, input_seqs: list):
scores = self.get_scores(input_seqs)
return scores
def unittest():
hemolysis = Hemolysis()
sequences = [
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
"MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD"
]
scores = hemolysis(input_seqs=sequences)
print([1-score for score in scores])
if __name__ == '__main__':
unittest() |