PepTune / data_preprocessing /data_split.py
yinuozhang's picture
add data_processing
680e7ec
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from collections import defaultdict
from tqdm import tqdm
import selfies as sf
from multiprocessing import Pool, cpu_count
from functools import partial
def generate_fingerprint_batch_selfies(selfies_batch):
fps = []
valid_selfies = []
for selfies in tqdm(selfies_batch, desc="Generating fingerprints", leave=False):
try:
# Convert SELFIES to SMILES then to molecule
smiles = sf.decoder(selfies)
mol = Chem.MolFromSmiles(smiles)
if mol is not None:
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048)
arr = np.zeros((1,))
DataStructs.ConvertToNumpyArray(fp, arr)
fps.append(arr)
valid_selfies.append(selfies)
except:
continue
return np.array(fps), valid_selfies
def process_batch(batch, n_clusters, seed):
fps, valid_selfies = generate_fingerprint_batch_selfies(batch)
if len(fps) > 0:
clusterer = MiniBatchKMeans(n_clusters=n_clusters, random_state=seed)
clusterer.fit(fps)
labels = clusterer.predict(fps)
return list(zip(labels, valid_selfies))
return []
def parallel_clustering_split_selfies(selfies_list, batch_size=10000, n_clusters=1000, train_ratio=0.9, seed=42):
np.random.seed(seed)
# Create batches
batches = [selfies_list[i:i + batch_size]
for i in range(0, len(selfies_list), batch_size)]
# Initialize parallel processing
n_cores = 12
process_batch_partial = partial(process_batch, n_clusters=n_clusters, seed=seed)
cluster_assignments = defaultdict(list)
with Pool(n_cores) as pool:
results = list(tqdm(
pool.imap(process_batch_partial, batches),
total=len(batches),
desc="Processing batches"
))
# Combine results
for batch_results in results:
for label, selfies in batch_results:
cluster_assignments[label].append(selfies)
# Split into train/val
clusters = list(cluster_assignments.values())
np.random.shuffle(clusters)
train_selfies = []
val_selfies = []
total_mols = sum(len(cluster) for cluster in clusters)
for cluster in tqdm(clusters, desc="Splitting clusters"):
if len(train_selfies) / total_mols < train_ratio:
train_selfies.extend(cluster)
else:
val_selfies.extend(cluster)
print(f"Final splits: Train={len(train_selfies)}, Validation={len(val_selfies)}")
return train_selfies, val_selfies
try:
with open('/home/yz927/projects/peptune/tokens/filtered_peptides_selfies.txt', 'r') as f:
selfies_list = [line.strip() for line in f if line.strip()]
print(f"Loaded {len(selfies_list)} selfies sequences from file")
except FileNotFoundError:
raise FileNotFoundError(f"Could not find the file at file")
except Exception as e:
raise Exception(f"Error reading file: {str(e)}")
train_selfies, val_selfies = parallel_clustering_split_selfies(
selfies_list,
batch_size=10000,
n_clusters=1000,
train_ratio=0.8
)
with open('/home/yz927/projects/peptune/tokens/11M_selfies/train_selfies.txt', 'w') as f:
for line in train_selfies:
f.write(f"{line}\n")
with open('/home/yz927/projects/peptune/tokens/11M_selfies/val_selfies.txt', 'w') as f:
for line in val_selfies:
f.write(f"{line}\n")