import sys import os sys.path.append('/home/st512/peptune/scripts/peptide-mdlm-mcts') import xgboost as xgb import torch import numpy as np import warnings import numpy as np from rdkit import Chem, rdBase, DataStructs from transformers import AutoTokenizer, EsmModel rdBase.DisableLog('rdApp.error') warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) class Hemolysis: def __init__(self): # change model path self.predictor = xgb.Booster(model_file='/home/tc415/flow_matching/classifier_ckpt/best_model_hemolysis.json') # Load ESM model and tokenizer self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") self.model.eval() def generate_embeddings(self, sequences): """Generate ESM embeddings for protein sequences""" embeddings = [] # Process sequences in batches to avoid memory issues batch_size = 8 for i in range(0, len(sequences), batch_size): batch_sequences = sequences[i:i + batch_size] inputs = self.tokenizer( batch_sequences, padding=True, truncation=True, return_tensors="pt" ) if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} self.model = self.model.cuda() # Generate embeddings with torch.no_grad(): outputs = self.model(**inputs) # Get last hidden states last_hidden_states = outputs.last_hidden_state # pdb.set_trace() # Compute mean pooling (excluding padding tokens) attention_mask = inputs['attention_mask'].unsqueeze(-1) masked_hidden_states = last_hidden_states * attention_mask sum_hidden_states = masked_hidden_states.sum(dim=1) seq_lengths = attention_mask.sum(dim=1) batch_embeddings = sum_hidden_states / seq_lengths batch_embeddings = batch_embeddings.cpu().numpy() embeddings.append(batch_embeddings) if embeddings: return np.vstack(embeddings) else: return np.array([]) def get_scores(self, input_seqs: list): scores = np.ones(len(input_seqs)) features = self.generate_embeddings(input_seqs) if len(features) == 0: return scores features = np.nan_to_num(features, nan=0.) features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) features = xgb.DMatrix(features) probs = self.predictor.predict(features) # return the probability of it being not hemolytic return scores - probs def __call__(self, input_seqs: list): scores = self.get_scores(input_seqs) return scores def unittest(): hemolysis = Hemolysis() sequences = [ "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", "MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD" ] scores = hemolysis(input_seqs=sequences) print([1-score for score in scores]) if __name__ == '__main__': unittest()