import argparse import warnings from pathlib import Path from time import time import torch from rdkit import Chem from tqdm import tqdm from lightning_modules import LigandPocketDDPM from analysis.molecule_builder import process_molecule import utils MAXITER = 10 MAXNTRIES = 10 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('checkpoint', type=Path) parser.add_argument('--test_dir', type=Path) parser.add_argument('--test_list', type=Path, default=None) parser.add_argument('--outdir', type=Path) parser.add_argument('--n_samples', type=int, default=100) parser.add_argument('--all_frags', action='store_true') parser.add_argument('--sanitize', action='store_true') parser.add_argument('--relax', action='store_true') parser.add_argument('--batch_size', type=int, default=120) parser.add_argument('--resamplings', type=int, default=10) parser.add_argument('--jump_length', type=int, default=1) parser.add_argument('--timesteps', type=int, default=None) parser.add_argument('--fix_n_nodes', action='store_true') parser.add_argument('--n_nodes_bias', type=int, default=0) parser.add_argument('--n_nodes_min', type=int, default=0) parser.add_argument('--skip_existing', action='store_true') args = parser.parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' args.outdir.mkdir(exist_ok=args.skip_existing) raw_sdf_dir = Path(args.outdir, 'raw') raw_sdf_dir.mkdir(exist_ok=args.skip_existing) processed_sdf_dir = Path(args.outdir, 'processed') processed_sdf_dir.mkdir(exist_ok=args.skip_existing) times_dir = Path(args.outdir, 'pocket_times') times_dir.mkdir(exist_ok=args.skip_existing) # Load model model = LigandPocketDDPM.load_from_checkpoint( args.checkpoint, map_location=device) model = model.to(device) test_files = list(args.test_dir.glob('[!.]*.sdf')) if args.test_list is not None: with open(args.test_list, 'r') as f: test_list = set(f.read().split(',')) test_files = [x for x in test_files if x.stem in test_list] pbar = tqdm(test_files) time_per_pocket = {} for sdf_file in pbar: ligand_name = sdf_file.stem pdb_name, pocket_id, *suffix = ligand_name.split('_') pdb_file = Path(sdf_file.parent, f"{pdb_name}.pdb") txt_file = Path(sdf_file.parent, f"{ligand_name}.txt") sdf_out_file_raw = Path(raw_sdf_dir, f'{ligand_name}_gen.sdf') sdf_out_file_processed = Path(processed_sdf_dir, f'{ligand_name}_gen.sdf') time_file = Path(times_dir, f'{ligand_name}.txt') if args.skip_existing and time_file.exists() \ and sdf_out_file_processed.exists() \ and sdf_out_file_raw.exists(): with open(time_file, 'r') as f: time_per_pocket[str(sdf_file)] = float(f.read().split()[1]) continue for n_try in range(MAXNTRIES): try: t_pocket_start = time() with open(txt_file, 'r') as f: resi_list = f.read().split() if args.fix_n_nodes: # some ligands (e.g. 6JWS_bio1_PT1:A:801) could not be read with sanitize=True suppl = Chem.SDMolSupplier(str(sdf_file), sanitize=False) num_nodes_lig = suppl[0].GetNumAtoms() else: num_nodes_lig = None all_molecules = [] valid_molecules = [] processed_molecules = [] # only used as temporary variable iter = 0 n_generated = 0 n_valid = 0 while len(valid_molecules) < args.n_samples: iter += 1 if iter > MAXITER: raise RuntimeError('Maximum number of iterations has been exceeded.') num_nodes_lig_inflated = None if num_nodes_lig is None else \ torch.ones(args.batch_size, dtype=int) * num_nodes_lig # Turn all filters off first mols_batch = model.generate_ligands( pdb_file, args.batch_size, resi_list, num_nodes_lig=num_nodes_lig_inflated, timesteps=args.timesteps, sanitize=False, largest_frag=False, relax_iter=0, n_nodes_bias=args.n_nodes_bias, n_nodes_min=args.n_nodes_min, resamplings=args.resamplings, jump_length=args.jump_length) all_molecules.extend(mols_batch) # Filter to find valid molecules mols_batch_processed = [ process_molecule(m, sanitize=args.sanitize, relax_iter=(200 if args.relax else 0), largest_frag=not args.all_frags) for m in mols_batch ] processed_molecules.extend(mols_batch_processed) valid_mols_batch = [m for m in mols_batch_processed if m is not None] n_generated += args.batch_size n_valid += len(valid_mols_batch) valid_molecules.extend(valid_mols_batch) # Remove excess molecules from list valid_molecules = valid_molecules[:args.n_samples] # Reorder raw files all_molecules = \ [all_molecules[i] for i, m in enumerate(processed_molecules) if m is not None] + \ [all_molecules[i] for i, m in enumerate(processed_molecules) if m is None] # Write SDF files utils.write_sdf_file(sdf_out_file_raw, all_molecules) utils.write_sdf_file(sdf_out_file_processed, valid_molecules) # Time the sampling process time_per_pocket[str(sdf_file)] = time() - t_pocket_start with open(time_file, 'w') as f: f.write(f"{str(sdf_file)} {time_per_pocket[str(sdf_file)]}") pbar.set_description( f'Last processed: {ligand_name}. ' f'Validity: {n_valid / n_generated * 100:.2f}%. ' f'{(time() - t_pocket_start) / len(valid_molecules):.2f} ' f'sec/mol.') break # no more tries needed except (RuntimeError, ValueError) as e: if n_try >= MAXNTRIES - 1: raise RuntimeError("Maximum number of retries exceeded") warnings.warn(f"Attempt {n_try + 1}/{MAXNTRIES} failed with " f"error: '{e}'. Trying again...") with open(Path(args.outdir, 'pocket_times.txt'), 'w') as f: for k, v in time_per_pocket.items(): f.write(f"{k} {v}\n") times_arr = torch.tensor([x for x in time_per_pocket.values()]) print(f"Time per pocket: {times_arr.mean():.3f} \pm " f"{times_arr.std(unbiased=False):.2f}")