|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem |
|
|
from rdkit import DataStructs |
|
|
import numpy as np |
|
|
from sklearn.cluster import MiniBatchKMeans |
|
|
from collections import defaultdict |
|
|
from tqdm import tqdm |
|
|
import selfies as sf |
|
|
from multiprocessing import Pool, cpu_count |
|
|
from functools import partial |
|
|
def generate_fingerprint_batch_selfies(selfies_batch): |
|
|
fps = [] |
|
|
valid_selfies = [] |
|
|
|
|
|
for selfies in tqdm(selfies_batch, desc="Generating fingerprints", leave=False): |
|
|
try: |
|
|
|
|
|
smiles = sf.decoder(selfies) |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is not None: |
|
|
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048) |
|
|
arr = np.zeros((1,)) |
|
|
DataStructs.ConvertToNumpyArray(fp, arr) |
|
|
fps.append(arr) |
|
|
valid_selfies.append(selfies) |
|
|
except: |
|
|
continue |
|
|
|
|
|
return np.array(fps), valid_selfies |
|
|
|
|
|
def process_batch(batch, n_clusters, seed): |
|
|
fps, valid_selfies = generate_fingerprint_batch_selfies(batch) |
|
|
if len(fps) > 0: |
|
|
clusterer = MiniBatchKMeans(n_clusters=n_clusters, random_state=seed) |
|
|
clusterer.fit(fps) |
|
|
labels = clusterer.predict(fps) |
|
|
return list(zip(labels, valid_selfies)) |
|
|
return [] |
|
|
|
|
|
def parallel_clustering_split_selfies(selfies_list, batch_size=10000, n_clusters=1000, train_ratio=0.9, seed=42): |
|
|
np.random.seed(seed) |
|
|
|
|
|
|
|
|
batches = [selfies_list[i:i + batch_size] |
|
|
for i in range(0, len(selfies_list), batch_size)] |
|
|
|
|
|
|
|
|
n_cores = 12 |
|
|
process_batch_partial = partial(process_batch, n_clusters=n_clusters, seed=seed) |
|
|
|
|
|
cluster_assignments = defaultdict(list) |
|
|
with Pool(n_cores) as pool: |
|
|
results = list(tqdm( |
|
|
pool.imap(process_batch_partial, batches), |
|
|
total=len(batches), |
|
|
desc="Processing batches" |
|
|
)) |
|
|
|
|
|
|
|
|
for batch_results in results: |
|
|
for label, selfies in batch_results: |
|
|
cluster_assignments[label].append(selfies) |
|
|
|
|
|
|
|
|
clusters = list(cluster_assignments.values()) |
|
|
np.random.shuffle(clusters) |
|
|
|
|
|
train_selfies = [] |
|
|
val_selfies = [] |
|
|
total_mols = sum(len(cluster) for cluster in clusters) |
|
|
|
|
|
for cluster in tqdm(clusters, desc="Splitting clusters"): |
|
|
if len(train_selfies) / total_mols < train_ratio: |
|
|
train_selfies.extend(cluster) |
|
|
else: |
|
|
val_selfies.extend(cluster) |
|
|
|
|
|
print(f"Final splits: Train={len(train_selfies)}, Validation={len(val_selfies)}") |
|
|
return train_selfies, val_selfies |
|
|
|
|
|
try: |
|
|
with open('/home/yz927/projects/peptune/tokens/filtered_peptides_selfies.txt', 'r') as f: |
|
|
selfies_list = [line.strip() for line in f if line.strip()] |
|
|
print(f"Loaded {len(selfies_list)} selfies sequences from file") |
|
|
except FileNotFoundError: |
|
|
raise FileNotFoundError(f"Could not find the file at file") |
|
|
except Exception as e: |
|
|
raise Exception(f"Error reading file: {str(e)}") |
|
|
|
|
|
train_selfies, val_selfies = parallel_clustering_split_selfies( |
|
|
selfies_list, |
|
|
batch_size=10000, |
|
|
n_clusters=1000, |
|
|
train_ratio=0.8 |
|
|
) |
|
|
with open('/home/yz927/projects/peptune/tokens/11M_selfies/train_selfies.txt', 'w') as f: |
|
|
for line in train_selfies: |
|
|
f.write(f"{line}\n") |
|
|
with open('/home/yz927/projects/peptune/tokens/11M_selfies/val_selfies.txt', 'w') as f: |
|
|
for line in val_selfies: |
|
|
f.write(f"{line}\n") |