File size: 7,308 Bytes
4742cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import argparse
import warnings
from pathlib import Path
from time import time

import torch
from rdkit import Chem
from tqdm import tqdm

from lightning_modules import LigandPocketDDPM
from analysis.molecule_builder import process_molecule
import utils

MAXITER = 10
MAXNTRIES = 10


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('checkpoint', type=Path)
    parser.add_argument('--test_dir', type=Path)
    parser.add_argument('--test_list', type=Path, default=None)
    parser.add_argument('--outdir', type=Path)
    parser.add_argument('--n_samples', type=int, default=100)
    parser.add_argument('--all_frags', action='store_true')
    parser.add_argument('--sanitize', action='store_true')
    parser.add_argument('--relax', action='store_true')
    parser.add_argument('--batch_size', type=int, default=120)
    parser.add_argument('--resamplings', type=int, default=10)
    parser.add_argument('--jump_length', type=int, default=1)
    parser.add_argument('--timesteps', type=int, default=None)
    parser.add_argument('--fix_n_nodes', action='store_true')
    parser.add_argument('--n_nodes_bias', type=int, default=0)
    parser.add_argument('--n_nodes_min', type=int, default=0)
    parser.add_argument('--skip_existing', action='store_true')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    args.outdir.mkdir(exist_ok=args.skip_existing)
    raw_sdf_dir = Path(args.outdir, 'raw')
    raw_sdf_dir.mkdir(exist_ok=args.skip_existing)
    processed_sdf_dir = Path(args.outdir, 'processed')
    processed_sdf_dir.mkdir(exist_ok=args.skip_existing)
    times_dir = Path(args.outdir, 'pocket_times')
    times_dir.mkdir(exist_ok=args.skip_existing)

    # Load model
    model = LigandPocketDDPM.load_from_checkpoint(
        args.checkpoint, map_location=device)
    model = model.to(device)

    test_files = list(args.test_dir.glob('[!.]*.sdf'))
    if args.test_list is not None:
        with open(args.test_list, 'r') as f:
            test_list = set(f.read().split(','))
        test_files = [x for x in test_files if x.stem in test_list]

    pbar = tqdm(test_files)
    time_per_pocket = {}
    for sdf_file in pbar:
        ligand_name = sdf_file.stem

        pdb_name, pocket_id, *suffix = ligand_name.split('_')
        pdb_file = Path(sdf_file.parent, f"{pdb_name}.pdb")
        txt_file = Path(sdf_file.parent, f"{ligand_name}.txt")
        sdf_out_file_raw = Path(raw_sdf_dir, f'{ligand_name}_gen.sdf')
        sdf_out_file_processed = Path(processed_sdf_dir,
                                      f'{ligand_name}_gen.sdf')
        time_file = Path(times_dir, f'{ligand_name}.txt')

        if args.skip_existing and time_file.exists() \
                and sdf_out_file_processed.exists() \
                and sdf_out_file_raw.exists():

            with open(time_file, 'r') as f:
                time_per_pocket[str(sdf_file)] = float(f.read().split()[1])

            continue

        for n_try in range(MAXNTRIES):

            try:
                t_pocket_start = time()

                with open(txt_file, 'r') as f:
                    resi_list = f.read().split()

                if args.fix_n_nodes:
                    # some ligands (e.g. 6JWS_bio1_PT1:A:801) could not be read with sanitize=True
                    suppl = Chem.SDMolSupplier(str(sdf_file), sanitize=False)
                    num_nodes_lig = suppl[0].GetNumAtoms()
                else:
                    num_nodes_lig = None

                all_molecules = []
                valid_molecules = []
                processed_molecules = []  # only used as temporary variable
                iter = 0
                n_generated = 0
                n_valid = 0
                while len(valid_molecules) < args.n_samples:
                    iter += 1
                    if iter > MAXITER:
                        raise RuntimeError('Maximum number of iterations has been exceeded.')

                    num_nodes_lig_inflated = None if num_nodes_lig is None else \
                        torch.ones(args.batch_size, dtype=int) * num_nodes_lig

                    # Turn all filters off first
                    mols_batch = model.generate_ligands(
                        pdb_file, args.batch_size, resi_list,
                        num_nodes_lig=num_nodes_lig_inflated,
                        timesteps=args.timesteps, sanitize=False,
                        largest_frag=False, relax_iter=0,
                        n_nodes_bias=args.n_nodes_bias,
                        n_nodes_min=args.n_nodes_min,
                        resamplings=args.resamplings,
                        jump_length=args.jump_length)

                    all_molecules.extend(mols_batch)

                    # Filter to find valid molecules
                    mols_batch_processed = [
                        process_molecule(m, sanitize=args.sanitize,
                                         relax_iter=(200 if args.relax else 0),
                                         largest_frag=not args.all_frags)
                        for m in mols_batch
                    ]
                    processed_molecules.extend(mols_batch_processed)
                    valid_mols_batch = [m for m in mols_batch_processed if m is not None]

                    n_generated += args.batch_size
                    n_valid += len(valid_mols_batch)
                    valid_molecules.extend(valid_mols_batch)

                # Remove excess molecules from list
                valid_molecules = valid_molecules[:args.n_samples]

                # Reorder raw files
                all_molecules = \
                    [all_molecules[i] for i, m in enumerate(processed_molecules)
                     if m is not None] + \
                    [all_molecules[i] for i, m in enumerate(processed_molecules)
                     if m is None]

                # Write SDF files
                utils.write_sdf_file(sdf_out_file_raw, all_molecules)
                utils.write_sdf_file(sdf_out_file_processed, valid_molecules)

                # Time the sampling process
                time_per_pocket[str(sdf_file)] = time() - t_pocket_start
                with open(time_file, 'w') as f:
                    f.write(f"{str(sdf_file)} {time_per_pocket[str(sdf_file)]}")

                pbar.set_description(
                    f'Last processed: {ligand_name}. '
                    f'Validity: {n_valid / n_generated * 100:.2f}%. '
                    f'{(time() - t_pocket_start) / len(valid_molecules):.2f} '
                    f'sec/mol.')

                break  # no more tries needed

            except (RuntimeError, ValueError) as e:
                if n_try >= MAXNTRIES - 1:
                    raise RuntimeError("Maximum number of retries exceeded")
                warnings.warn(f"Attempt {n_try + 1}/{MAXNTRIES} failed with "
                              f"error: '{e}'. Trying again...")

    with open(Path(args.outdir, 'pocket_times.txt'), 'w') as f:
        for k, v in time_per_pocket.items():
            f.write(f"{k} {v}\n")

    times_arr = torch.tensor([x for x in time_per_pocket.values()])
    print(f"Time per pocket: {times_arr.mean():.3f} \pm "
          f"{times_arr.std(unbiased=False):.2f}")