Irwiny123 commited on
Commit
ef423c5
·
1 Parent(s): 4132f99

添加PepFlow模型初始代码

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +180 -0
  2. LICENSE +21 -0
  3. README.md +106 -3
  4. configs/learn_angle.yaml +74 -0
  5. environment.yml +261 -0
  6. eval/align.py +17 -0
  7. eval/energy.py +94 -0
  8. eval/foldx.py +77 -0
  9. eval/geometry.py +127 -0
  10. eval/run_esmfold.py +73 -0
  11. eval/run_esmif.py +33 -0
  12. eval/run_mpnn.py +146 -0
  13. eval/run_rfdiffusion.py +75 -0
  14. eval/run_scwrl4.py +30 -0
  15. eval/utils.py +106 -0
  16. models_con/edge.py +112 -0
  17. models_con/flow_model.py +472 -0
  18. models_con/ga.py +127 -0
  19. models_con/inference.py +101 -0
  20. models_con/ipa_pytorch.py +687 -0
  21. models_con/node.py +105 -0
  22. models_con/pep_dataloader.py +212 -0
  23. models_con/sample.py +145 -0
  24. models_con/torsion.py +239 -0
  25. models_con/torus.py +34 -0
  26. models_con/utils.py +72 -0
  27. openfold/config.py +4 -0
  28. openfold/model/__init__.py +16 -0
  29. openfold/model/dropout.py +78 -0
  30. openfold/model/embedders.py +352 -0
  31. openfold/model/evoformer.py +630 -0
  32. openfold/model/heads.py +251 -0
  33. openfold/model/model.py +446 -0
  34. openfold/model/msa.py +392 -0
  35. openfold/model/outer_product_mean.py +129 -0
  36. openfold/model/pair_transition.py +99 -0
  37. openfold/model/primitives.py +587 -0
  38. openfold/model/structure_module.py +820 -0
  39. openfold/model/template.py +333 -0
  40. openfold/model/torchscript.py +215 -0
  41. openfold/model/triangular_attention.py +139 -0
  42. openfold/model/triangular_multiplicative_update.py +127 -0
  43. openfold/np/__init__.py +16 -0
  44. openfold/np/protein.py +438 -0
  45. openfold/np/relax/__init__.py +16 -0
  46. openfold/np/relax/amber_minimize.py +612 -0
  47. openfold/np/relax/cleanup.py +131 -0
  48. openfold/np/relax/relax.py +90 -0
  49. openfold/np/relax/utils.py +88 -0
  50. openfold/np/residue_constants.py +1310 -0
.gitignore ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ logs/
2
+ lightning_logs/
3
+ wandb/
4
+ Data/
5
+ wandb*.json
6
+ **/*.ckpt
7
+ **/*.pth
8
+ **/*.json
9
+ .vscode/
10
+ **/*.pdb
11
+ ckpt/
12
+ *.code-workspace
13
+ outputs/
14
+ **/*.txt
15
+ **/lightning_logs/
16
+ **/inference_outputs/
17
+ .hydra
18
+ preprocessed/
19
+ misc/
20
+
21
+ # Byte-compiled / optimized / DLL files
22
+ __pycache__/
23
+ *.py[cod]
24
+ *$py.class
25
+
26
+ # C extensions
27
+ *.so
28
+
29
+ # Distribution / packaging
30
+ .Python
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ lib/
38
+ lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ wheels/
43
+ share/python-wheels/
44
+ *.egg-info/
45
+ .installed.cfg
46
+ *.egg
47
+ MANIFEST
48
+
49
+ # PyInstaller
50
+ # Usually these files are written by a python script from a template
51
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
52
+ *.manifest
53
+ *.spec
54
+
55
+ # Installer logs
56
+ pip-log.txt
57
+ pip-delete-this-directory.txt
58
+
59
+ # Unit test / coverage reports
60
+ htmlcov/
61
+ .tox/
62
+ .nox/
63
+ .coverage
64
+ .coverage.*
65
+ .cache
66
+ nosetests.xml
67
+ coverage.xml
68
+ *.cover
69
+ *.py,cover
70
+ .hypothesis/
71
+ .pytest_cache/
72
+ cover/
73
+
74
+ # Translations
75
+ *.mo
76
+ *.pot
77
+
78
+ # Django stuff:
79
+ *.log
80
+ local_settings.py
81
+ db.sqlite3
82
+ db.sqlite3-journal
83
+
84
+ # Flask stuff:
85
+ instance/
86
+ .webassets-cache
87
+
88
+ # Scrapy stuff:
89
+ .scrapy
90
+
91
+ # Sphinx documentation
92
+ docs/_build/
93
+
94
+ # PyBuilder
95
+ .pybuilder/
96
+ target/
97
+
98
+ # Jupyter Notebook
99
+ .ipynb_checkpoints
100
+
101
+ # IPython
102
+ profile_default/
103
+ ipython_config.py
104
+
105
+ # pyenv
106
+ # For a library or package, you might want to ignore these files since the code is
107
+ # intended to run in multiple environments; otherwise, check them in:
108
+ # .python-version
109
+
110
+ # pipenv
111
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
112
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
113
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
114
+ # install all needed dependencies.
115
+ #Pipfile.lock
116
+
117
+ # poetry
118
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
119
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
120
+ # commonly ignored for libraries.
121
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
122
+ #poetry.lock
123
+
124
+ # pdm
125
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
126
+ #pdm.lock
127
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
128
+ # in version control.
129
+ # https://pdm.fming.dev/#use-with-ide
130
+ .pdm.toml
131
+
132
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
133
+ __pypackages__/
134
+
135
+ # Celery stuff
136
+ celerybeat-schedule
137
+ celerybeat.pid
138
+
139
+ # SageMath parsed files
140
+ *.sage.py
141
+
142
+ # Environments
143
+ .env
144
+ .venv
145
+ env/
146
+ venv/
147
+ ENV/
148
+ env.bak/
149
+ venv.bak/
150
+
151
+ # Spyder project settings
152
+ .spyderproject
153
+ .spyproject
154
+
155
+ # Rope project settings
156
+ .ropeproject
157
+
158
+ # mkdocs documentation
159
+ /site
160
+
161
+ # mypy
162
+ .mypy_cache/
163
+ .dmypy.json
164
+ dmypy.json
165
+
166
+ # Pyre type checker
167
+ .pyre/
168
+
169
+ # pytype static type analyzer
170
+ .pytype/
171
+
172
+ # Cython debug symbols
173
+ cython_debug/
174
+
175
+ # PyCharm
176
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
177
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
178
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
179
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
180
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Cedlijh
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,106 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PepFlow: Full-Atom Peptide Design
2
+
3
+ ![alt text](teaser.png)
4
+
5
+
6
+ This repository contains the official implementation of 💡 Full-Atom Peptide Design based on Multi-modal Flow Matching (ICML 2024).
7
+
8
+ You can find our [paper](https://arxiv.org/abs/2406.00735) here. We also appreciate the inspiration from [diffab](https://github.com/luost26/diffab) and [frameflow](https://github.com/microsoft/protein-frame-flow).
9
+
10
+ If you have any questions, please contact lijiahanypc@pku.edu.cn or ced3ljhypc@gmail.com. Thank you! :)
11
+
12
+ ## Install
13
+
14
+
15
+ ### Environment
16
+
17
+ Please replace cuda and torch version to match your machine, here we test our code on CUDA >= 11.7, we also suggest using [micromamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html) as a replace of conda.
18
+
19
+ ```bash
20
+ conda env create -f environment.yml # or use micromamba instead of conda
21
+
22
+ conda activate flow
23
+
24
+ pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
25
+
26
+ pip install joblib lmdb easydict
27
+
28
+ ```
29
+
30
+ ### Clone Repo### Train
31
+
32
+ ```bash
33
+ git clone https://github.com/Ced3-han/PepFlowww.git
34
+ ```
35
+
36
+ We suggest adding the code to the Python environment variable, or you can use setup tools.
37
+
38
+ ```bash
39
+ export PYTHONPATH=$(pwd):$PYTHONPATH
40
+ python setup.py develop
41
+ ```
42
+
43
+
44
+ ### Data and Weights Download
45
+
46
+ We provide data and pretrained model weights [here](https://drive.google.com/drive/folders/1bHaKDF3uCDPtfsihjZs0zmjwF6UU1uVl?usp=sharing).
47
+
48
+ + PepMerge_release.zip: 1.2GB
49
+ + PepMerge_lmdb.zip: 180MB
50
+ + model1.pt: 80MB
51
+ + model2.pt: 80MB
52
+
53
+ The ```PepMerge_release.zip``` contains filtered data of peptide-receptor pairs. For example, in the folder ```1a0n_A```, the ```P``` chain in the PDB file ```1a0n``` is the peptide. In this folder, we provide the FASTA and PDB files of the peptide and receptor. The postfix _merge means the peptide and receptor are in the same PDB file. We also extract the binding pocket of the receptor, where our model is trained to generate peptides based on the binding pocket. You can also download [PepBDB](http://huanglab.phys.hust.edu.cn/pepbdb/db/1cta_A/) and [QBioLip](https://yanglab.qd.sdu.edu.cn/Q-BioLiP/Download), and use ```playgrounds/gen_dataset```.ipynb to reproduce the dataset.
54
+
55
+ The ```PepMerge_lmdb.zip``` contains several different splits of the dataset. We use ```mmseqs2``` to cluster complexes based on receptor sequence identity. See ```playgrounds/cluster.ipynb``` for details. The names.txt file contains the names of complexes in the test set. You can use ```models_con/pep_dataloader.py``` to load these datasets. We suggest putting these LMDBs in a single ```Data``` folder.
56
+
57
+ Besides, ```model1.pt``` and ```model2.pt``` are two checkpoints that you can load using ```models_con/flow_model.py``` together with the config file configs/learn_angle.yaml. We suggest using model1 for benchmark evaluation and model2 for real-world peptide design tasks, the latter is trained on a larger dataset.
58
+
59
+
60
+ ## Usage
61
+
62
+ We will add more user-friendly straightforward pipelines (generation and evaluation) later.
63
+
64
+ ### Inference and Generate
65
+
66
+ By default, we support sampling of generated peptides from our processed dataset. You can use ```models_con/sample.py``` to sample, and ```models_con/inference.py``` to reconstruct PDB files.
67
+
68
+ If you want to use your own data, you can organize your data (peptide and pocket) as we did in PepMerge_release and construct a dataset for sampling and reconstruction. You can also use ```models_con/pep_dataloader/preprocess_structure``` to parse a single data point.
69
+
70
+
71
+
72
+
73
+ ### Evaluation
74
+
75
+ Our evaluation involves many third-party packages, and we include some useful evaluation scripts in ```eval```. Please refer to our paper for details and download the corresponding packages for evaluation. Please use different python environments for these tools.
76
+
77
+
78
+
79
+ ### Train
80
+
81
+ You can also ```train.py``` on single GPU training and ```train_ddp.py``` for multiple GPT training.
82
+
83
+
84
+ ## Future Work
85
+
86
+
87
+ Future improvements on peptide generation models may include chemical modifications, non-canonical amino acids, pretraining on larger datasets, language models, better sampling methods, etc. Stay tuned and feel free to contact us for collaboration and discussion!
88
+
89
+
90
+
91
+ ## Reference
92
+
93
+ ```bibtex
94
+ @InProceedings{pmlr-v235-li24o,
95
+ title={Full-Atom Peptide Design based on Multi-modal Flow Matching},
96
+ author={Li, Jiahan and Cheng, Chaoran and Wu, Zuofan and Guo, Ruihan and Luo, Shitong and Ren, Zhizhou and Peng, Jian and Ma, Jianzhu},
97
+ booktitle={Proceedings of the 41st International Conference on Machine Learning},
98
+ pages={27615--27640},
99
+ year={2024},
100
+ editor={Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
101
+ volume={235},
102
+ series={Proceedings of Machine Learning Research},
103
+ month=21--27 Jul},
104
+ publisher={PMLR},
105
+ }
106
+ ```
configs/learn_angle.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ encoder:
3
+ node_embed_size: 128
4
+ edge_embed_size: 64
5
+ ipa:
6
+ c_s: 128 #${model.node_embed_size}
7
+ c_z: 64 #${model.edge_embed_size}
8
+ c_hidden: 128
9
+ no_heads: 8
10
+ no_qk_points: 8
11
+ no_v_points: 12
12
+ seq_tfmr_num_heads: 4
13
+ seq_tfmr_num_layers: 2
14
+ num_blocks: 6
15
+ stop_grad: False
16
+ interpolant:
17
+ min_t: 1.e-2
18
+ t_normalization_clip: 0.9
19
+ sample_sequence: True
20
+ sample_structure: True
21
+ rots:
22
+ train_schedule: linear
23
+ sample_schedule: exp
24
+ exp_rate: 10
25
+ trans:
26
+ train_schedule: linear
27
+ sample_schedule: linear
28
+ sigma: 1.0
29
+ seqs:
30
+ num_classes: 20
31
+ simplex_value: 5.0
32
+ sampling:
33
+ num_timesteps: 100
34
+ self_condition: False
35
+
36
+ train:
37
+ loss_weights:
38
+ trans_loss: 0.5 # 1.0 for dreamfold, 0.05 for yim
39
+ rot_loss: 0.5 # 1.0 for dreamfold, 0.5 for yim
40
+ bb_atom_loss: 0.25
41
+ seqs_loss: 1.0
42
+ angle_loss: 1.0
43
+ torsion_loss: 0.5
44
+ max_iters: 400000000
45
+ val_freq: 20000
46
+ batch_size: 32
47
+ accum_grad: 1
48
+ seed: 114514
49
+ max_grad_norm: 100.0
50
+ optimizer:
51
+ type: adam
52
+ lr: 5.e-4 #1.e-4
53
+ weight_decay: 0.0
54
+ beta1: 0.9
55
+ beta2: 0.999
56
+ scheduler:
57
+ type: plateau
58
+ factor: 0.8
59
+ patience: 10
60
+ min_lr: 5.e-6
61
+
62
+ dataset:
63
+ train:
64
+ type: peprec
65
+ structure_dir: /datapool/data2/home/jiahan/Data/PepMerge_new/
66
+ dataset_dir: /datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Fixed Data
67
+ name: pep_pocket_train
68
+ reset: False
69
+ val:
70
+ type: peprec
71
+ structure_dir: /datapool/data2/home/jiahan/Data/PepMerge_new/
72
+ dataset_dir: /datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Fixed Data
73
+ name: pep_pocket_test
74
+ reset: False
environment.yml ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: flow
2
+ channels:
3
+ - conda-forge
4
+ - nvidia
5
+ - pytorch
6
+ dependencies:
7
+ - _libgcc_mutex==0.1=conda_forge
8
+ - _openmp_mutex==4.5=2_gnu
9
+ - anyio==3.7.1=pyhd8ed1ab_0
10
+ - argon2-cffi==21.3.0=pyhd8ed1ab_0
11
+ - argon2-cffi-bindings==21.2.0=py310h5764c6d_3
12
+ - arrow==1.2.3=pyhd8ed1ab_0
13
+ - asttokens==2.2.1=pyhd8ed1ab_0
14
+ - astunparse==1.6.3=pyhd8ed1ab_0
15
+ - async-lru==2.0.4=pyhd8ed1ab_0
16
+ - attrs==23.1.0=pyh71513ae_1
17
+ - babel==2.12.1=pyhd8ed1ab_1
18
+ - backcall==0.2.0=pyh9f0ad1d_0
19
+ - backports==1.0=pyhd8ed1ab_3
20
+ - backports.functools_lru_cache==1.6.5=pyhd8ed1ab_0
21
+ - beautifulsoup4==4.12.2=pyha770c72_0
22
+ - biopython==1.81=py310h1fa729e_0
23
+ - biotite==0.38.0
24
+ - bleach==6.0.0=pyhd8ed1ab_0
25
+ - blosc==1.21.4=h0f2a231_0
26
+ - brotli==1.0.9=h166bdaf_9
27
+ - brotli-bin==1.0.9=h166bdaf_9
28
+ - brotli-python==1.0.9=py310hd8f1fbe_9
29
+ - bzip2==1.0.8=h7f98852_4
30
+ - c-ares==1.19.1=hd590300_0
31
+ - c-blosc2==2.10.2=hb4ffafa_0
32
+ - ca-certificates==2023.7.22=hbcca054_0
33
+ - cached-property==1.5.2=hd8ed1ab_1
34
+ - cached_property==1.5.2=pyha770c72_1
35
+ - certifi==2023.7.22=pyhd8ed1ab_0
36
+ - cffi==1.15.1=py310h255011f_3
37
+ - charset-normalizer==3.2.0=pyhd8ed1ab_0
38
+ - comm==0.1.4=pyhd8ed1ab_0
39
+ - contourpy==1.1.0=py310hd41b1e2_0
40
+ - cuda==11.6.0=0
41
+ - cuda-cccl==11.6.55=hf6102b2_0
42
+ - cuda-command-line-tools==11.6.2=0
43
+ - cuda-compiler==11.6.2=0
44
+ - cuda-cudart==11.6.55=he381448_0
45
+ - cuda-cudart-dev==11.6.55=h42ad0f4_0
46
+ - cuda-cuobjdump==11.6.124=h2eeebcb_0
47
+ - cuda-cupti==11.6.124=h86345e5_0
48
+ - cuda-cuxxfilt==11.6.124=hecbf4f6_0
49
+ - cuda-driver-dev==11.6.55=0
50
+ - cuda-gdb==12.0.90=hd47b8d6_0
51
+ - cuda-libraries==11.6.2=0
52
+ - cuda-libraries-dev==11.6.0=0
53
+ - cuda-memcheck==11.8.86=0
54
+ - cuda-nsight==12.0.78=ha770c72_0
55
+ - cuda-nsight-compute==12.2.1=0
56
+ - cuda-nvcc==11.6.124=hbba6d2d_0
57
+ - cuda-nvdisasm==12.0.76=h59595ed_0
58
+ - cuda-nvml-dev==11.6.55=haa9ef22_0
59
+ - cuda-nvprof==12.0.90=h59595ed_0
60
+ - cuda-nvprune==11.6.124=he22ec0a_0
61
+ - cuda-nvrtc==11.6.124=h020bade_0
62
+ - cuda-nvrtc-dev==11.6.124=h249d397_0
63
+ - cuda-nvtx==11.6.124=h0630a44_0
64
+ - cuda-nvvp==12.0.90=h59595ed_0
65
+ - cuda-runtime==11.6.2=0
66
+ - cuda-samples==11.6.101=h8efea70_0
67
+ - cuda-sanitizer-api==12.0.90=h59595ed_0
68
+ - cuda-toolkit==11.6.0=0
69
+ - cuda-tools==11.6.0=0
70
+ - cuda-version==12.0=hffde075_2
71
+ - cuda-visual-tools==11.6.0=0
72
+ - cycler==0.11.0=pyhd8ed1ab_0
73
+ - debugpy==1.6.8=py310hc6cd4ac_0
74
+ - decorator==5.1.1=pyhd8ed1ab_0
75
+ - defusedxml==0.7.1=pyhd8ed1ab_0
76
+ - entrypoints==0.4=pyhd8ed1ab_0
77
+ - exceptiongroup==1.1.3=pyhd8ed1ab_0
78
+ - executing==1.2.0=pyhd8ed1ab_0
79
+ - flit-core==3.9.0=pyhd8ed1ab_0
80
+ - fonttools==4.42.0=py310h2372a71_0
81
+ - fqdn==1.5.1=pyhd8ed1ab_0
82
+ - freetype==2.12.1=hca18f0e_1
83
+ - gds-tools==1.5.0.59=hcb278e6_0
84
+ - gmp==6.2.1=h58526e2_0
85
+ - hdf5==1.14.1=nompi_h4f84152_100
86
+ - idna==3.4=pyhd8ed1ab_0
87
+ - importlib-metadata==6.8.0=pyha770c72_0
88
+ - importlib_metadata==6.8.0=hd8ed1ab_0
89
+ - importlib_resources==6.0.1=pyhd8ed1ab_0
90
+ - ipykernel==6.25.1=pyh71e2992_0
91
+ - ipython==8.14.0=pyh41d4057_0
92
+ - isoduration==20.11.0=pyhd8ed1ab_0
93
+ - jedi==0.19.0=pyhd8ed1ab_0
94
+ - jinja2==3.1.2=pyhd8ed1ab_1
95
+ - json5==0.9.14=pyhd8ed1ab_0
96
+ - jsonpointer==2.0=py_0
97
+ - jsonschema==4.19.0=pyhd8ed1ab_1
98
+ - jsonschema-specifications==2023.7.1=pyhd8ed1ab_0
99
+ - jsonschema-with-format-nongpl==4.19.0=pyhd8ed1ab_1
100
+ - jupyter-lsp==2.2.0=pyhd8ed1ab_0
101
+ - jupyter_client==8.3.0=pyhd8ed1ab_0
102
+ - jupyter_core==5.3.1=py310hff52083_0
103
+ - jupyter_events==0.7.0=pyhd8ed1ab_2
104
+ - jupyter_server==2.7.1=pyhd8ed1ab_0
105
+ - jupyter_server_terminals==0.4.4=pyhd8ed1ab_1
106
+ - jupyterlab==4.0.5=pyhd8ed1ab_0
107
+ - jupyterlab_pygments==0.2.2=pyhd8ed1ab_0
108
+ - jupyterlab_server==2.24.0=pyhd8ed1ab_0
109
+ - keyutils==1.6.1=h166bdaf_0
110
+ - kiwisolver==1.4.4=py310hbf28c38_1
111
+ - krb5==1.21.2=h659d440_0
112
+ - lcms2==2.15=haa2dc70_1
113
+ - ld_impl_linux-64==2.40=h41732ed_0
114
+ - lerc==4.0.0=h27087fc_0
115
+ - libaec==1.0.6=hcb278e6_1
116
+ - libblas==3.9.0=17_linux64_openblas
117
+ - libbrotlicommon==1.0.9=h166bdaf_9
118
+ - libbrotlidec==1.0.9=h166bdaf_9
119
+ - libbrotlienc==1.0.9=h166bdaf_9
120
+ - libcblas==3.9.0=17_linux64_openblas
121
+ - libcublas==12.0.1.189=hcb278e6_2
122
+ - libcublas-dev==12.0.1.189=hcb278e6_2
123
+ - libcufft==11.0.0.21=hcb278e6_1
124
+ - libcufft-dev==11.0.0.21=hcb278e6_1
125
+ - libcufile==1.5.0.59=hcb278e6_0
126
+ - libcufile-dev==1.5.0.59=hcb278e6_0
127
+ - libcurand==10.3.1.50=hcb278e6_0
128
+ - libcurand-dev==10.3.1.50=hcb278e6_0
129
+ - libcurl==8.2.1=hca28451_0
130
+ - libcusolver==11.4.2.57=hcb278e6_1
131
+ - libcusparse==12.0.0.76=hcb278e6_1
132
+ - libdeflate==1.18=h0b41bf4_0
133
+ - libedit==3.1.20191231=he28a2e2_2
134
+ - libev==4.33=h516909a_1
135
+ - libffi==3.4.2=h7f98852_5
136
+ - libgcc-ng==13.1.0=he5830b7_0
137
+ - libgfortran-ng==13.1.0=h69a702a_0
138
+ - libgfortran5==13.1.0=h15d22d2_0
139
+ - libgomp==13.1.0=he5830b7_0
140
+ - libjpeg-turbo==2.1.5.1=h0b41bf4_0
141
+ - liblapack==3.9.0=17_linux64_openblas
142
+ - libnghttp2==1.52.0=h61bc06f_0
143
+ - libnpp==12.0.0.30=h59595ed_0
144
+ - libnpp-dev==12.0.0.30=h59595ed_0
145
+ - libnsl==2.0.0=h7f98852_0
146
+ - libnuma==2.0.16=h0b41bf4_1
147
+ - libnvjitlink==12.0.76=hcb278e6_1
148
+ - libnvjpeg==12.0.0.28=hcb278e6_0
149
+ - libnvjpeg-dev==12.0.0.28=ha770c72_0
150
+ - libopenblas==0.3.23=pthreads_h80387f5_0
151
+ - libpng==1.6.39=h753d276_0
152
+ - libsodium==1.0.18=h36c2ea0_1
153
+ - libsqlite==3.42.0=h2797004_0
154
+ - libssh2==1.11.0=h0841786_0
155
+ - libstdcxx-ng==13.1.0=hfd8a6a1_0
156
+ - libtiff==4.5.1=h8b53f26_0
157
+ - libuuid==2.38.1=h0b41bf4_0
158
+ - libwebp-base==1.3.1=hd590300_0
159
+ - libxcb==1.15=h0b41bf4_0
160
+ - libzlib==1.2.13=hd590300_5
161
+ - lz4-c==1.9.4=hcb278e6_0
162
+ - lzo==2.10=h516909a_1000
163
+ - markupsafe==2.1.3=py310h2372a71_0
164
+ - matplotlib-base==3.7.2=py310hf38f957_0
165
+ - matplotlib-inline==0.1.6=pyhd8ed1ab_0
166
+ - mdtraj==1.9.9=py310h8e08b51_0
167
+ - mistune==3.0.1=pyhd8ed1ab_0
168
+ - munkres==1.1.4=pyh9f0ad1d_0
169
+ - nbclient==0.8.0=pyhd8ed1ab_0
170
+ - nbconvert-core==7.7.4=pyhd8ed1ab_0
171
+ - nbformat==5.9.2=pyhd8ed1ab_0
172
+ - ncurses==6.4=hcb278e6_0
173
+ - nest-asyncio==1.5.6=pyhd8ed1ab_0
174
+ - nomkl==1.0=h5ca1d4c_0
175
+ - notebook-shim==0.2.3=pyhd8ed1ab_0
176
+ - nsight-compute==2023.2.1.3=0
177
+ - numexpr==2.8.4=py310hd91493a_101
178
+ - numpy==1.25.2=py310ha4c1d20_0
179
+ - openjpeg==2.5.0=hfec8fc6_2
180
+ - openssl==3.1.2=hd590300_0
181
+ - overrides==7.4.0=pyhd8ed1ab_0
182
+ - packaging==23.1=pyhd8ed1ab_0
183
+ - pandas==2.0.3=py310h7cbd5c2_1
184
+ - pandocfilters==1.5.0=pyhd8ed1ab_0
185
+ - parso==0.8.3=pyhd8ed1ab_0
186
+ - patsy==0.5.3=pyhd8ed1ab_0
187
+ - pexpect==4.8.0=pyh1a96a4e_2
188
+ - pickleshare==0.7.5=py_1003
189
+ - pillow==10.0.0=py310h582fbeb_0
190
+ - pip
191
+ - pkgutil-resolve-name==1.3.10=pyhd8ed1ab_0
192
+ - platformdirs==3.10.0=pyhd8ed1ab_0
193
+ - pooch==1.7.0=pyha770c72_3
194
+ - prometheus_client==0.17.1=pyhd8ed1ab_0
195
+ - prompt-toolkit==3.0.39=pyha770c72_0
196
+ - prompt_toolkit==3.0.39=hd8ed1ab_0
197
+ - psutil==5.9.5=py310h1fa729e_0
198
+ - pthread-stubs==0.4=h36c2ea0_1001
199
+ - ptyprocess==0.7.0=pyhd3deb0d_0
200
+ - pure_eval==0.2.2=pyhd8ed1ab_0
201
+ - py-cpuinfo==9.0.0=pyhd8ed1ab_0
202
+ - pycparser==2.21=pyhd8ed1ab_0
203
+ - pygments==2.16.1=pyhd8ed1ab_0
204
+ - pyparsing==3.0.9=pyhd8ed1ab_0
205
+ - pysocks==1.7.1=pyha2e5f31_6
206
+ - pytables==3.8.0=py310ha028ce3_2
207
+ - python==3.10.12=hd12c33a_0_cpython
208
+ - python-dateutil==2.8.2=pyhd8ed1ab_0
209
+ - python-fastjsonschema==2.18.0=pyhd8ed1ab_0
210
+ - python-json-logger==2.0.7=pyhd8ed1ab_0
211
+ - python-tzdata==2023.3=pyhd8ed1ab_0
212
+ - python_abi==3.10=3_cp310
213
+ - pytorch-cuda==11.6=h867d48c_0
214
+ - pytz==2023.3=pyhd8ed1ab_0
215
+ - pyyaml==6.0=py310h5764c6d_5
216
+ - pyzmq==25.1.1=py310h5bbb5d0_0
217
+ - readline==8.2=h8228510_1
218
+ - referencing==0.30.2=pyhd8ed1ab_0
219
+ - requests==2.31.0=pyhd8ed1ab_0
220
+ - rfc3339-validator==0.1.4=pyhd8ed1ab_0
221
+ - rfc3986-validator==0.1.1=pyh9f0ad1d_0
222
+ - rpds-py==0.9.2=py310hcb5633a_0
223
+ - scipy==1.11.1=py310ha4c1d20_0
224
+ - seaborn==0.12.2=hd8ed1ab_0
225
+ - seaborn-base==0.12.2=pyhd8ed1ab_0
226
+ - send2trash==1.8.2=pyh41d4057_0
227
+ - setuptools==68.1.2=pyhd8ed1ab_0
228
+ - six==1.16.0=pyh6c4a22f_0
229
+ - snappy==1.1.10=h9fff704_0
230
+ - sniffio==1.3.0=pyhd8ed1ab_0
231
+ - soupsieve==2.3.2.post1=pyhd8ed1ab_0
232
+ - stack_data==0.6.2=pyhd8ed1ab_0
233
+ - statsmodels==0.14.0=py310h278f3c1_1
234
+ - terminado==0.17.1=pyh41d4057_0
235
+ - tinycss2==1.2.1=pyhd8ed1ab_0
236
+ - tk==8.6.12=h27826a3_0
237
+ - tomli==2.0.1=pyhd8ed1ab_0
238
+ - tornado==6.3.3=py310h2372a71_0
239
+ - traitlets==5.9.0=pyhd8ed1ab_0
240
+ - typing-extensions==4.7.1=hd8ed1ab_0
241
+ - typing_extensions==4.7.1=pyha770c72_0
242
+ - typing_utils==0.1.0=pyhd8ed1ab_0
243
+ - tzdata==2023c=h71feb2d_0
244
+ - unicodedata2==15.0.0=py310h5764c6d_0
245
+ - uri-template==1.3.0=pyhd8ed1ab_0
246
+ - urllib3==2.0.4=pyhd8ed1ab_0
247
+ - wcwidth==0.2.6=pyhd8ed1ab_0
248
+ - webcolors==1.13=pyhd8ed1ab_0
249
+ - webencodings==0.5.1=py_1
250
+ - websocket-client==1.6.1=pyhd8ed1ab_0
251
+ - wheel==0.41.1=pyhd8ed1ab_0
252
+ - xorg-libxau==1.0.11=hd590300_0
253
+ - xorg-libxdmcp==1.1.3=h7f98852_0
254
+ - xz==5.2.6=h166bdaf_0
255
+ - yaml==0.2.5=h7f98852_2
256
+ - zeromq==4.3.4=h9c3ff4c_1
257
+ - zipp==3.16.2=pyhd8ed1ab_0
258
+ - zlib==1.2.13=hd590300_5
259
+ - zlib-ng==2.0.7=h0b41bf4_0
260
+ - zstd==1.5.2=hfc55251_7
261
+
eval/align.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import re
3
+ from tqdm import tqdm
4
+ import os
5
+
6
+
7
+ RUNNER = '/datapool/data2/home/jiahan/Tool/TMalign-20180426/MMalign'
8
+
9
+ def align_pdb(pdb1,pdb2,pdb1_out):
10
+ subprocess.run([RUNNER,pdb1,pdb2,'-o',pdb1_out],stdout=subprocess.PIPE)
11
+
12
+ def get_tm_score(pdb1,pdb2):
13
+ cmd = subprocess.run(['TMscore',pdb1,pdb2],stdout=subprocess.PIPE)
14
+ out = cmd.stdout.decode()
15
+ tm_score = re.search(r"TM-score\s+=\s+(\d+\.\d+)", out)
16
+ rmsd = re.search(r"RMSD of the common residues=\s+(\d+\.\d+)", out)
17
+ return float(rmsd.group(1)),float(tm_score.group(1))
eval/energy.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrosetta
2
+ from pyrosetta import init, pose_from_pdb, get_fa_scorefxn
3
+ from pyrosetta.rosetta.protocols.relax import FastRelax
4
+ from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
5
+ from pyrosetta.rosetta.core.pack.task import TaskFactory
6
+ from pyrosetta.rosetta.core.pack.task.operation import RestrictToRepacking
7
+ from pyrosetta.rosetta.protocols.minimization_packing import PackRotamersMover
8
+
9
+ import os
10
+ import pandas as pd
11
+ import subprocess
12
+ import numpy as np
13
+ import shutil
14
+ from tqdm import tqdm
15
+ import pickle
16
+
17
+ from joblib import delayed, Parallel
18
+ from utils import *
19
+
20
+ input_dir=".Tests"
21
+ output_dir="./Pack"
22
+
23
+ def get_chain_dic(input_pdb):
24
+ parser = PDBParser()
25
+ structure = parser.get_structure("protein", input_pdb)
26
+ chain_dic = {}
27
+ for model in structure:
28
+ for chain in model:
29
+ chain_dic[chain.id] = len([res for res in chain if is_aa(res) and res.has_id('CA')])
30
+
31
+ return chain_dic
32
+
33
+ def get_rosetta_score_base(pdb_path,chain_id='A'):
34
+ try:
35
+ init()
36
+ pose = pyrosetta.pose_from_pdb(pdb_path)
37
+ chains = list(get_chain_dic(pdb_path).keys())
38
+ chains.remove(chain_id)
39
+ interface = f'{chain_id}_{"".join(chains)}'
40
+ fast_relax = FastRelax() # cant be pickled
41
+ scorefxn = get_fa_scorefxn()
42
+ fast_relax.set_scorefxn(scorefxn)
43
+ mover = InterfaceAnalyzerMover(interface)
44
+ mover.set_pack_separated(True)
45
+ stabs,binds = [],[]
46
+ for i in range(5):
47
+ fast_relax.apply(pose)
48
+ stab = scorefxn(pose)
49
+ mover.apply(pose)
50
+ bind = pose.scores['dG_separated']
51
+ stabs.append(stab)
52
+ binds.append(bind)
53
+ return {'name':pdb_path,'stab':np.array(stabs).mean(),'bind':np.array(binds).mean()}
54
+ except:
55
+ return {'name':pdb_path,'stab':999.0,'bind':999.0}
56
+
57
+
58
+ def get_rosetta_score(pdb_path,chain='A'):
59
+ try:
60
+ init()
61
+ pose = pyrosetta.pose_from_pdb(pdb_path)
62
+ # chains = list(get_chain_dic(os.path.join(input_dir,name,'pocket_merge_renum.pdb')).keys())
63
+ # chains.remove(chain)
64
+ # interface = f'{chain}_{"".join(chains)}'
65
+ interface='A_B'
66
+ fast_relax = FastRelax() # cant be pickled
67
+ scorefxn = get_fa_scorefxn()
68
+ fast_relax.set_scorefxn(scorefxn)
69
+ mover = InterfaceAnalyzerMover(interface)
70
+ mover.set_pack_separated(True)
71
+ fast_relax.apply(pose)
72
+ energy = scorefxn(pose)
73
+ mover.apply(pose)
74
+ dg = pose.scores['dG_separated']
75
+ return [pdb_path,energy,dg]
76
+ except:
77
+ return [pdb_path,999.0,999.0]
78
+
79
+ def pack_sc(name='1a1m_C',num_samples=10):
80
+ try:
81
+ if os.path.exists(os.path.join(output_dir,name,'rosetta')):
82
+ shutil.rmtree(os.path.join(output_dir,name,'rosetta'))
83
+ os.makedirs(os.path.join(output_dir,name,'rosetta'),exist_ok=True)
84
+ init()
85
+ tf = TaskFactory()
86
+ tf.push_back(RestrictToRepacking()) # Only repack, don't change amino acid types
87
+ packer = PackRotamersMover()
88
+ packer.task_factory(tf)
89
+ for i in range(num_samples):
90
+ pose = pose_from_pdb(os.path.join(input_dir,name,f'pocket_merge_renum_bb.pdb'))
91
+ packer.apply(pose)
92
+ pose.dump_pdb(os.path.join(output_dir,name,'rosetta',f'packed_{i}.pdb'))
93
+ except:
94
+ return None
eval/foldx.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from joblib import Parallel, delayed
6
+ from tqdm import tqdm
7
+ import tempfile
8
+ import os
9
+ import shutil
10
+ import subprocess
11
+
12
+ from Bio.PDB import PDBParser
13
+
14
+ def fetch_stability_score(path):
15
+ u = pd.read_csv(path, sep='\t', header=None)
16
+ return u.values[0][1]
17
+
18
+ def fetch_binding_affinity(path):
19
+ with open(path, 'r') as f:
20
+ u = f.readlines()
21
+ return float(u[-1].split("\t")[-3])
22
+
23
+ class FoldXSession(object):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.tmpdir = tempfile.TemporaryDirectory()
27
+ self.pdb_names = []
28
+
29
+ def cleanup(self):
30
+ self.tmpdir.cleanup()
31
+ self.tmpdir = None
32
+
33
+ def __enter__(self):
34
+ return self
35
+
36
+ def __exit__(self, exc_type, exc_val, exc_tb):
37
+ self.cleanup()
38
+
39
+ @property
40
+ def workdir(self):
41
+ return self.tmpdir.name
42
+
43
+ def path(self, filename):
44
+ return os.path.join(self.workdir, filename)
45
+
46
+ def preprocess_data(self, pdb_dir, pdb_name):
47
+ shutil.copy(os.path.join(pdb_dir, pdb_name), self.path(pdb_name))
48
+ return self.path(pdb_name)
49
+
50
+ def get_chain_names(pdb_dir,pdb_name):
51
+ pep_chain = pdb_name.split("_")[-1][0]
52
+ parser = PDBParser()
53
+ structure = parser.get_structure("name", os.path.join(pdb_dir,pdb_name))
54
+ chain_names = [chain.get_id() for model in structure for chain in model]
55
+ chains = f"{pep_chain},"
56
+ for chain in chain_names:
57
+ if chain != pep_chain:
58
+ chains += f"{chain}"
59
+ return chains
60
+
61
+ def process_one_file(pdb_dir,pdb_name):
62
+ chains = get_chain_names(pdb_dir,pdb_name)
63
+ with FoldXSession() as session:
64
+ try:
65
+ # print(session.workdir)
66
+ session.preprocess_data(pdb_dir, pdb_name)
67
+ assert(os.path.exists(session.path(pdb_name)))
68
+ # print(os.listdir(session.workdir))
69
+ ret = subprocess.run(['/datapool/data2/home/ruihan/bin/foldx', '--command='+'AnalyseComplex', '--pdb='+pdb_name, f'--analyseComplexChains={chains}'], cwd=session.workdir, stdout=None)
70
+ fxout_path = session.path(f'Summary_{pdb_name.split(".")[0]}_AC.fxout')
71
+ assert(os.path.exists(fxout_path))
72
+ return (pdb_name.split('.')[0],fetch_binding_affinity(fxout_path))
73
+ except:
74
+ print(f"Error in {pdb_name}")
75
+ print(os.path.exists(fxout_path))
76
+ return (pdb_name.split('.')[0],None)
77
+
eval/geometry.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Bio.PDB import PDBParser, Superimposer, is_aa, Select, NeighborSearch
2
+ import tmtools
3
+ import os
4
+ import numpy as np
5
+ import mdtraj as md
6
+ from Bio.SeqUtils import seq1
7
+
8
+ import warnings
9
+ from Bio import BiopythonWarning, SeqIO
10
+
11
+ import difflib
12
+ import torch
13
+
14
+ # 忽略PDBConstructionWarning
15
+ warnings.filterwarnings('ignore', category=BiopythonWarning)
16
+
17
+ def get_chain_from_pdb(pdb_path, chain_id='A'):
18
+ parser = PDBParser()
19
+ structure = parser.get_structure('X', pdb_path)[0]
20
+ for chain in structure:
21
+ if chain.id == chain_id:
22
+ # print(len(chain))
23
+ return chain
24
+ return None
25
+
26
+ def diff_ratio(str1, str2):
27
+ # Create a SequenceMatcher object
28
+ seq_matcher = difflib.SequenceMatcher(None, str1, str2)
29
+
30
+ # Calculate the difference ratio
31
+ return seq_matcher.ratio()
32
+
33
+ #######################################
34
+
35
+ #RMSD and Tm
36
+
37
+ #######################################
38
+ def align_chains(chain1, chain2):
39
+ reslist1 = []
40
+ reslist2 = []
41
+ for residue1,residue2 in zip(chain1.get_residues(),chain2.get_residues()):
42
+ if is_aa(residue1) and residue1.has_id('CA'): # at least have CA
43
+ reslist1.append(residue1)
44
+ reslist2.append(residue2)
45
+ return reslist1,reslist2
46
+
47
+ def get_rmsd(chain1, chain2):
48
+ # chain1 = get_chain_from_pdb(pdb1, chain_id1)
49
+ # chain2 = get_chain_from_pdb(pdb2, chain_id2)
50
+ if chain1 is None or chain2 is None:
51
+ return None
52
+ super_imposer = Superimposer()
53
+ pos1 = np.array([atom.get_coord() for atom in chain1.get_atoms() if atom.name == 'CA'])
54
+ pos2 = np.array([atom.get_coord() for atom in chain2.get_atoms() if atom.name == 'CA'])
55
+ rmsd1 = np.sqrt(np.sum((pos1 - pos2)**2) / len(pos1))
56
+ super_imposer.set_atoms([atom for atom in chain1.get_atoms() if atom.name == 'CA'],
57
+ [atom for atom in chain2.get_atoms() if atom.name == 'CA'])
58
+ rmsd2 = super_imposer.rms
59
+ return rmsd1,rmsd2
60
+
61
+ def get_tm(chain1,chain2):
62
+ # chain1 = get_chain_from_pdb(pdb1, chain_id1)
63
+ # chain2 = get_chain_from_pdb(pdb2, chain_id2)
64
+ pos1 = np.array([atom.get_coord() for atom in chain1.get_atoms() if atom.name == 'CA'])
65
+ pos2 = np.array([atom.get_coord() for atom in chain2.get_atoms() if atom.name == 'CA'])
66
+ tm_results = tmtools.tm_align(pos1, pos2, 'A'*len(pos1), 'A'*len(pos2))
67
+ # print(dir(tm_results))
68
+ return tm_results.tm_norm_chain2
69
+
70
+ def get_traj_chain(pdb, chain):
71
+ parser = PDBParser()
72
+ structure = parser.get_structure('X', pdb)[0]
73
+ chain2id = {chain.id:i for i,chain in enumerate(structure)}
74
+ traj = md.load(pdb)
75
+ chain_indices = traj.topology.select(f"chainid {chain2id[chain]}")
76
+ traj = traj.atom_slice(chain_indices)
77
+ return traj
78
+
79
+ def get_second_stru(pdb,chain):
80
+ parser = PDBParser()
81
+ structure = parser.get_structure('X', pdb)[0]
82
+ chain2id = {chain.id:i for i,chain in enumerate(structure)}
83
+ traj = md.load(pdb)
84
+ chain_indices = traj.topology.select(f"chainid {chain2id[chain]}")
85
+ traj = traj.atom_slice(chain_indices)
86
+ return md.compute_dssp(traj,simplified=True)
87
+
88
+ def get_ss(traj1,traj2):
89
+ # traj1,traj2 = get_traj_chain(pdb1,chain_id1),get_traj_chain(pdb2,chain_id2)
90
+ ss1,ss2 = md.compute_dssp(traj1,simplified=True),md.compute_dssp(traj2,simplified=True)
91
+ return (ss1==ss2).mean()
92
+
93
+ def get_bind_site(pdb,chain_id):
94
+ parser = PDBParser()
95
+ structure = parser.get_structure('X', pdb)[0]
96
+ peps = [atom for res in structure[chain_id] for atom in res if atom.get_name() == 'CA']
97
+ recs = [atom for chain in structure if chain.get_id()!=chain_id for res in chain for atom in res if atom.get_name() == 'CA']
98
+ # print(recs)
99
+ search = NeighborSearch(recs)
100
+ near_res = []
101
+ for atom in peps:
102
+ near_res += search.search(atom.get_coord(), 10.0, level='R')
103
+ near_res = set([res.get_id()[1] for res in near_res])
104
+ return near_res
105
+
106
+ def get_bind_ratio(pdb1, pdb2, chain_id1, chain_id2):
107
+ near_res1,near_res2 = get_bind_site(pdb1,chain_id1),get_bind_site(pdb2,chain_id2)
108
+ # print(near_res1)
109
+ # print(near_res2)
110
+ return len(near_res1.intersection(near_res2))/(len(near_res2)+1e-10) # last one is gt
111
+
112
+ def get_dihedral(pdb,chain):
113
+ traj = get_traj_chain(pdb,chain)
114
+ #TODO: dihedral
115
+
116
+ def get_seq(pdb,chain_id):
117
+ parser = PDBParser()
118
+ chain = parser.get_structure('X', pdb)[0][chain_id]
119
+ return seq1("".join([residue.get_resname() for residue in chain])) # ignore is_aa,used for extract seq from genrated pdb
120
+
121
+ def get_mpnn_seqs(path):
122
+ fastas = []
123
+ for record in SeqIO.parse(path, "fasta"):
124
+ tmp = [c for c in str(record.seq)]
125
+ fastas.append(tmp)
126
+ return fastas
127
+
eval/run_esmfold.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import subprocess
4
+ import torch
5
+ import esm
6
+ import numpy as np
7
+ import shutil
8
+ from tqdm import tqdm
9
+
10
+ from joblib import delayed, Parallel
11
+
12
+ import warnings
13
+ from Bio import BiopythonWarning, SeqIO
14
+
15
+ from geometry import *
16
+
17
+ # 忽略PDBConstructionWarning
18
+ warnings.filterwarnings('ignore', category=BiopythonWarning)
19
+
20
+ input_dir="./Data/Baselines_new/Tests"
21
+ output_dir="/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Baselines_new/Codesign"
22
+
23
+ model = esm.pretrained.esmfold_v1()
24
+ model = model.eval().to('cuda:2')
25
+
26
+ def process_rf(name='1aze_B'):
27
+ input_dir=".Data/Baselines_new/Tests"
28
+ output_dir=".Data/Baselines_new/Codesign"
29
+ struct_dir = os.path.join(output_dir,name,'rfs_refold')
30
+ seq_dir = os.path.join(output_dir,name,'mpnns','seqs')
31
+ os.makedirs(struct_dir,exist_ok=True)
32
+ seqs = {}
33
+ for seq_path in os.listdir(seq_dir):
34
+ tmp_seqs = []
35
+ if seq_path.endswith('.fasta'):
36
+ for record in SeqIO.parse(os.path.join(seq_dir,seq_path), "fasta"):
37
+ tmp_seqs.append(str(record.seq))
38
+ seqs[seq_path.split('.')[0]] = tmp_seqs[-1]
39
+ for seq_name,seq in seqs.items():
40
+ with torch.no_grad():
41
+ output = model.infer_pdb(seq)
42
+ with open(os.path.join(struct_dir,seq_name+'.pdb'),'w') as f:
43
+ f.write(output)
44
+
45
+ def process_pg(name='1aze_B',chain_id='A'):
46
+ input_dir=".Data/Baselines_new/Tests"
47
+ output_dir=".Data/Baselines_new/Codesign"
48
+ struct_dir = os.path.join(output_dir,name,'pgs_refold')
49
+ seq_dir = os.path.join(output_dir,name,'pgs')
50
+ os.makedirs(struct_dir,exist_ok=True)
51
+ seqs = {}
52
+ for seq_path in os.listdir(seq_dir):
53
+ if seq_path.endswith('.pdb'):
54
+ seqs[seq_path.split('.')[0]] = get_seq(os.path.join(seq_dir,seq_path),chain_id)
55
+ for seq_name,seq in seqs.items():
56
+ with torch.no_grad():
57
+ output = model.infer_pdb(seq)
58
+ with open(os.path.join(struct_dir,seq_name+'.pdb'),'w') as f:
59
+ f.write(output)
60
+
61
+ def refold(name,chain_id,sub_dir):
62
+ raw_dir = os.path.join('/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Codesign',sub_dir,'pdbs')
63
+ refold_dir = os.path.join('/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Codesign',sub_dir,'pdbs_refold')
64
+ os.makedirs(os.path.join(refold_dir,name),exist_ok=True)
65
+ seqs = {}
66
+ for seq_path in os.listdir(os.path.join(raw_dir,name)):
67
+ if seq_path.endswith('.pdb'):
68
+ seqs[seq_path.split('.')[0]] = get_seq(os.path.join(raw_dir,name,seq_path),chain_id)
69
+ for seq_name,seq in seqs.items():
70
+ with torch.no_grad():
71
+ output = model.infer_pdb(seq)
72
+ with open(os.path.join(refold_dir,name,seq_name+'.pdb'),'w') as f:
73
+ f.write(output)
eval/run_esmif.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+
3
+ import os
4
+ import pandas as pd
5
+ import subprocess
6
+ import torch
7
+ import numpy as np
8
+ import shutil
9
+ from tqdm import tqdm
10
+
11
+ from joblib import delayed, Parallel
12
+
13
+ input_dir="./Baselines_new/Tests"
14
+ # output_dir="/datapool/data2/home/jiahan/Res Proj/PepDiff/frame-flow/Data/RF_samples"
15
+ output_dir="./Data/Baselines_new/Fixbb"
16
+
17
+ RUNNER = "/datapool/data2/home/jiahan/Tool/esm/examples/inverse_folding/sample_sequences.py"
18
+
19
+
20
+ def process_one_item_esmif(name='1a1m_C',chains_to_design="A",num_samples=10,temperature=0.1):
21
+ if not os.path.exists(os.path.join(output_dir,name,'esms')):
22
+ os.makedirs(os.path.join(output_dir,name,'esms'))
23
+ assert os.path.exists(os.path.join(output_dir,name,'esms'))
24
+ # if not os.path.exists(os.path.join(output_dir,name,'pocket_merge_renum.pdb')):
25
+ # chain_dic = renumber_pdb(os.path.join(input_dir,name,'pocket_merge.pdb'),os.path.join(output_dir,name,'pocket_merge_renum.pdb'))
26
+ dirname = os.path.join(output_dir,name,'esms')
27
+ cmd = [
28
+ "python", RUNNER, os.path.join(input_dir,name,'pocket_merge_renum.pdb'),
29
+ "--chain", chains_to_design, "--temperature", f"{temperature}", "--num-samples", f"{num_samples}",
30
+ "--outpath", os.path.join(dirname,'pocket_merge_renum.fasta'),
31
+ "--multichain-backbone", "--nogpu"
32
+ ]
33
+ subprocess.run(cmd)
eval/run_mpnn.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from geometry import *
3
+
4
+ import os
5
+ import pandas as pd
6
+ import subprocess
7
+ import torch
8
+ import numpy as np
9
+ import shutil
10
+ from tqdm import tqdm
11
+
12
+ from joblib import delayed, Parallel
13
+
14
+ from Bio.PDB import PDBParser, PDBIO, Select
15
+
16
+
17
+ HELPERS = "/datapool/data2/home/jiahan/Tool/ProteinMPNN/helper_scripts"
18
+ RUNNER = "/datapool/data2/home/jiahan/Tool/ProteinMPNN/protein_mpnn_run.py"
19
+
20
+ def get_chain_nums(pdb_path,chain_id):
21
+ parser = PDBParser()
22
+ chain = parser.get_structure('X',pdb_path)[0][chain_id]
23
+ residue_nums = [residue.get_id()[1] for residue in chain]
24
+ return residue_nums
25
+
26
+ def process_mpnn_bb(name='1aze_B',chains_to_design="A",num_samples=1):
27
+ input_dir = './Data/Models_new/Codesign/bb/pdbs'
28
+ output_dir = './Data/Models_new/Codesign/bb/seqs'
29
+ if not os.path.exists(os.path.join(output_dir,name)):
30
+ os.makedirs(os.path.join(output_dir,name))
31
+ dirname = os.path.join(output_dir,name)
32
+ # defined dirs
33
+ path_for_parsed_chains=os.path.join(dirname,'parsed_pdbs.jsonl')
34
+ path_for_assigned_chains=os.path.join(dirname,'assigned_pdbs.jsonl')
35
+ path_for_fixed_positions=os.path.join(dirname,'fixed_pdbs.jsonl')
36
+ residue_nums = get_chain_nums(os.path.join(input_dir,name,'gt.pdb'),chains_to_design)
37
+ design_only_positions = " ".join(map(str,residue_nums)) #design only these residues; use flag --specify_non_fixed
38
+ # print(path_for_assigned_chains)
39
+ # print(design_only_positions)
40
+ subprocess.run([
41
+ "python", os.path.join(HELPERS,"parse_multiple_chains.py"),
42
+ "--input_path", os.path.join(input_dir,name),
43
+ "--output_path", path_for_parsed_chains,
44
+ ])
45
+ subprocess.run([
46
+ "python", os.path.join(HELPERS,"assign_fixed_chains.py"),
47
+ "--input_path", path_for_parsed_chains,
48
+ "--output_path", path_for_assigned_chains,
49
+ '--chain_list', chains_to_design,
50
+ ])
51
+ subprocess.run([
52
+ "python", os.path.join(HELPERS,"make_fixed_positions_dict.py"),
53
+ "--input_path", path_for_parsed_chains,
54
+ "--output_path", path_for_fixed_positions,
55
+ '--chain_list', chains_to_design,
56
+ '--position_list', design_only_positions,
57
+ '--specify_non_fixed'
58
+ ])
59
+ # run mpnn
60
+ # print('run mpnns')
61
+ subprocess.run([
62
+ "python", RUNNER,
63
+ "--jsonl_path", path_for_parsed_chains,
64
+ "--chain_id_jsonl", path_for_assigned_chains,
65
+ "--fixed_positions_jsonl", path_for_fixed_positions,
66
+ "--out_folder", dirname,
67
+ "--num_seq_per_target", f"{num_samples}",
68
+ "--sampling_temp", "0.1",
69
+ "--seed", "37",
70
+ "--batch_size","1",
71
+ '--device','cuda:1'
72
+ ])
73
+
74
+ def process_one_item_mpnn(name='1a1m_C',chains_to_design="A",num_samples=1):
75
+ input_dir="./Data/Baselines_new/Tests"
76
+ output_dir="./Data/Baselines_new/Codesign"
77
+ if not os.path.exists(os.path.join(output_dir,name,'mpnns')):
78
+ os.makedirs(os.path.join(output_dir,name,'mpnns'))
79
+ # if not os.path.exists(os.path.join(output_dir,name,'pocket_merge_renum.pdb')):
80
+ # chain_dic = renumber_pdb(os.path.join(input_dir,name,'pocket_merge.pdb'),os.path.join(output_dir,name,'pocket_merge_renum.pdb'))
81
+ dirname = os.path.join(output_dir,name,'mpnns')
82
+ # defined dirs
83
+ path_for_parsed_chains=os.path.join(dirname,'parsed_pdbs.jsonl')
84
+ path_for_assigned_chains=os.path.join(dirname,'assigned_pdbs.jsonl')
85
+ path_for_fixed_positions=os.path.join(dirname,'fixed_pdbs.jsonl')
86
+ with open(os.path.join(input_dir,name,'seq.fasta'),'r') as f:
87
+ pep_len = len(f.readlines()[1].strip())
88
+ design_only_positions=" ".join(map(str,list(range(1,pep_len+1)))) #design only these residues; use flag --specify_non_fixed
89
+ # print(design_only_positions)
90
+ # parsed chains
91
+ # print("parsing chains")
92
+ subprocess.run([
93
+ "python", os.path.join(HELPERS,"parse_multiple_chains.py"),
94
+ "--input_path", os.path.join('./Data/Baselines_new/Codesign',name,'rfs'),#os.path.join('/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Baselines/Fixbb/',name),
95
+ "--output_path", path_for_parsed_chains,
96
+ ])
97
+ subprocess.run([
98
+ "python", os.path.join(HELPERS,"assign_fixed_chains.py"),
99
+ "--input_path", path_for_parsed_chains,
100
+ "--output_path", path_for_assigned_chains,
101
+ '--chain_list', chains_to_design,
102
+ ])
103
+ subprocess.run([
104
+ "python", os.path.join(HELPERS,"make_fixed_positions_dict.py"),
105
+ "--input_path", path_for_parsed_chains,
106
+ "--output_path", path_for_fixed_positions,
107
+ '--chain_list', chains_to_design,
108
+ '--position_list', design_only_positions,
109
+ '--specify_non_fixed'
110
+ ])
111
+ # run mpnn
112
+ # print('run mpnns')
113
+ subprocess.run([
114
+ "python", RUNNER,
115
+ "--jsonl_path", path_for_parsed_chains,
116
+ "--chain_id_jsonl", path_for_assigned_chains,
117
+ "--fixed_positions_jsonl", path_for_fixed_positions,
118
+ "--out_folder", dirname,
119
+ "--num_seq_per_target", f"{num_samples}",
120
+ "--sampling_temp", "0.1",
121
+ "--seed", "37",
122
+ "--batch_size","1",
123
+ '--device','cuda:1'
124
+ ])
125
+
126
+
127
+ def write_seq_to_pdb(seq_path,pdb_path,out_path,chain_id):
128
+ # first we should fix GGGGG in rfs with mpnn generated seq
129
+ aa_mapping = {"A": "ALA","C": "CYS","D": "ASP","E": "GLU","F": "PHE","G": "GLY","H": "HIS","I": "ILE","K": "LYS","L": "LEU","M": "MET","N": "ASN","P": "PRO","Q": "GLN","R": "ARG","S": "SER","T": "THR","V": "VAL","W": "TRP","Y": "TYR",
130
+ 'X':'UNK'}
131
+ tmps = []
132
+ for record in SeqIO.parse(seq_path, "fasta"):
133
+ tmps.append(str(record.seq))
134
+ seq = tmps[-1]
135
+
136
+ parser = PDBParser()
137
+ structure = parser.get_structure("X", pdb_path)
138
+ model = structure[0]
139
+ for chain in model:
140
+ if chain.id == chain_id: # 假设你要更改的是链A
141
+ for i,res in enumerate(chain):
142
+ if i<len(seq):
143
+ res.resname = aa_mapping[seq[i]]
144
+ io = PDBIO()
145
+ io.set_structure(structure)
146
+ io.save(out_path)
eval/run_rfdiffusion.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from geometry import *
3
+
4
+ import os
5
+ import pandas as pd
6
+ import subprocess
7
+ import torch
8
+ import numpy as np
9
+ import shutil
10
+ from tqdm import tqdm
11
+
12
+ from joblib import delayed, Parallel
13
+
14
+ input_dir="./Data/Baselines_new/Tests"
15
+ output_dir=".Data/Baselines_new/Codesign"
16
+
17
+ PROGEN="/datapool/data2/home/jiahan/Tool/protein_generator/inference.py"
18
+
19
+ def process_one_item_rf(name='1a1m_C',num_samples=10):
20
+ if not os.path.exists(os.path.join(output_dir,name,'rfs')):
21
+ os.makedirs(os.path.join(output_dir,name,'rfs'))
22
+ chain_dic = get_chain_dic(os.path.join(input_dir,name,'pocket_renum.pdb'))
23
+ with open(os.path.join(input_dir,name,'seq.fasta'),'r') as f:
24
+ pep_len = len(f.readlines()[1].strip())
25
+ # rfdiffusion
26
+ contigs = []
27
+ for chain,chain_len in chain_dic.items():
28
+ contigs.append(f'{chain}1-{chain_len}/0')
29
+ contigs.append(f'{pep_len}-{pep_len}')
30
+ contigs = " ".join(contigs)
31
+ command = [
32
+ "run_inference.py",
33
+ f"inference.output_prefix='{os.path.join(output_dir,name,'rfs','sample')}'",
34
+ f"inference.input_pdb='{os.path.join(input_dir,name,'pocket_renum.pdb')}'",
35
+ f"contigmap.contigs=[{contigs}]",
36
+ f"inference.num_designs={num_samples}",
37
+ ]
38
+ # print(command)
39
+ try:
40
+ result = subprocess.run(command, check=True, capture_output=True, text=True)
41
+ return name
42
+ except:
43
+ return None
44
+
45
+ def process_one_item_pg(name='1a1m_C',num_samples=10):
46
+ if not os.path.exists(os.path.join(output_dir,name,'pgs')):
47
+ os.makedirs(os.path.join(output_dir,name,'pgs'))
48
+ os.makedirs(os.path.join(output_dir,name,'pgs'),exist_ok=True)
49
+ chain_dic = get_chain_dic(os.path.join(input_dir,name,'pocket_renum.pdb'))
50
+ with open(os.path.join(input_dir,name,'seq.fasta'),'r') as f:
51
+ pep_len = len(f.readlines()[1].strip())
52
+ # protein_generator settings
53
+ contigs = []
54
+ for chain,chain_len in chain_dic.items():
55
+ contigs.append(f'{chain}1-{chain_len},0')
56
+ contigs.append(f'{pep_len}-{pep_len}')
57
+ command = [
58
+ "python", PROGEN,
59
+ "--num_designs", f"{num_samples}",
60
+ "--out", os.path.join(output_dir,name,'pgs','sample'),
61
+ "--pdb", os.path.join(input_dir,name,'pocket_renum.pdb'),
62
+ "--T", "25", # default setting
63
+ "--save_best_plddt", # default setting
64
+ "--contigs", *contigs,
65
+ ]
66
+ # print(command)
67
+ try:
68
+ result = subprocess.run(command, check=True, capture_output=True, text=True)
69
+ return name
70
+ except:
71
+ return None
72
+
73
+ def process_one_item(name='1a1m_C',num_samples=10):
74
+ process_one_item_pg(name,num_samples)
75
+ process_one_item_rf(name,num_samples)
eval/run_scwrl4.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+
3
+ import os
4
+ import pandas as pd
5
+ import subprocess
6
+ import numpy as np
7
+ import shutil
8
+ from tqdm import tqdm
9
+
10
+ from joblib import delayed, Parallel
11
+
12
+ input_dir="./Data/Baselines_new/Tests"
13
+ output_dir="./Data/Baselines_new/Pack"
14
+
15
+ RUNNER = "/datapool/data2/home/jiahan/Tool/bin/Scwrl4"
16
+
17
+ def process_one_item_scwrl4(name='1a1m_C',num_samples=10):
18
+ if not os.path.exists(os.path.join(output_dir,name,'scwrls')):
19
+ os.makedirs(os.path.join(output_dir,name,'scwrls'))
20
+ # if not os.path.exists(os.path.join(output_dir,name,'pocket_merge_renum.pdb')):
21
+ # chain_dic = renumber_pdb(os.path.join(input_dir,name,'pocket_merge.pdb'),os.path.join(output_dir,name,'pocket_merge_renum.pdb'))
22
+ # keep_backbone_atoms(os.path.join(output_dir,name,'pocket_merge_renum.pdb'),os.path.join(output_dir,name,'pocket_merge_renum_backbone.pdb'))
23
+ dirname = os.path.join(output_dir,name,'scwrls')
24
+ for i in range(num_samples):
25
+ cmd = [
26
+ RUNNER,
27
+ '-i',os.path.join(input_dir,name,'pocket_merge_renum_bb.pdb'),
28
+ '-o',os.path.join(dirname,f'packed_{i}.pdb'),
29
+ ]
30
+ subprocess.run(cmd)
eval/utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import pandas as pd
4
+ import subprocess
5
+ from difflib import SequenceMatcher
6
+
7
+ from Bio import SeqIO
8
+ from Bio.PDB import PDBParser, PDBIO, Chain, Select, is_aa
9
+ from Bio.PDB.Polypeptide import PPBuilder
10
+
11
+ from Bio.PDB import PDBParser
12
+ from Bio.SeqUtils import seq1
13
+
14
+ # def parse_pdb_chains(pdb_file):
15
+ # parser = PDBParser()
16
+ # structure = parser.get_structure("protein", pdb_file)
17
+ # pp_builder = PPBuilder()
18
+
19
+ # sequences = {}
20
+ # for model in structure:
21
+ # for chain in model:
22
+ # chain_id = chain.get_id()
23
+ # sequence = "".join([str(pp.get_sequence()) for pp in pp_builder.build_peptides(chain)])
24
+ # print(len(sequence))
25
+ # sequences[chain_id] = sequence
26
+
27
+ # return sequences
28
+
29
+ def get_fasta_from_pdb(pdb_file):
30
+ parser = PDBParser()
31
+ structure = parser.get_structure("pdb", pdb_file)
32
+
33
+ fasta_sequence = {}
34
+ for chain in structure.get_chains():
35
+ seq = ""
36
+ for residue in chain.get_residues():
37
+ seq += seq1(residue.get_resname())
38
+ fasta_sequence[chain.id] = seq
39
+
40
+ return fasta_sequence
41
+
42
+ def parse_fasta(file):
43
+ sequences = {}
44
+ with open(file, "r") as fasta_file:
45
+ for i, record in enumerate(SeqIO.parse(fasta_file, "fasta")):
46
+ sequences[i] = str(record.seq).split("/")
47
+ return sequences
48
+
49
+ def renumber_pdb(input_pdb, output_pdb):
50
+ parser = PDBParser()
51
+ structure = parser.get_structure("protein", input_pdb)
52
+
53
+ chain_dic = {}
54
+
55
+ for model in structure:
56
+ old_chains = []
57
+ new_chains = []
58
+ for chain in model: # this may include HEAATM atoms
59
+ new_chain_id = chain.id + "_renum"
60
+ new_chain = Chain.Chain(new_chain_id)
61
+ for i, residue in enumerate(chain):
62
+ new_residue = residue.copy()
63
+ new_residue_id = (residue.id[0], i + 1, residue.id[2])
64
+ new_residue.id = new_residue_id
65
+ new_chain.add(new_residue)
66
+ old_chains.append(chain)
67
+ new_chains.append(new_chain)
68
+ chain_dic[chain.id] = len(list(chain))
69
+
70
+ for chain, new_chain in zip(old_chains, new_chains):
71
+ model.detach_child(chain.id)
72
+ new_chain.id = chain.id
73
+ model.add(new_chain)
74
+
75
+ io = PDBIO()
76
+ io.set_structure(structure)
77
+ io.save(output_pdb)
78
+
79
+ return chain_dic
80
+
81
+ def get_chain_dic(input_pdb):
82
+ parser = PDBParser()
83
+ structure = parser.get_structure("protein", input_pdb)
84
+
85
+ chain_dic = {}
86
+
87
+ for model in structure:
88
+ for chain in model:
89
+ chain_dic[chain.id] = len([res for res in chain if is_aa(res) and res.has_id('CA')])
90
+
91
+ return chain_dic
92
+
93
+
94
+ def keep_backbone_atoms(input_file, output_file):
95
+
96
+ class BackboneSelect(Select):
97
+ def accept_atom(self, atom):
98
+ return atom.get_name() in ["N", "CA", "C", "O"]
99
+
100
+ parser = PDBParser()
101
+ io = PDBIO()
102
+
103
+ structure = parser.get_structure("protein", input_file)
104
+
105
+ io.set_structure(structure)
106
+ io.save(output_file, BackboneSelect())
models_con/edge.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ from pepflow.modules.common.geometry import angstrom_to_nm, pairwise_dihedrals
7
+ from pepflow.modules.common.layers import AngularEncoding
8
+ from pepflow.modules.protein.constants import BBHeavyAtom, AA
9
+
10
+
11
+ class EdgeEmbedder(nn.Module):
12
+
13
+ def __init__(self, feat_dim, max_num_atoms, max_aa_types=22, max_relpos=32, num_bins=16):
14
+ super().__init__()
15
+ self.max_num_atoms = max_num_atoms
16
+ self.max_aa_types = max_aa_types
17
+ self.max_relpos = max_relpos
18
+ self.num_bins = num_bins
19
+ self.aa_pair_embed = nn.Embedding(self.max_aa_types*self.max_aa_types, feat_dim)
20
+ self.relpos_embed = nn.Embedding(2*max_relpos+1, feat_dim)
21
+
22
+ self.aapair_to_distcoef = nn.Embedding(self.max_aa_types*self.max_aa_types, max_num_atoms*max_num_atoms)
23
+ nn.init.zeros_(self.aapair_to_distcoef.weight)
24
+ self.distance_embed = nn.Sequential(
25
+ nn.Linear(max_num_atoms*max_num_atoms, feat_dim), nn.ReLU(),
26
+ nn.Linear(feat_dim, feat_dim), nn.ReLU(),
27
+ )
28
+
29
+ self.dihedral_embed = AngularEncoding()
30
+ feat_dihed_dim = self.dihedral_embed.get_out_dim(2) # Phi and Psi
31
+
32
+ infeat_dim = feat_dim + feat_dim + feat_dim + feat_dihed_dim
33
+ self.out_mlp = nn.Sequential(
34
+ nn.Linear(infeat_dim, feat_dim), nn.ReLU(),
35
+ nn.Linear(feat_dim, feat_dim), nn.ReLU(),
36
+ nn.Linear(feat_dim, feat_dim),
37
+ )
38
+
39
+ def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None):
40
+ """
41
+ Args:
42
+ aa: (N, L).
43
+ res_nb: (N, L).
44
+ chain_nb: (N, L).
45
+ pos_atoms: (N, L, A, 3)
46
+ mask_atoms: (N, L, A)
47
+ trans, sc_trans: (N,L,3)
48
+ structure_mask: (N, L)
49
+ sequence_mask: (N, L), mask out unknown amino acids to generate.
50
+
51
+ Returns:
52
+ (N, L, L, feat_dim)
53
+ """
54
+ N, L = aa.size()
55
+
56
+ # Remove other atoms
57
+ pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
58
+ mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
59
+
60
+ mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
61
+ mask_pair = mask_residue[:, :, None] * mask_residue[:, None, :]
62
+ pair_structure_mask = structure_mask[:, :, None] * structure_mask[:, None, :] if structure_mask is not None else None
63
+
64
+ # Pair identities
65
+ if sequence_mask is not None:
66
+ # Avoid data leakage at training time
67
+ aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
68
+ aa_pair = aa[:,:,None]*self.max_aa_types + aa[:,None,:] # (N, L, L)
69
+ feat_aapair = self.aa_pair_embed(aa_pair)
70
+
71
+ # Relative sequential positions
72
+ same_chain = (chain_nb[:, :, None] == chain_nb[:, None, :])
73
+ relpos = torch.clamp(
74
+ res_nb[:,:,None] - res_nb[:,None,:],
75
+ min=-self.max_relpos, max=self.max_relpos,
76
+ ) # (N, L, L)
77
+ feat_relpos = self.relpos_embed(relpos + self.max_relpos) * same_chain[:,:,:,None]
78
+
79
+ # Distances
80
+ d = angstrom_to_nm(torch.linalg.norm(
81
+ pos_atoms[:,:,None,:,None] - pos_atoms[:,None,:,None,:],
82
+ dim = -1, ord = 2,
83
+ )).reshape(N, L, L, -1) # (N, L, L, A*A)
84
+ c = F.softplus(self.aapair_to_distcoef(aa_pair)) # (N, L, L, A*A)
85
+ d_gauss = torch.exp(-1 * c * d**2)
86
+ mask_atom_pair = (mask_atoms[:,:,None,:,None] * mask_atoms[:,None,:,None,:]).reshape(N, L, L, -1)
87
+ feat_dist = self.distance_embed(d_gauss * mask_atom_pair)
88
+ if pair_structure_mask is not None:
89
+ # Avoid data leakage at training time
90
+ feat_dist = feat_dist * pair_structure_mask[:, :, :, None]
91
+
92
+ # Orientations
93
+ dihed = pairwise_dihedrals(pos_atoms) # (N, L, L, 2)
94
+ feat_dihed = self.dihedral_embed(dihed)
95
+ if pair_structure_mask is not None:
96
+ # Avoid data leakage at training time
97
+ feat_dihed = feat_dihed * pair_structure_mask[:, :, :, None]
98
+
99
+ # # trans embed
100
+ # dist_feats = calc_distogram(
101
+ # trans, min_bin=1e-3, max_bin=20.0, num_bins=self.num_bins)
102
+ # if sc_trans == None:
103
+ # sc_trans = torch.zeros_like(trans)
104
+ # sc_feats = calc_distogram(
105
+ # sc_trans, min_bin=1e-3, max_bin=20.0, num_bins=self.num_bins)
106
+
107
+ # All
108
+ feat_all = torch.cat([feat_aapair, feat_relpos, feat_dist, feat_dihed], dim=-1)
109
+ feat_all = self.out_mlp(feat_all) # (N, L, L, F)
110
+ feat_all = feat_all * mask_pair[:, :, :, None]
111
+
112
+ return feat_all
models_con/flow_model.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import copy
7
+ import math
8
+ from tqdm.auto import tqdm
9
+ import functools
10
+ from torch.utils.data import DataLoader
11
+ import os
12
+ import argparse
13
+
14
+ import pandas as pd
15
+
16
+ from models_con.edge import EdgeEmbedder
17
+ from models_con.node import NodeEmbedder
18
+ from pepflow.modules.common.layers import sample_from, clampped_one_hot
19
+ from models_con.ga import GAEncoder
20
+ from pepflow.modules.protein.constants import AA, BBHeavyAtom, max_num_heavyatoms
21
+ from pepflow.modules.common.geometry import construct_3d_basis
22
+ from pepflow.utils.data import mask_select_data, find_longest_true_segment, PaddingCollate
23
+ from pepflow.utils.misc import seed_all
24
+ from pepflow.utils.train import sum_weighted_losses
25
+ from torch.nn.utils import clip_grad_norm_
26
+
27
+ from pepflow.modules.so3.dist import centered_gaussian,uniform_so3
28
+ from pepflow.modules.common.geometry import batch_align, align
29
+
30
+ from tqdm import tqdm
31
+
32
+ import wandb
33
+
34
+ from data import so3_utils
35
+ from data import all_atom
36
+
37
+ from models_con.pep_dataloader import PepDataset
38
+
39
+ from pepflow.utils.misc import load_config
40
+ from pepflow.utils.train import recursive_to
41
+ from easydict import EasyDict
42
+
43
+ from models_con.utils import process_dic
44
+ from models_con.torsion import get_torsion_angle, torsions_mask
45
+ import models_con.torus as torus
46
+
47
+ import gc
48
+
49
+ from copy import deepcopy
50
+ from pepflow.utils.data import PaddingCollate
51
+ collate_fn = PaddingCollate(eight=False)
52
+ from pepflow.utils.train import recursive_to
53
+
54
+ resolution_to_num_atoms = {
55
+ 'backbone+CB': 5,
56
+ 'full': max_num_heavyatoms
57
+ }
58
+
59
+ class FlowModel(nn.Module):
60
+ def __init__(self,cfg):
61
+ super().__init__()
62
+ self._model_cfg = cfg.encoder
63
+ self._interpolant_cfg = cfg.interpolant
64
+
65
+ self.node_embedder = NodeEmbedder(cfg.encoder.node_embed_size,max_num_heavyatoms)
66
+ self.edge_embedder = EdgeEmbedder(cfg.encoder.edge_embed_size,max_num_heavyatoms)
67
+ self.ga_encoder = GAEncoder(cfg.encoder.ipa)
68
+
69
+ self.sample_structure = self._interpolant_cfg.sample_structure
70
+ self.sample_sequence = self._interpolant_cfg.sample_sequence
71
+
72
+ self.K = self._interpolant_cfg.seqs.num_classes
73
+ self.k = self._interpolant_cfg.seqs.simplex_value
74
+
75
+ def encode(self, batch):
76
+ rotmats_1 = construct_3d_basis(batch['pos_heavyatom'][:, :, BBHeavyAtom.CA],batch['pos_heavyatom'][:, :, BBHeavyAtom.C],batch['pos_heavyatom'][:, :, BBHeavyAtom.N] )
77
+ trans_1 = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA]
78
+ seqs_1 = batch['aa']
79
+
80
+ # ignore psi
81
+ # batch['torsion_angle'] = batch['torsion_angle'][:,:,1:]
82
+ # batch['torsion_angle_mask'] = batch['torsion_angle_mask'][:,:,1:]
83
+ angles_1 = batch['torsion_angle']
84
+
85
+ context_mask = torch.logical_and(batch['mask_heavyatom'][:, :, BBHeavyAtom.CA], ~batch['generate_mask'])
86
+ structure_mask = context_mask if self.sample_structure else None
87
+ sequence_mask = context_mask if self.sample_sequence else None
88
+ node_embed = self.node_embedder(batch['aa'], batch['res_nb'], batch['chain_nb'], batch['pos_heavyatom'],
89
+ batch['mask_heavyatom'], structure_mask=structure_mask, sequence_mask=sequence_mask)
90
+ edge_embed = self.edge_embedder(batch['aa'], batch['res_nb'], batch['chain_nb'], batch['pos_heavyatom'],
91
+ batch['mask_heavyatom'], structure_mask=structure_mask, sequence_mask=sequence_mask)
92
+
93
+ return rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed
94
+
95
+ def zero_center_part(self,pos,gen_mask,res_mask):
96
+ """
97
+ move pos by center of gen_mask
98
+ pos: (B,N,3)
99
+ gen_mask, res_mask: (B,N)
100
+ """
101
+ center = torch.sum(pos * gen_mask[...,None], dim=1) / (torch.sum(gen_mask,dim=-1,keepdim=True) + 1e-8) # (B,N,3)*(B,N,1)->(B,3)/(B,1)->(B,3)
102
+ center = center.unsqueeze(1) # (B,1,3)
103
+ # center = 0. it seems not center didnt influence the result, but its good for training stabilty
104
+ pos = pos - center
105
+ pos = pos * res_mask[...,None]
106
+ return pos,center
107
+
108
+ def seq_to_simplex(self,seqs):
109
+ return clampped_one_hot(seqs, self.K).float() * self.k * 2 - self.k # (B,L,K)
110
+
111
+ def forward(self, batch):
112
+
113
+ num_batch, num_res = batch['aa'].shape
114
+ gen_mask,res_mask,angle_mask = batch['generate_mask'].long(),batch['res_mask'].long(),batch['torsion_angle_mask'].long()
115
+
116
+ #encode
117
+ rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed = self.encode(batch) # no generate mask
118
+
119
+ # prepare for denoise
120
+ trans_1_c,_ = self.zero_center_part(trans_1,gen_mask,res_mask)
121
+ trans_1_c = trans_1 # already centered when constructing dataset
122
+ seqs_1_simplex = self.seq_to_simplex(seqs_1)
123
+ seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
124
+
125
+ with torch.no_grad():
126
+ t = torch.rand((num_batch,1), device=batch['aa'].device)
127
+ t = t*(1-2 * self._interpolant_cfg.t_normalization_clip) + self._interpolant_cfg.t_normalization_clip # avoid 0
128
+ if self.sample_structure:
129
+ # corrupt trans
130
+ trans_0 = torch.randn((num_batch,num_res,3), device=batch['aa'].device) * self._interpolant_cfg.trans.sigma # scale with sigma?
131
+ trans_0_c,_ = self.zero_center_part(trans_0,gen_mask,res_mask)
132
+ trans_t = (1-t[...,None])*trans_0_c + t[...,None]*trans_1_c
133
+ trans_t_c = torch.where(batch['generate_mask'][...,None],trans_t,trans_1_c)
134
+ # corrupt rotmats
135
+ rotmats_0 = uniform_so3(num_batch,num_res,device=batch['aa'].device)
136
+ rotmats_t = so3_utils.geodesic_t(t[..., None], rotmats_1, rotmats_0)
137
+ rotmats_t = torch.where(batch['generate_mask'][...,None,None],rotmats_t,rotmats_1)
138
+ # corrup angles
139
+ angles_0 = torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype) # (B,L,5)
140
+ angles_t = torus.tor_geodesic_t(t[..., None], angles_1, angles_0)
141
+ angles_t = torch.where(batch['generate_mask'][...,None],angles_t,angles_1)
142
+ else:
143
+ trans_t_c = trans_1_c.detach().clone()
144
+ rotmats_t = rotmats_1.detach().clone()
145
+ angles_t = angles_1.detach().clone()
146
+ if self.sample_sequence:
147
+ # corrupt seqs
148
+ seqs_0_simplex = self.k * torch.randn_like(seqs_1_simplex) # (B,L,K)
149
+ seqs_0_prob = F.softmax(seqs_0_simplex,dim=-1) # (B,L,K)
150
+ seqs_t_simplex = ((1 - t[..., None]) * seqs_0_simplex) + (t[..., None] * seqs_1_simplex) # (B,L,K)
151
+ seqs_t_simplex = torch.where(batch['generate_mask'][...,None],seqs_t_simplex,seqs_1_simplex)
152
+ seqs_t_prob = F.softmax(seqs_t_simplex,dim=-1) # (B,L,K)
153
+ seqs_t = sample_from(seqs_t_prob) # (B,L)
154
+ seqs_t = torch.where(batch['generate_mask'],seqs_t,seqs_1)
155
+ else:
156
+ seqs_t = seqs_1.detach().clone()
157
+ seqs_t_simplex = seqs_1_simplex.detach().clone()
158
+ seqs_t_prob = seqs_1_prob.detach().clone()
159
+
160
+ # denoise
161
+ pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t, trans_t_c, angles_t, seqs_t, node_embed, edge_embed, gen_mask, res_mask)
162
+ pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
163
+ pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,torch.clamp(seqs_1,0,19))
164
+ pred_trans_1_c,_ = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
165
+ pred_trans_1_c = pred_trans_1 # implicitly enforce zero center in gen_mask, in this way, we dont need to move receptor when sampling
166
+
167
+ norm_scale = 1 / (1 - torch.min(t[...,None], torch.tensor(self._interpolant_cfg.t_normalization_clip))) # yim etal.trick, 1/1-t
168
+
169
+ # trans vf loss
170
+ trans_loss = torch.sum((pred_trans_1_c - trans_1_c)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
171
+ trans_loss = torch.mean(trans_loss)
172
+
173
+ # rots vf loss
174
+ gt_rot_vf = so3_utils.calc_rot_vf(rotmats_t, rotmats_1)
175
+ pred_rot_vf = so3_utils.calc_rot_vf(rotmats_t, pred_rotmats_1)
176
+ rot_loss = torch.sum(((gt_rot_vf - pred_rot_vf) * norm_scale)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
177
+ rot_loss = torch.mean(rot_loss)
178
+
179
+ # bb aux loss
180
+ gt_bb_atoms = all_atom.to_atom37(trans_1_c, rotmats_1)[:, :, :3]
181
+ pred_bb_atoms = all_atom.to_atom37(pred_trans_1_c, pred_rotmats_1)[:, :, :3]
182
+ # gt_bb_atoms = all_atom.to_bb_atoms(trans_1_c, rotmats_1, angles_1[:,:,0]) # N,CA,C,O,CB
183
+ # pred_bb_atoms = all_atom.to_bb_atoms(pred_trans_1_c, pred_rotmats_1, pred_angles_1[:,:,0])
184
+ # print(gt_bb_atoms.shape)
185
+ bb_atom_loss = torch.sum(
186
+ (gt_bb_atoms - pred_bb_atoms) ** 2 * gen_mask[..., None, None],
187
+ dim=(-1, -2, -3)
188
+ ) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
189
+ bb_atom_loss = torch.mean(bb_atom_loss)
190
+ # bb_atom_loss = torch.mean(torch.where(t[:,0]>=0.75,bb_atom_loss,torch.zeros_like(bb_atom_loss))) # penalty for near gt point
191
+
192
+ # seqs vf loss
193
+ seqs_loss = F.cross_entropy(pred_seqs_1_prob.view(-1,pred_seqs_1_prob.shape[-1]),torch.clamp(seqs_1,0,19).view(-1), reduction='none').view(pred_seqs_1_prob.shape[:-1]) # (N,L), not softmax
194
+ seqs_loss = torch.sum(seqs_loss * gen_mask, dim=-1) / (torch.sum(gen_mask,dim=-1) + 1e-8)
195
+ seqs_loss = torch.mean(seqs_loss)
196
+
197
+ # we should not use angle mask, as you dont know aa type when generating
198
+ # angle_mask_loss = torch.cat([angle_mask,angle_mask],dim=-1) # (B,L,10)
199
+ # angle vf loss
200
+ angle_mask_loss = torsions_mask.to(batch['aa'].device)
201
+ angle_mask_loss = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
202
+ angle_mask_loss = torch.cat([angle_mask_loss,angle_mask_loss],dim=-1) # (B,L,10)
203
+ angle_mask_loss = torch.logical_and(batch['generate_mask'][...,None].bool(),angle_mask_loss)
204
+ gt_angle_vf = torus.tor_logmap(angles_t, angles_1)
205
+ gt_angle_vf_vec = torch.cat([torch.sin(gt_angle_vf),torch.cos(gt_angle_vf)],dim=-1)
206
+ pred_angle_vf = torus.tor_logmap(angles_t, pred_angles_1)
207
+ pred_angle_vf_vec = torch.cat([torch.sin(pred_angle_vf),torch.cos(pred_angle_vf)],dim=-1)
208
+ # angle_loss = torch.sum(((gt_angle_vf_vec - pred_angle_vf_vec) * norm_scale)**2*gen_mask[...,None],dim=(-1,-2)) / ((torch.sum(gen_mask,dim=-1)) + 1e-8) # (B,)
209
+ angle_loss = torch.sum(((gt_angle_vf_vec - pred_angle_vf_vec) * norm_scale)**2*angle_mask_loss,dim=(-1,-2)) / (torch.sum(angle_mask_loss,dim=(-1,-2)) + 1e-8) # (B,)
210
+ angle_loss = torch.mean(angle_loss)
211
+
212
+
213
+ # angle aux loss
214
+ angles_1_vec = torch.cat([torch.sin(angles_1),torch.cos(angles_1)],dim=-1)
215
+ pred_angles_1_vec = torch.cat([torch.sin(pred_angles_1),torch.cos(pred_angles_1)],dim=-1)
216
+ # torsion_loss = torch.sum((pred_angles_1_vec - angles_1_vec)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
217
+ torsion_loss = torch.sum((pred_angles_1_vec - angles_1_vec)**2*angle_mask_loss,dim=(-1,-2)) / (torch.sum(angle_mask_loss,dim=(-1,-2)) + 1e-8) # (B,)
218
+ torsion_loss = torch.mean(torsion_loss)
219
+
220
+ return {
221
+ "trans_loss": trans_loss,
222
+ 'rot_loss': rot_loss,
223
+ 'bb_atom_loss': bb_atom_loss,
224
+ 'seqs_loss': seqs_loss,
225
+ 'angle_loss': angle_loss,
226
+ 'torsion_loss': torsion_loss,
227
+ }
228
+
229
+ @torch.no_grad()
230
+ def sample(self, batch, num_steps = 100, sample_bb=True, sample_ang=True, sample_seq=True):
231
+
232
+ num_batch, num_res = batch['aa'].shape
233
+ gen_mask,res_mask = batch['generate_mask'],batch['res_mask']
234
+ K = self._interpolant_cfg.seqs.num_classes
235
+ k = self._interpolant_cfg.seqs.simplex_value
236
+ angle_mask_loss = torsions_mask.to(batch['aa'].device)
237
+
238
+ #encode
239
+ rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed = self.encode(batch)
240
+ # trans_1_c,center = self.zero_center_part(trans_1,gen_mask,res_mask)
241
+ trans_1_c = trans_1
242
+ seqs_1_simplex = self.seq_to_simplex(seqs_1)
243
+ seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
244
+
245
+ # # # only sample bb, angle and seq with noise
246
+ # angles_1 = torch.where(batch['generate_mask'][...,None],angles_1,torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype))
247
+ # seqs_1 = torch.where(batch['generate_mask'],seqs_1,torch.randint_like(seqs_1,0,20))
248
+ # seqs_1_simplex = self.seq_to_simplex(seqs_1)
249
+ # seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
250
+
251
+ #initial noise
252
+ if sample_bb:
253
+ rotmats_0 = uniform_so3(num_batch,num_res,device=batch['aa'].device)
254
+ rotmats_0 = torch.where(batch['generate_mask'][...,None,None],rotmats_0,rotmats_1)
255
+ trans_0 = torch.randn((num_batch,num_res,3), device=batch['aa'].device) # scale with sigma?
256
+ # move center and receptor
257
+ trans_0_c,center = self.zero_center_part(trans_0,gen_mask,res_mask)
258
+ trans_0_c = torch.where(batch['generate_mask'][...,None],trans_0_c,trans_1_c)
259
+ else:
260
+ rotmats_0 = rotmats_1.detach().clone()
261
+ trans_0_c = trans_1_c.detach().clone()
262
+ if sample_ang:
263
+ # angle noise
264
+ angles_0 = torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype) # (B,L,5)
265
+ angles_0 = torch.where(batch['generate_mask'][...,None],angles_0,angles_1)
266
+ else:
267
+ angles_0 = angles_1.detach().clone()
268
+ if sample_seq:
269
+ seqs_0_simplex = k * torch.randn((num_batch,num_res,K), device=batch['aa'].device)
270
+ seqs_0_prob = F.softmax(seqs_0_simplex,dim=-1)
271
+ seqs_0 = sample_from(seqs_0_prob)
272
+ seqs_0 = torch.where(batch['generate_mask'],seqs_0,seqs_1)
273
+ seqs_0_simplex = torch.where(batch['generate_mask'][...,None],seqs_0_simplex,seqs_1_simplex)
274
+ else:
275
+ seqs_0 = seqs_1.detach().clone()
276
+ seqs_0_prob = seqs_1_prob.detach().clone()
277
+ seqs_0_simplex = seqs_1_simplex.detach().clone()
278
+
279
+ # Set-up time
280
+ ts = torch.linspace(1.e-2, 1.0, num_steps)
281
+ t_1 = ts[0]
282
+ # prot_traj = [{'rotmats':rotmats_0,'trans':trans_0_c,'seqs':seqs_0,'seqs_simplex':seqs_0_simplex,'rotmats_1':rotmats_1,'trans_1':trans_1-center,'seqs_1':seqs_1}]
283
+ clean_traj = []
284
+ rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, seqs_t_1_simplex = rotmats_0, trans_0_c, angles_0, seqs_0, seqs_0_simplex
285
+
286
+ # denoise loop
287
+ for t_2 in ts[1:]:
288
+ t = torch.ones((num_batch, 1), device=batch['aa'].device) * t_1
289
+ # rots
290
+ pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, node_embed, edge_embed, batch['generate_mask'].long(), batch['res_mask'].long())
291
+ pred_rotmats_1 = torch.where(batch['generate_mask'][...,None,None],pred_rotmats_1,rotmats_1)
292
+ # trans, move center
293
+ # pred_trans_1_c,center = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
294
+ pred_trans_1_c = torch.where(batch['generate_mask'][...,None],pred_trans_1,trans_1_c) # move receptor also
295
+ # angles
296
+ pred_angles_1 = torch.where(batch['generate_mask'][...,None],pred_angles_1,angles_1)
297
+ # seqs
298
+ pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
299
+ pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,seqs_1)
300
+ pred_seqs_1_simplex = self.seq_to_simplex(pred_seqs_1)
301
+ # seq-angle
302
+ torsion_mask = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
303
+ pred_angles_1 = torch.where(torsion_mask.bool(),pred_angles_1,torch.zeros_like(pred_angles_1))
304
+ if not sample_bb:
305
+ pred_trans_1_c = trans_1_c.detach().clone()
306
+ # _,center = self.zero_center_part(trans_1,gen_mask,res_mask)
307
+ pred_rotmats_1 = rotmats_1.detach().clone()
308
+ if not sample_ang:
309
+ pred_angles_1 = angles_1.detach().clone()
310
+ if not sample_seq:
311
+ pred_seqs_1 = seqs_1.detach().clone()
312
+ pred_seqs_1_simplex = seqs_1_simplex.detach().clone()
313
+ clean_traj.append({'rotmats':pred_rotmats_1.cpu(),'trans':pred_trans_1_c.cpu(),'angles':pred_angles_1.cpu(),'seqs':pred_seqs_1.cpu(),'seqs_simplex':pred_seqs_1_simplex.cpu(),
314
+ 'rotmats_1':rotmats_1.cpu(),'trans_1':trans_1_c.cpu(),'angles_1':angles_1.cpu(),'seqs_1':seqs_1.cpu()})
315
+ # reverse step, also only for gen mask region
316
+ d_t = (t_2-t_1) * torch.ones((num_batch, 1), device=batch['aa'].device)
317
+ # Euler step
318
+ trans_t_2 = trans_t_1_c + (pred_trans_1_c-trans_0_c)*d_t[...,None]
319
+ # trans_t_2_c,center = self.zero_center_part(trans_t_2,gen_mask,res_mask)
320
+ trans_t_2_c = torch.where(batch['generate_mask'][...,None],trans_t_2,trans_1_c) # move receptor also
321
+ # rotmats_t_2 = so3_utils.geodesic_t(d_t[...,None] / (1-t[...,None]), pred_rotmats_1, rotmats_t_1)
322
+ rotmats_t_2 = so3_utils.geodesic_t(d_t[...,None] * 10, pred_rotmats_1, rotmats_t_1)
323
+ rotmats_t_2 = torch.where(batch['generate_mask'][...,None,None],rotmats_t_2,rotmats_1)
324
+ # angles
325
+ angles_t_2 = torus.tor_geodesic_t(d_t[...,None],pred_angles_1, angles_t_1)
326
+ angles_t_2 = torch.where(batch['generate_mask'][...,None],angles_t_2,angles_1)
327
+ # seqs
328
+ seqs_t_2_simplex = seqs_t_1_simplex + (pred_seqs_1_simplex - seqs_0_simplex) * d_t[...,None]
329
+ seqs_t_2 = sample_from(F.softmax(seqs_t_2_simplex,dim=-1))
330
+ seqs_t_2 = torch.where(batch['generate_mask'],seqs_t_2,seqs_1)
331
+ # seq-angle
332
+ torsion_mask = angle_mask_loss[seqs_t_2.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
333
+ angles_t_2 = torch.where(torsion_mask.bool(),angles_t_2,torch.zeros_like(angles_t_2))
334
+
335
+ if not sample_bb:
336
+ trans_t_2_c = trans_1_c.detach().clone()
337
+ rotmats_t_2 = rotmats_1.detach().clone()
338
+ if not sample_ang:
339
+ angles_t_2 = angles_1.detach().clone()
340
+ if not sample_seq:
341
+ seqs_t_2 = seqs_1.detach().clone()
342
+ rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, seqs_t_1_simplex = rotmats_t_2, trans_t_2_c, angles_t_2, seqs_t_2, seqs_t_2_simplex
343
+ t_1 = t_2
344
+
345
+ # final step
346
+ t_1 = ts[-1]
347
+ t = torch.ones((num_batch, 1), device=batch['aa'].device) * t_1
348
+ pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, node_embed, edge_embed, batch['generate_mask'].long(), batch['res_mask'].long())
349
+ pred_rotmats_1 = torch.where(batch['generate_mask'][...,None,None],pred_rotmats_1,rotmats_1)
350
+ # move center
351
+ # pred_trans_1_c,center = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
352
+ pred_trans_1_c = torch.where(batch['generate_mask'][...,None],pred_trans_1,trans_1_c) # move receptor also
353
+ # angles
354
+ pred_angles_1 = torch.where(batch['generate_mask'][...,None],pred_angles_1,angles_1)
355
+ # seqs
356
+ pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
357
+ pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,seqs_1)
358
+ pred_seqs_1_simplex = self.seq_to_simplex(pred_seqs_1)
359
+ # seq-angle
360
+ torsion_mask = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
361
+ pred_angles_1 = torch.where(torsion_mask.bool(),pred_angles_1,torch.zeros_like(pred_angles_1))
362
+ if not sample_bb:
363
+ pred_trans_1_c = trans_1_c.detach().clone()
364
+ # _,center = self.zero_center_part(trans_1,gen_mask,res_mask)
365
+ pred_rotmats_1 = rotmats_1.detach().clone()
366
+ if not sample_ang:
367
+ pred_angles_1 = angles_1.detach().clone()
368
+ if not sample_seq:
369
+ pred_seqs_1 = seqs_1.detach().clone()
370
+ pred_seqs_1_simplex = seqs_1_simplex.detach().clone()
371
+ clean_traj.append({'rotmats':pred_rotmats_1.cpu(),'trans':pred_trans_1_c.cpu(),'angles':pred_angles_1.cpu(),'seqs':pred_seqs_1.cpu(),'seqs_simplex':pred_seqs_1_simplex.cpu(),
372
+ 'rotmats_1':rotmats_1.cpu(),'trans_1':trans_1_c.cpu(),'angles_1':angles_1.cpu(),'seqs_1':seqs_1.cpu()})
373
+
374
+ return clean_traj
375
+
376
+
377
+ # if __name__ == '__main__':
378
+ # prefix_dir = './pepflowww'
379
+ # # config,cfg_name = load_config("../configs/angle/learn_sc.yaml")
380
+ # config,cfg_name = load_config(os.path.join(prefix_dir,"configs/angle/learn_sc.yaml"))
381
+ # # print(config)
382
+ # device = 'cuda:0'
383
+ # dataset = PepDataset(structure_dir = config.dataset.val.structure_dir, dataset_dir = config.dataset.val.dataset_dir,
384
+ # name = config.dataset.val.name, transform=None, reset=config.dataset.val.reset)
385
+ # dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=PaddingCollate(eight=False), num_workers=4, pin_memory=True)
386
+ # ckpt = torch.load("./checkpoints/600000.pt", map_location=device)
387
+ # seed_all(114514)
388
+ # model = FlowModel(config.model).to(device)
389
+ # model.load_state_dict(process_dic(ckpt['model']))
390
+ # model.eval()
391
+
392
+ # # print(model)
393
+
394
+ # # print(dataset[0]['chain_id'])
395
+ # # print(dataset[0]['id'])
396
+ # # print(dataset[0]['resseq'])
397
+ # # print(dataset[0]['res_nb'])
398
+ # # print(dataset[0]['icode'])
399
+
400
+ # dic = {'id':[],'len':[],'tran':[],'aar':[],'rot':[],'trans_loss':[],'rot_loss':[]}
401
+
402
+ # # for batch in tqdm(dataloader):
403
+ # # batch = recursive_to(batch,device)
404
+ # for i in tqdm(range(len(dataset))):
405
+ # item = dataset[i]
406
+ # data_list = [deepcopy(item) for _ in range(16)]
407
+ # batch = recursive_to(collate_fn(data_list),device)
408
+ # loss_dic = model(batch)
409
+ # # traj_1 = model.sample(batch,num_steps=50,sample_bb=False,sample_ang=True,sample_seq=False)
410
+ # traj_1 = model.sample(batch,num_steps=50,sample_bb=True,sample_ang=True,sample_seq=True)
411
+ # ca_dist = torch.sqrt(torch.sum((traj_1[-1]['trans']-traj_1[-1]['trans_1'])**2*batch['generate_mask'][...,None].cpu().long()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
412
+ # rot_dist = torch.sqrt(torch.sum((traj_1[-1]['rotmats']-traj_1[-1]['rotmats_1'])**2*batch['generate_mask'][...,None,None].long().cpu()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
413
+ # aar = torch.sum((traj_1[-1]['seqs']==traj_1[-1]['seqs_1']) * batch['generate_mask'].long().cpu()) / (torch.sum(batch['generate_mask']).cpu() + 1e-8)
414
+
415
+
416
+ # print(loss_dic)
417
+ # print(f'tran:{ca_dist},rot:{rot_dist},aar:{aar},len:{batch["generate_mask"].sum().item()}')
418
+
419
+ # # free
420
+ # torch.cuda.empty_cache()
421
+ # gc.collect()
422
+
423
+ # # dic['tran'].append(ca_dist.item())
424
+ # # dic['rot'].append(rot_dist.item())
425
+ # dic['aar'].append(aar.item())
426
+ # dic['trans_loss'].append(loss_dic['trans_loss'].item())
427
+ # dic['rot_loss'].append(loss_dic['rot_loss'].item())
428
+ # dic['id'].append(batch['id'][0])
429
+ # dic['len'].append(batch['generate_mask'].sum().item())
430
+ # # # break
431
+
432
+ # # traj_1[-1]['batch'] = batch
433
+ # # torch.save(traj_1[-1],f'/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Pack_new/outputs/{batch["id"][0]}.pt')
434
+
435
+ # # print(dic)
436
+ # # dic = pd.DataFrame(dic)
437
+ # # dic.to_csv(f'/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Pack/outputs.csv',index=None)
438
+
439
+ # print(np.mean(dic['aar']))
440
+ # print(np.mean(dic['trans_loss']))
441
+
442
+ # if __name__ == '__main__':
443
+ # config,cfg_name = load_config("./configs/angle/learn_angle.yaml")
444
+ # seed_all(114514)
445
+ # device = 'cpu'
446
+ # dataset = PepDataset(structure_dir = config.dataset.train.structure_dir, dataset_dir = config.dataset.train.dataset_dir,
447
+ # name = config.dataset.train.name, transform=None, reset=config.dataset.train.reset)
448
+ # dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=PaddingCollate(), num_workers=4, pin_memory=True)
449
+ # model = FlowModel(config.model).to(device)
450
+ # optimizer = torch.optim.Adam(model.parameters(),lr=1.e-4)
451
+
452
+ # # ckpt = torch.load('./checkpoints/90000.pt',map_location=device)
453
+ # # model.load_state_dict(process_dic(ckpt['model']))
454
+ # # optimizer.load_state_dict(ckpt['optimizer'])
455
+
456
+
457
+ # # torch.autograd.set_detect_anomaly(True)
458
+ # for i,batch in tqdm(enumerate(dataloader)):
459
+ # batch = recursive_to(batch,device)
460
+ # loss_dict = model(batch)
461
+ # loss = sum_weighted_losses(loss_dict, config.train.loss_weights)
462
+ # # if torch.isnan(loss):
463
+ # # print(i)
464
+ # # print(batch['id'])
465
+
466
+ # loss.backward()
467
+ # orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
468
+
469
+ # print(f'{loss_dict},{loss},{orig_grad_norm}')
470
+
471
+ # optimizer.step()
472
+ # optimizer.zero_grad()
models_con/ga.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from models_con import ipa_pytorch as ipa_pytorch
5
+ from data import utils as du
6
+
7
+ from models_con.utils import get_index_embedding, get_time_embedding
8
+
9
+ from pepflow.modules.protein.constants import ANG_TO_NM_SCALE, NM_TO_ANG_SCALE
10
+ from pepflow.modules.common.layers import AngularEncoding
11
+
12
+ import math
13
+
14
+
15
+ class GAEncoder(nn.Module):
16
+ def __init__(self, ipa_conf):
17
+ super().__init__()
18
+ self._ipa_conf = ipa_conf
19
+
20
+ # angles
21
+ self.angles_embedder = AngularEncoding(num_funcs=12) # 25*5=120, for competitive embedding size
22
+ self.angle_net = nn.Sequential(
23
+ nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
24
+ nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
25
+ nn.Linear(self._ipa_conf.c_s, 5)
26
+ # nn.Linear(self._ipa_conf.c_s, 22)
27
+ )
28
+
29
+ # for condition on current seq
30
+ self.current_seq_embedder = nn.Embedding(22, self._ipa_conf.c_s)
31
+ self.seq_net = nn.Sequential(
32
+ nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
33
+ nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
34
+ nn.Linear(self._ipa_conf.c_s, 20)
35
+ # nn.Linear(self._ipa_conf.c_s, 22)
36
+ )
37
+
38
+ # mixer
39
+ self.res_feat_mixer = nn.Sequential(
40
+ nn.Linear(3 * self._ipa_conf.c_s + self.angles_embedder.get_out_dim(in_dim=5), self._ipa_conf.c_s),
41
+ nn.ReLU(),
42
+ nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
43
+ )
44
+
45
+ self.feat_dim = self._ipa_conf.c_s
46
+
47
+ # Attention trunk
48
+ self.trunk = nn.ModuleDict()
49
+ for b in range(self._ipa_conf.num_blocks):
50
+ self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf)
51
+ self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s)
52
+ tfmr_in = self._ipa_conf.c_s
53
+ tfmr_layer = torch.nn.TransformerEncoderLayer(
54
+ d_model=tfmr_in,
55
+ nhead=self._ipa_conf.seq_tfmr_num_heads,
56
+ dim_feedforward=tfmr_in,
57
+ batch_first=True,
58
+ dropout=0.0,
59
+ norm_first=False
60
+ )
61
+ self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
62
+ tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False)
63
+ self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(
64
+ tfmr_in, self._ipa_conf.c_s, init="final")
65
+ self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(
66
+ c=self._ipa_conf.c_s)
67
+ self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(
68
+ self._ipa_conf.c_s, use_rot_updates=True)
69
+
70
+ if b < self._ipa_conf.num_blocks-1:
71
+ # No edge update on the last block.
72
+ edge_in = self._ipa_conf.c_z
73
+ self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition(
74
+ node_embed_size=self._ipa_conf.c_s,
75
+ edge_embed_in=edge_in,
76
+ edge_embed_out=self._ipa_conf.c_z,
77
+ )
78
+
79
+ def embed_t(self, timesteps, mask):
80
+ timestep_emb = get_time_embedding(
81
+ timesteps[:, 0],
82
+ self.feat_dim,
83
+ max_positions=2056
84
+ )[:, None, :].repeat(1, mask.shape[1], 1)
85
+ return timestep_emb
86
+
87
+ def forward(self, t, rotmats_t, trans_t, angles_t, seqs_t, node_embed, edge_embed, generate_mask, res_mask):
88
+ num_batch, num_res = seqs_t.shape
89
+
90
+ # incorperate current seq and timesteps
91
+ node_mask = res_mask
92
+ edge_mask = node_mask[:, None] * node_mask[:, :, None]
93
+
94
+ node_embed = self.res_feat_mixer(torch.cat([node_embed, self.current_seq_embedder(seqs_t), self.embed_t(t,node_mask), self.angles_embedder(angles_t).reshape(num_batch,num_res,-1)],dim=-1))
95
+ node_embed = node_embed * node_mask[..., None]
96
+ curr_rigids = du.create_rigid(rotmats_t, trans_t)
97
+ for b in range(self._ipa_conf.num_blocks):
98
+ ipa_embed = self.trunk[f'ipa_{b}'](
99
+ node_embed,
100
+ edge_embed,
101
+ curr_rigids,
102
+ node_mask)
103
+ ipa_embed *= node_mask[..., None]
104
+ node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
105
+ seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
106
+ node_embed, src_key_padding_mask=(1 - node_mask).bool())
107
+ node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
108
+ node_embed = self.trunk[f'node_transition_{b}'](node_embed)
109
+ node_embed = node_embed * node_mask[..., None]
110
+ rigid_update = self.trunk[f'bb_update_{b}'](
111
+ node_embed * node_mask[..., None])
112
+ curr_rigids = curr_rigids.compose_q_update_vec(
113
+ rigid_update, node_mask[..., None])
114
+
115
+ if b < self._ipa_conf.num_blocks-1:
116
+ edge_embed = self.trunk[f'edge_transition_{b}'](
117
+ node_embed, edge_embed)
118
+ edge_embed *= edge_mask[..., None]
119
+
120
+ # curr_rigids = self.rigids_nm_to_ang(curr_rigids)
121
+ pred_trans1 = curr_rigids.get_trans()
122
+ pred_rotmats1 = curr_rigids.get_rots().get_rot_mats()
123
+ pred_seqs1_prob = self.seq_net(node_embed)
124
+ pred_angles1 = self.angle_net(node_embed)
125
+ pred_angles1 = pred_angles1 % (2*math.pi) # inductive bias to bound between (0,2pi)
126
+
127
+ return pred_rotmats1, pred_trans1, pred_angles1, pred_seqs1_prob
models_con/inference.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+
7
+ import copy
8
+ import math
9
+ from tqdm.auto import tqdm
10
+ import functools
11
+ import os
12
+ import argparse
13
+ import pandas as pd
14
+ from copy import deepcopy
15
+
16
+ from models_con.pep_dataloader import PepDataset
17
+
18
+ from pepflow.utils.misc import load_config
19
+ from pepflow.utils.train import recursive_to
20
+
21
+ from pepflow.modules.common.geometry import reconstruct_backbone, reconstruct_backbone_partially, align, batch_align
22
+ from pepflow.modules.protein.writers import save_pdb
23
+
24
+ from pepflow.utils.data import PaddingCollate
25
+
26
+ from models_con.utils import process_dic
27
+
28
+ import gc
29
+
30
+ from models_con.flow_model import FlowModel
31
+
32
+ from pepflow.utils.misc import seed_all
33
+
34
+ from models_con.torsion import full_atom_reconstruction, get_heavyatom_mask
35
+
36
+ collate_fn = PaddingCollate(eight=False)
37
+
38
+ import argparse
39
+
40
+
41
+ if __name__ == '__main__':
42
+ args = argparse.ArgumentParser()
43
+ args.add_argument('--config', type=str)
44
+ args.add_argument('--device', type=str)
45
+ args.add_argument('--ckpt', type=str)
46
+ args.add_argument('--output', type=str)
47
+ args.add_argument('--num_steps', type=int, default=200)
48
+ args.add_argument('--num_samples', type=int, default=64)
49
+ args.add_argument('--sample_bb', type=bool, default=True)
50
+ args.add_argument('--sample_ang', type=bool, default=True)
51
+ args.add_argument('--sample_seq', type=bool, default=True)
52
+ args.add_argument('--num_samples', type=int, default=64)
53
+ args.add_argument('--num_samples', type=int, default=64)
54
+ parser = args.parse_args()
55
+
56
+ config,cfg_name = load_config(parser.config)
57
+ device = parser.device
58
+ dataset = PepDataset(structure_dir = config.dataset.val.structure_dir, dataset_dir = config.dataset.val.dataset_dir,
59
+ name = config.dataset.val.name, transform=None, reset=config.dataset.val.reset)
60
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=PaddingCollate(eight=False), num_workers=4, pin_memory=True)
61
+ ckpt = torch.load(parser.ckpt, map_location=device)
62
+
63
+ seed_all(114514)
64
+ model = FlowModel(config.model).to(device)
65
+ model.load_state_dict(process_dic(ckpt['model']))
66
+ model.eval()
67
+
68
+
69
+ dic = {'id':[],'len':[],'tran':[],'aar':[],'rot':[],'trans_loss':[],'rot_loss':[]}
70
+
71
+ for i in tqdm(range(len(dataset))):
72
+ item = dataset[i]
73
+ data_list = [deepcopy(item) for _ in range(parser.num_samples)]
74
+ batch = recursive_to(collate_fn(data_list),device)
75
+ loss_dic = model(batch)
76
+ traj_1 = model.sample(batch,num_steps=parser.num_steps,sample_bb=parser.sample_bb,sample_ang=parser.sample_ang,sample_seq=parser.sample_seq)
77
+ ca_dist = torch.sqrt(torch.sum((traj_1[-1]['trans']-traj_1[-1]['trans_1'])**2*batch['generate_mask'][...,None].cpu().long()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
78
+ rot_dist = torch.sqrt(torch.sum((traj_1[-1]['rotmats']-traj_1[-1]['rotmats_1'])**2*batch['generate_mask'][...,None,None].long().cpu()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
79
+ aar = torch.sum((traj_1[-1]['seqs']==traj_1[-1]['seqs_1']) * batch['generate_mask'].long().cpu()) / (torch.sum(batch['generate_mask']).cpu() + 1e-8)
80
+
81
+
82
+ print(loss_dic)
83
+ print(f'tran:{ca_dist},rot:{rot_dist},aar:{aar},len:{batch["generate_mask"].sum().item()}')
84
+
85
+ # free
86
+ torch.cuda.empty_cache()
87
+ gc.collect()
88
+
89
+ dic['tran'].append(ca_dist.item())
90
+ dic['rot'].append(rot_dist.item())
91
+ dic['aar'].append(aar.item())
92
+ dic['trans_loss'].append(loss_dic['trans_loss'].item())
93
+ dic['rot_loss'].append(loss_dic['rot_loss'].item())
94
+ dic['id'].append(batch['id'][0])
95
+ dic['len'].append(batch['generate_mask'].sum().item())
96
+ # break
97
+
98
+ traj_1[-1]['batch'] = batch
99
+ torch.save(traj_1[-1],f'{parser.output}/outputs/{batch["id"][0]}.pt')
100
+ dic = pd.DataFrame(dic)
101
+ dic.to_csv(f'{parser.output}/outputs.csv',index=None)
models_con/ipa_pytorch.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Modified code of Openfold's IPA."""
17
+
18
+ import numpy as np
19
+ import torch
20
+ import math
21
+ from scipy.stats import truncnorm
22
+ import torch.nn as nn
23
+ from typing import Optional, Callable, List, Sequence
24
+ from openfold.utils.rigid_utils import Rigid
25
+ from data import all_atom
26
+
27
+
28
+ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
29
+ zero_index = -1 * len(inds)
30
+ first_inds = list(range(len(tensor.shape[:zero_index])))
31
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
32
+
33
+
34
+ def flatten_final_dims(t: torch.Tensor, no_dims: int):
35
+ return t.reshape(t.shape[:-no_dims] + (-1,))
36
+
37
+
38
+ def ipa_point_weights_init_(weights):
39
+ with torch.no_grad():
40
+ softplus_inverse_1 = 0.541324854612918
41
+ weights.fill_(softplus_inverse_1)
42
+
43
+ def _prod(nums):
44
+ out = 1
45
+ for n in nums:
46
+ out = out * n
47
+ return out
48
+
49
+
50
+ def _calculate_fan(linear_weight_shape, fan="fan_in"):
51
+ fan_out, fan_in = linear_weight_shape
52
+
53
+ if fan == "fan_in":
54
+ f = fan_in
55
+ elif fan == "fan_out":
56
+ f = fan_out
57
+ elif fan == "fan_avg":
58
+ f = (fan_in + fan_out) / 2
59
+ else:
60
+ raise ValueError("Invalid fan option")
61
+
62
+ return f
63
+
64
+ def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
65
+ shape = weights.shape
66
+ f = _calculate_fan(shape, fan)
67
+ scale = scale / max(1, f)
68
+ a = -2
69
+ b = 2
70
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
71
+ size = _prod(shape)
72
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
73
+ samples = np.reshape(samples, shape)
74
+ with torch.no_grad():
75
+ weights.copy_(torch.tensor(samples, device=weights.device))
76
+
77
+
78
+ def lecun_normal_init_(weights):
79
+ trunc_normal_init_(weights, scale=1.0)
80
+
81
+
82
+ def he_normal_init_(weights):
83
+ trunc_normal_init_(weights, scale=2.0)
84
+
85
+
86
+ def glorot_uniform_init_(weights):
87
+ nn.init.xavier_uniform_(weights, gain=1)
88
+
89
+
90
+ def final_init_(weights):
91
+ with torch.no_grad():
92
+ weights.fill_(0.0)
93
+
94
+
95
+ def gating_init_(weights):
96
+ with torch.no_grad():
97
+ weights.fill_(0.0)
98
+
99
+
100
+ def normal_init_(weights):
101
+ torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
102
+
103
+
104
+ def compute_angles(ca_pos, pts):
105
+ batch_size, num_res, num_heads, num_pts, _ = pts.shape
106
+ calpha_vecs = (ca_pos[:, :, None, :] - ca_pos[:, None, :, :]) + 1e-10
107
+ calpha_vecs = torch.tile(calpha_vecs[:, :, :, None, None, :], (1, 1, 1, num_heads, num_pts, 1))
108
+ ipa_pts = pts[:, :, None, :, :, :] - torch.tile(ca_pos[:, :, None, None, None, :], (1, 1, num_res, num_heads, num_pts, 1))
109
+ phi_angles = all_atom.calculate_neighbor_angles(
110
+ calpha_vecs.reshape(-1, 3),
111
+ ipa_pts.reshape(-1, 3)
112
+ ).reshape(batch_size, num_res, num_res, num_heads, num_pts)
113
+ return phi_angles
114
+
115
+
116
+ class Linear(nn.Linear):
117
+ """
118
+ A Linear layer with built-in nonstandard initializations. Called just
119
+ like torch.nn.Linear.
120
+
121
+ Implements the initializers in 1.11.4, plus some additional ones found
122
+ in the code.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ in_dim: int,
128
+ out_dim: int,
129
+ bias: bool = True,
130
+ init: str = "default",
131
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
132
+ ):
133
+ """
134
+ Args:
135
+ in_dim:
136
+ The final dimension of inputs to the layer
137
+ out_dim:
138
+ The final dimension of layer outputs
139
+ bias:
140
+ Whether to learn an additive bias. True by default
141
+ init:
142
+ The initializer to use. Choose from:
143
+
144
+ "default": LeCun fan-in truncated normal initialization
145
+ "relu": He initialization w/ truncated normal distribution
146
+ "glorot": Fan-average Glorot uniform initialization
147
+ "gating": Weights=0, Bias=1
148
+ "normal": Normal initialization with std=1/sqrt(fan_in)
149
+ "final": Weights=0, Bias=0
150
+
151
+ Overridden by init_fn if the latter is not None.
152
+ init_fn:
153
+ A custom initializer taking weight and bias as inputs.
154
+ Overrides init if not None.
155
+ """
156
+ super(Linear, self).__init__(in_dim, out_dim, bias=bias)
157
+
158
+ if bias:
159
+ with torch.no_grad():
160
+ self.bias.fill_(0)
161
+
162
+ if init_fn is not None:
163
+ init_fn(self.weight, self.bias)
164
+ else:
165
+ if init == "default":
166
+ lecun_normal_init_(self.weight)
167
+ elif init == "relu":
168
+ he_normal_init_(self.weight)
169
+ elif init == "glorot":
170
+ glorot_uniform_init_(self.weight)
171
+ elif init == "gating":
172
+ gating_init_(self.weight)
173
+ if bias:
174
+ with torch.no_grad():
175
+ self.bias.fill_(1.0)
176
+ elif init == "normal":
177
+ normal_init_(self.weight)
178
+ elif init == "final":
179
+ final_init_(self.weight)
180
+ else:
181
+ raise ValueError("Invalid init string.")
182
+
183
+
184
+ class StructureModuleTransition(nn.Module):
185
+ def __init__(self, c):
186
+ super(StructureModuleTransition, self).__init__()
187
+
188
+ self.c = c
189
+
190
+ self.linear_1 = Linear(self.c, self.c, init="relu")
191
+ self.linear_2 = Linear(self.c, self.c, init="relu")
192
+ self.linear_3 = Linear(self.c, self.c, init="final")
193
+ self.relu = nn.ReLU()
194
+ self.ln = nn.LayerNorm(self.c)
195
+
196
+ def forward(self, s):
197
+ s_initial = s
198
+ s = self.linear_1(s)
199
+ s = self.relu(s)
200
+ s = self.linear_2(s)
201
+ s = self.relu(s)
202
+ s = self.linear_3(s)
203
+ s = s + s_initial
204
+ s = self.ln(s)
205
+
206
+ return s
207
+
208
+
209
+ class EdgeTransition(nn.Module):
210
+ def __init__(
211
+ self,
212
+ *,
213
+ node_embed_size,
214
+ edge_embed_in,
215
+ edge_embed_out,
216
+ num_layers=2,
217
+ node_dilation=2
218
+ ):
219
+ super(EdgeTransition, self).__init__()
220
+
221
+ bias_embed_size = node_embed_size // node_dilation
222
+ self.initial_embed = Linear(
223
+ node_embed_size, bias_embed_size, init="relu")
224
+ hidden_size = bias_embed_size * 2 + edge_embed_in
225
+ trunk_layers = []
226
+ for _ in range(num_layers):
227
+ trunk_layers.append(Linear(hidden_size, hidden_size, init="relu"))
228
+ trunk_layers.append(nn.ReLU())
229
+ self.trunk = nn.Sequential(*trunk_layers)
230
+ self.final_layer = Linear(hidden_size, edge_embed_out, init="final")
231
+ self.layer_norm = nn.LayerNorm(edge_embed_out)
232
+
233
+ def forward(self, node_embed, edge_embed):
234
+ node_embed = self.initial_embed(node_embed)
235
+ batch_size, num_res, _ = node_embed.shape
236
+ edge_bias = torch.cat([
237
+ torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)),
238
+ torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)),
239
+ ], axis=-1)
240
+ edge_embed = torch.cat(
241
+ [edge_embed, edge_bias], axis=-1).reshape(
242
+ batch_size * num_res**2, -1)
243
+ edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed)
244
+ edge_embed = self.layer_norm(edge_embed)
245
+ edge_embed = edge_embed.reshape(
246
+ batch_size, num_res, num_res, -1
247
+ )
248
+ return edge_embed
249
+
250
+
251
+ class InvariantPointAttention(nn.Module):
252
+ """
253
+ Implements Algorithm 22.
254
+ """
255
+ def __init__(
256
+ self,
257
+ ipa_conf,
258
+ inf: float = 1e5,
259
+ eps: float = 1e-8,
260
+ ):
261
+ """
262
+ Args:
263
+ c_s:
264
+ Single representation channel dimension
265
+ c_z:
266
+ Pair representation channel dimension
267
+ c_hidden:
268
+ Hidden channel dimension
269
+ no_heads:
270
+ Number of attention heads
271
+ no_qk_points:
272
+ Number of query/key points to generate
273
+ no_v_points:
274
+ Number of value points to generate
275
+ """
276
+ super(InvariantPointAttention, self).__init__()
277
+ self._ipa_conf = ipa_conf
278
+
279
+ self.c_s = ipa_conf.c_s
280
+ self.c_z = ipa_conf.c_z
281
+ self.c_hidden = ipa_conf.c_hidden
282
+ self.no_heads = ipa_conf.no_heads
283
+ self.no_qk_points = ipa_conf.no_qk_points
284
+ self.no_v_points = ipa_conf.no_v_points
285
+ self.inf = inf
286
+ self.eps = eps
287
+
288
+ # These linear layers differ from their specifications in the
289
+ # supplement. There, they lack bias and use Glorot initialization.
290
+ # Here as in the official source, they have bias and use the default
291
+ # Lecun initialization.
292
+ hc = self.c_hidden * self.no_heads
293
+ self.linear_q = Linear(self.c_s, hc)
294
+ self.linear_kv = Linear(self.c_s, 2 * hc)
295
+
296
+ hpq = self.no_heads * self.no_qk_points * 3
297
+ self.linear_q_points = Linear(self.c_s, hpq)
298
+
299
+ hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
300
+ self.linear_kv_points = Linear(self.c_s, hpkv)
301
+
302
+ self.linear_b = Linear(self.c_z, self.no_heads)
303
+ self.down_z = Linear(self.c_z, self.c_z // 4)
304
+
305
+ self.head_weights = nn.Parameter(torch.zeros((ipa_conf.no_heads)))
306
+ ipa_point_weights_init_(self.head_weights)
307
+
308
+ concat_out_dim = (
309
+ self.c_z // 4 + self.c_hidden + self.no_v_points * 4
310
+ )
311
+ self.linear_out = Linear(self.no_heads * concat_out_dim, self.c_s, init="final")
312
+
313
+ self.softmax = nn.Softmax(dim=-1)
314
+ self.softplus = nn.Softplus()
315
+
316
+ def forward(
317
+ self,
318
+ s: torch.Tensor,
319
+ z: Optional[torch.Tensor],
320
+ r: Rigid,
321
+ mask: torch.Tensor,
322
+ _offload_inference: bool = False,
323
+ _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
324
+ ) -> torch.Tensor:
325
+ """
326
+ Args:
327
+ s:
328
+ [*, N_res, C_s] single representation
329
+ z:
330
+ [*, N_res, N_res, C_z] pair representation
331
+ r:
332
+ [*, N_res] transformation object
333
+ mask:
334
+ [*, N_res] mask
335
+ Returns:
336
+ [*, N_res, C_s] single representation update
337
+ """
338
+ if _offload_inference:
339
+ z = _z_reference_list
340
+ else:
341
+ z = [z]
342
+
343
+ #######################################
344
+ # Generate scalar and point activations
345
+ #######################################
346
+ # [*, N_res, H * C_hidden]
347
+ q = self.linear_q(s)
348
+ kv = self.linear_kv(s)
349
+
350
+ # [*, N_res, H, C_hidden]
351
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
352
+
353
+ # [*, N_res, H, 2 * C_hidden]
354
+ kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
355
+
356
+ # [*, N_res, H, C_hidden]
357
+ k, v = torch.split(kv, self.c_hidden, dim=-1)
358
+
359
+ # [*, N_res, H * P_q * 3]
360
+ q_pts = self.linear_q_points(s)
361
+
362
+ # This is kind of clunky, but it's how the original does it
363
+ # [*, N_res, H * P_q, 3]
364
+ q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
365
+ q_pts = torch.stack(q_pts, dim=-1)
366
+ q_pts = r[..., None].apply(q_pts)
367
+
368
+ # [*, N_res, H, P_q, 3]
369
+ q_pts = q_pts.view(
370
+ q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
371
+ )
372
+
373
+ # [*, N_res, H * (P_q + P_v) * 3]
374
+ kv_pts = self.linear_kv_points(s)
375
+
376
+ # [*, N_res, H * (P_q + P_v), 3]
377
+ kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
378
+ kv_pts = torch.stack(kv_pts, dim=-1)
379
+ kv_pts = r[..., None].apply(kv_pts)
380
+
381
+ # [*, N_res, H, (P_q + P_v), 3]
382
+ kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
383
+
384
+ # [*, N_res, H, P_q/P_v, 3]
385
+ k_pts, v_pts = torch.split(
386
+ kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
387
+ )
388
+
389
+ ##########################
390
+ # Compute attention scores
391
+ ##########################
392
+ # [*, N_res, N_res, H]
393
+ b = self.linear_b(z[0])
394
+
395
+ if(_offload_inference):
396
+ z[0] = z[0].cpu()
397
+
398
+ # [*, H, N_res, N_res]
399
+ a = torch.matmul(
400
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
401
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
402
+ )
403
+ a *= math.sqrt(1.0 / (3 * self.c_hidden))
404
+ a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
405
+
406
+ # [*, N_res, N_res, H, P_q, 3]
407
+ pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
408
+ pt_att = pt_displacement ** 2
409
+
410
+ # [*, N_res, N_res, H, P_q]
411
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
412
+ head_weights = self.softplus(self.head_weights).view(
413
+ *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
414
+ )
415
+ head_weights = head_weights * math.sqrt(
416
+ 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
417
+ )
418
+ pt_att = pt_att * head_weights
419
+
420
+ # [*, N_res, N_res, H]
421
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
422
+ # [*, N_res, N_res]
423
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
424
+ square_mask = self.inf * (square_mask - 1)
425
+
426
+ # [*, H, N_res, N_res]
427
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
428
+
429
+ a = a + pt_att
430
+ a = a + square_mask.unsqueeze(-3)
431
+ a = self.softmax(a)
432
+
433
+ ################
434
+ # Compute output
435
+ ################
436
+ # [*, N_res, H, C_hidden]
437
+ o = torch.matmul(
438
+ a, v.transpose(-2, -3)
439
+ ).transpose(-2, -3)
440
+
441
+ # [*, N_res, H * C_hidden]
442
+ o = flatten_final_dims(o, 2)
443
+
444
+ # [*, H, 3, N_res, P_v]
445
+ o_pt = torch.sum(
446
+ (
447
+ a[..., None, :, :, None]
448
+ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
449
+ ),
450
+ dim=-2,
451
+ )
452
+
453
+ # [*, N_res, H, P_v, 3]
454
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
455
+ o_pt = r[..., None, None].invert_apply(o_pt)
456
+
457
+ # [*, N_res, H * P_v]
458
+ o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps)
459
+ o_pt_norm_feats = flatten_final_dims(
460
+ o_pt_dists, 2)
461
+
462
+ # [*, N_res, H * P_v, 3]
463
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
464
+
465
+ if(_offload_inference):
466
+ z[0] = z[0].to(o_pt.device)
467
+
468
+ # [*, N_res, H, C_z // 4]
469
+ pair_z = self.down_z(z[0])
470
+ o_pair = torch.matmul(a.transpose(-2, -3), pair_z)
471
+
472
+ # [*, N_res, H * C_z // 4]
473
+ o_pair = flatten_final_dims(o_pair, 2)
474
+
475
+ o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair]
476
+
477
+ # [*, N_res, C_s]
478
+ s = self.linear_out(
479
+ torch.cat(
480
+ o_feats, dim=-1
481
+ )
482
+ )
483
+
484
+ return s
485
+
486
+
487
+ class TorsionAngles(nn.Module):
488
+ def __init__(self, c, num_torsions, eps=1e-8):
489
+ super(TorsionAngles, self).__init__()
490
+
491
+ self.c = c
492
+ self.eps = eps
493
+ self.num_torsions = num_torsions
494
+
495
+ self.linear_1 = Linear(self.c, self.c, init="relu")
496
+ self.linear_2 = Linear(self.c, self.c, init="relu")
497
+ # TODO: Remove after published checkpoint is updated without these weights.
498
+ self.linear_3 = Linear(self.c, self.c, init="final")
499
+ self.linear_final = Linear(
500
+ self.c, self.num_torsions * 2, init="final")
501
+
502
+ self.relu = nn.ReLU()
503
+
504
+ def forward(self, s):
505
+ s_initial = s
506
+ s = self.linear_1(s)
507
+ s = self.relu(s)
508
+ s = self.linear_2(s)
509
+
510
+ s = s + s_initial
511
+ unnormalized_s = self.linear_final(s)
512
+ norm_denom = torch.sqrt(
513
+ torch.clamp(
514
+ torch.sum(unnormalized_s ** 2, dim=-1, keepdim=True),
515
+ min=self.eps,
516
+ )
517
+ )
518
+ normalized_s = unnormalized_s / norm_denom
519
+
520
+ return unnormalized_s, normalized_s
521
+
522
+
523
+ class RotationVFLayer(nn.Module):
524
+ def __init__(self, dim):
525
+ super(RotationVFLayer, self).__init__()
526
+
527
+ self.linear_1 = Linear(dim, dim, init="relu")
528
+ self.linear_2 = Linear(dim, dim, init="relu")
529
+ self.linear_3 = Linear(dim, dim)
530
+ self.final_linear = Linear(dim, 6, init="final")
531
+ self.relu = nn.ReLU()
532
+
533
+ def forward(self, s):
534
+ s_initial = s
535
+ s = self.linear_1(s)
536
+ s = self.relu(s)
537
+ s = self.linear_2(s)
538
+ s = self.relu(s)
539
+ s = self.linear_3(s)
540
+ s = s + s_initial
541
+ return self.final_linear(s)
542
+
543
+
544
+ class BackboneUpdate(nn.Module):
545
+ """
546
+ Implements part of Algorithm 23.
547
+ """
548
+
549
+ def __init__(self, c_s, use_rot_updates):
550
+ """
551
+ Args:
552
+ c_s:
553
+ Single representation channel dimension
554
+ """
555
+ super(BackboneUpdate, self).__init__()
556
+
557
+ self.c_s = c_s
558
+ self._use_rot_updates = use_rot_updates
559
+ update_dim = 6 if use_rot_updates else 3
560
+ self.linear = Linear(self.c_s, update_dim, init="final")
561
+
562
+ def forward(self, s: torch.Tensor):
563
+ """
564
+ Args:
565
+ [*, N_res, C_s] single representation
566
+ Returns:
567
+ [*, N_res, 6] update vector
568
+ """
569
+ # [*, 6]
570
+ update = self.linear(s)
571
+
572
+ return update
573
+
574
+ class IpaScore(nn.Module):
575
+
576
+ def __init__(self, model_conf, diffuser):
577
+ super(IpaScore, self).__init__()
578
+ self._model_conf = model_conf
579
+ ipa_conf = model_conf.ipa
580
+ self._ipa_conf = ipa_conf
581
+ self.diffuser = diffuser
582
+
583
+ self.scale_pos = lambda x: x * ipa_conf.coordinate_scaling
584
+ self.scale_rigids = lambda x: x.apply_trans_fn(self.scale_pos)
585
+
586
+ self.unscale_pos = lambda x: x / ipa_conf.coordinate_scaling
587
+ self.unscale_rigids = lambda x: x.apply_trans_fn(self.unscale_pos)
588
+ self.trunk = nn.ModuleDict()
589
+
590
+ for b in range(ipa_conf.num_blocks):
591
+ self.trunk[f'ipa_{b}'] = InvariantPointAttention(ipa_conf)
592
+ self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(ipa_conf.c_s)
593
+ self.trunk[f'skip_embed_{b}'] = Linear(
594
+ self._model_conf.node_embed_size,
595
+ self._ipa_conf.c_skip,
596
+ init="final"
597
+ )
598
+ tfmr_in = ipa_conf.c_s + self._ipa_conf.c_skip
599
+ tfmr_layer = torch.nn.TransformerEncoderLayer(
600
+ d_model=tfmr_in,
601
+ nhead=ipa_conf.seq_tfmr_num_heads,
602
+ dim_feedforward=tfmr_in,
603
+ batch_first=True,
604
+ dropout=0.0,
605
+ norm_first=False
606
+ )
607
+ self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
608
+ tfmr_layer, ipa_conf.seq_tfmr_num_layers)
609
+ self.trunk[f'post_tfmr_{b}'] = Linear(
610
+ tfmr_in, ipa_conf.c_s, init="final")
611
+ self.trunk[f'node_transition_{b}'] = StructureModuleTransition(
612
+ c=ipa_conf.c_s)
613
+ self.trunk[f'bb_update_{b}'] = BackboneUpdate(ipa_conf.c_s)
614
+
615
+ if b < ipa_conf.num_blocks-1:
616
+ # No edge update on the last block.
617
+ edge_in = self._model_conf.edge_embed_size
618
+ self.trunk[f'edge_transition_{b}'] = EdgeTransition(
619
+ node_embed_size=ipa_conf.c_s,
620
+ edge_embed_in=edge_in,
621
+ edge_embed_out=self._model_conf.edge_embed_size,
622
+ )
623
+
624
+ self.torsion_pred = TorsionAngles(ipa_conf.c_s, 1)
625
+
626
+ def forward(self, init_node_embed, edge_embed, input_feats):
627
+ node_mask = input_feats['res_mask'].type(torch.float32)
628
+ diffuse_mask = (1 - input_feats['fixed_mask'].type(torch.float32)) * node_mask
629
+ edge_mask = node_mask[..., None] * node_mask[..., None, :]
630
+ init_frames = input_feats['rigids_t'].type(torch.float32)
631
+
632
+ curr_rigids = Rigid.from_tensor_7(torch.clone(init_frames))
633
+ init_rigids = Rigid.from_tensor_7(init_frames)
634
+ init_rots = init_rigids.get_rots()
635
+
636
+ # Main trunk
637
+ curr_rigids = self.scale_rigids(curr_rigids)
638
+ init_node_embed = init_node_embed * node_mask[..., None]
639
+ node_embed = init_node_embed * node_mask[..., None]
640
+ for b in range(self._ipa_conf.num_blocks):
641
+ ipa_embed = self.trunk[f'ipa_{b}'](
642
+ node_embed,
643
+ edge_embed,
644
+ curr_rigids,
645
+ node_mask)
646
+ ipa_embed *= node_mask[..., None]
647
+ node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
648
+ seq_tfmr_in = torch.cat([
649
+ node_embed, self.trunk[f'skip_embed_{b}'](init_node_embed)
650
+ ], dim=-1)
651
+ seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
652
+ seq_tfmr_in, src_key_padding_mask=1 - node_mask)
653
+ node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
654
+ node_embed = self.trunk[f'node_transition_{b}'](node_embed)
655
+ node_embed = node_embed * node_mask[..., None]
656
+ rigid_update = self.trunk[f'bb_update_{b}'](
657
+ node_embed * diffuse_mask[..., None])
658
+ curr_rigids = curr_rigids.compose_q_update_vec(
659
+ rigid_update, diffuse_mask[..., None])
660
+
661
+ if b < self._ipa_conf.num_blocks-1:
662
+ edge_embed = self.trunk[f'edge_transition_{b}'](
663
+ node_embed, edge_embed)
664
+ edge_embed *= edge_mask[..., None]
665
+ rot_score = self.diffuser.calc_rot_score(
666
+ init_rigids.get_rots(),
667
+ curr_rigids.get_rots(),
668
+ input_feats['t']
669
+ )
670
+ rot_score = rot_score * node_mask[..., None]
671
+
672
+ curr_rigids = self.unscale_rigids(curr_rigids)
673
+ trans_score = self.diffuser.calc_trans_score(
674
+ init_rigids.get_trans(),
675
+ curr_rigids.get_trans(),
676
+ input_feats['t'][:, None, None],
677
+ use_torch=True,
678
+ )
679
+ trans_score = trans_score * node_mask[..., None]
680
+ _, psi_pred = self.torsion_pred(node_embed)
681
+ model_out = {
682
+ 'psi': psi_pred,
683
+ 'rot_score': rot_score,
684
+ 'trans_score': trans_score,
685
+ 'final_rigids': curr_rigids,
686
+ }
687
+ return model_out
models_con/node.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from pepflow.modules.common.geometry import construct_3d_basis, global_to_local, get_backbone_dihedral_angles
5
+ from pepflow.modules.common.layers import AngularEncoding
6
+ from pepflow.modules.protein.constants import BBHeavyAtom, AA
7
+
8
+
9
+ class NodeEmbedder(nn.Module):
10
+
11
+ def __init__(self, feat_dim, max_num_atoms, max_aa_types=22):
12
+ super().__init__()
13
+ self.max_num_atoms = max_num_atoms
14
+ self.max_aa_types = max_aa_types
15
+ self.feat_dim = feat_dim
16
+ self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim)
17
+ self.dihed_embed = AngularEncoding()
18
+
19
+ infeat_dim = feat_dim + (self.max_aa_types*max_num_atoms*3) + self.dihed_embed.get_out_dim(3)
20
+ self.mlp = nn.Sequential(
21
+ nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(),
22
+ nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(),
23
+ nn.Linear(feat_dim, feat_dim), nn.ReLU(),
24
+ nn.Linear(feat_dim, feat_dim)
25
+ )
26
+
27
+ # def embed_t(self, timesteps, mask):
28
+ # timestep_emb = get_time_embedding(
29
+ # timesteps[:, 0],
30
+ # self.feat_dim,
31
+ # max_positions=2056
32
+ # )[:, None, :].repeat(1, mask.shape[1], 1)
33
+ # return timestep_emb
34
+
35
+ def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None):
36
+ """
37
+ Args:
38
+ aa: (N, L).
39
+ res_nb: (N, L).
40
+ chain_nb: (N, L).
41
+ pos_atoms: (N, L, A, 3).
42
+ mask_atoms: (N, L, A).
43
+ structure_mask: (N, L), mask out unknown structures to generate.
44
+ sequence_mask: (N, L), mask out unknown amino acids to generate.
45
+ """
46
+ N, L = aa.size()
47
+ mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
48
+
49
+ # Remove other atoms
50
+ pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
51
+ mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
52
+
53
+ # Amino acid identity features
54
+ if sequence_mask is not None:
55
+ # Avoid data leakage at training time
56
+ aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
57
+ aa_feat = self.aatype_embed(aa) # (N, L, feat)
58
+
59
+ # Coordinate features
60
+ R = construct_3d_basis(
61
+ pos_atoms[:, :, BBHeavyAtom.CA],
62
+ pos_atoms[:, :, BBHeavyAtom.C],
63
+ pos_atoms[:, :, BBHeavyAtom.N]
64
+ )
65
+ t = pos_atoms[:, :, BBHeavyAtom.CA]
66
+ crd = global_to_local(R, t, pos_atoms) # (N, L, A, 3)
67
+ crd_mask = mask_atoms[:, :, :, None].expand_as(crd)
68
+ crd = torch.where(crd_mask, crd, torch.zeros_like(crd))
69
+
70
+ aa_expand = aa[:, :, None, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
71
+ rng_expand = torch.arange(0, self.max_aa_types)[None, None, :, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3).to(aa_expand)
72
+ place_mask = (aa_expand == rng_expand)
73
+ crd_expand = crd[:, :, None, :, :].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
74
+ crd_expand = torch.where(place_mask, crd_expand, torch.zeros_like(crd_expand))
75
+ crd_feat = crd_expand.reshape(N, L, self.max_aa_types*self.max_num_atoms*3)
76
+ if structure_mask is not None:
77
+ # Avoid data leakage at training time
78
+ crd_feat = crd_feat * structure_mask[:, :, None]
79
+
80
+ # Backbone dihedral features
81
+ bb_dihedral, mask_bb_dihed = get_backbone_dihedral_angles(pos_atoms, chain_nb=chain_nb, res_nb=res_nb, mask=mask_residue)
82
+ dihed_feat = self.dihed_embed(bb_dihedral[:, :, :, None]) * mask_bb_dihed[:, :, :, None] # (N, L, 3, dihed/3)
83
+ dihed_feat = dihed_feat.reshape(N, L, -1)
84
+ if structure_mask is not None:
85
+ # Avoid data leakage at training time
86
+ dihed_mask = torch.logical_and(
87
+ structure_mask,
88
+ torch.logical_and(
89
+ torch.roll(structure_mask, shifts=+1, dims=1),
90
+ torch.roll(structure_mask, shifts=-1, dims=1)
91
+ ),
92
+ ) # Avoid slight data leakage via dihedral angles of anchor residues
93
+ dihed_feat = dihed_feat * dihed_mask[:, :, None]
94
+
95
+ # # timestep
96
+ # timestep_emb = self.embed_t(timesteps, mask_residue)
97
+
98
+ out_feat = self.mlp(torch.cat([aa_feat, crd_feat, dihed_feat], dim=-1)) # (N, L, F)
99
+ out_feat = out_feat * mask_residue[:, :, None]
100
+
101
+ # print(f'aa_seq:{aa},aa:{aa_feat},crd:{crd_feat},dihed:{dihed_feat},time:{timestep_emb}')
102
+
103
+ # print(f'weight:{self.aatype_embed.weight}') # nan, why?
104
+
105
+ return out_feat
models_con/pep_dataloader.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """pep-rec dataset"""
2
+ import os
3
+ import logging
4
+ import joblib
5
+ import pickle
6
+ import lmdb
7
+ from Bio import PDB
8
+ from Bio.PDB import PDBExceptions
9
+ from torch.utils.data import Dataset
10
+ from tqdm.auto import tqdm
11
+
12
+ from pepflow.modules.protein.parsers import parse_pdb
13
+ from pepflow.modules.common.geometry import *
14
+ from pepflow.modules.protein.constants import *
15
+ from pepflow.utils.data import mask_select_data, find_longest_true_segment, PaddingCollate
16
+ from torch.utils.data import DataLoader
17
+
18
+ from omegaconf import OmegaConf
19
+ from easydict import EasyDict
20
+
21
+ from torch.utils.data import DataLoader, Dataset
22
+ from torch.utils.data.distributed import DistributedSampler, dist
23
+
24
+ from pepflow.utils.misc import load_config
25
+ from pepflow.utils.train import recursive_to
26
+
27
+ from models_con.torsion import get_torsion_angle
28
+
29
+ import torch
30
+
31
+ from pepflow.modules.protein.writers import save_pdb
32
+
33
+ # bind_dic = torch.load("/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/misc/affinity_dict.pt")
34
+
35
+ # testset
36
+ names = []
37
+ with open('/datapool/data2/home/ruihan/data/jiahan/ResProj/PepDiff/pepflowww/Data/names.txt','r') as f:
38
+ for line in f:
39
+ names.append(line.strip())
40
+
41
+ def preprocess_structure(task):
42
+
43
+ try:
44
+ if task['id'] in names:
45
+ raise ValueError(f'{task["id"]} not in names')
46
+ pdb_path = task['pdb_path']
47
+ # pep
48
+ # process peptide and find center of mass
49
+ pep = parse_pdb(os.path.join(pdb_path,'peptide.pdb'))[0]
50
+ center = torch.sum(pep['pos_heavyatom'][pep['mask_heavyatom'][:, BBHeavyAtom.CA], BBHeavyAtom.CA], dim=0) / (torch.sum(pep['mask_heavyatom'][:, BBHeavyAtom.CA]) + 1e-8)
51
+ pep['pos_heavyatom'] = pep['pos_heavyatom'] - center[None, None, :]
52
+ pep['torsion_angle'],pep['torsion_angle_mask'] = get_torsion_angle(pep['pos_heavyatom'],pep['aa']) # calc angles after translation
53
+ if len(pep['aa'])<3 or len(pep['aa'])>25:
54
+ raise ValueError('peptide length not in [3,25]')
55
+ # rec
56
+ rec = parse_pdb(os.path.join(pdb_path,'pocket.pdb'))[0]
57
+ rec['pos_heavyatom'] = rec['pos_heavyatom'] - center[None, None, :]
58
+ rec['torsion_angle'],rec['torsion_angle_mask'] = get_torsion_angle(rec['pos_heavyatom'],rec['aa']) # calc angles after translation
59
+ rec['chain_nb'] += 1
60
+ # meta data
61
+ data = {}
62
+ data['id'] = task['id']
63
+ data['generate_mask'] = torch.cat([torch.zeros_like(rec['aa']), torch.ones_like(pep['aa'])], dim=0).bool()
64
+ for k in rec.keys():
65
+ if isinstance(rec[k], torch.Tensor):
66
+ data[k] = torch.cat([rec[k], pep[k]], dim=0)
67
+ elif isinstance(rec[k], list):
68
+ data[k] = rec[k] + pep[k]
69
+ else:
70
+ raise ValueError(f'Unknown type of {rec[k]}')
71
+ return data
72
+
73
+ except (
74
+ PDBExceptions.PDBConstructionException,
75
+ KeyError,
76
+ ValueError,
77
+ TypeError
78
+ ) as e:
79
+ logging.warning('[{}] {}: {}'.format(
80
+ task['id'],
81
+ e.__class__.__name__,
82
+ str(e)
83
+ ))
84
+ return None
85
+
86
+
87
+ class PepDataset(Dataset):
88
+
89
+ MAP_SIZE = 32*(1024*1024*1024) # 32GB
90
+
91
+ def __init__(self, structure_dir = "./Data/PepMerge_new/", dataset_dir = "./Data/",
92
+ name = 'pep', transform=None, reset=False):
93
+
94
+ super().__init__()
95
+ self.structure_dir = structure_dir
96
+ self.dataset_dir = dataset_dir
97
+ self.transform = transform
98
+ self.name = name
99
+
100
+ self.db_conn = None
101
+ self.db_ids = None
102
+ self._load_structures(reset)
103
+
104
+ @property
105
+ def _cache_db_path(self):
106
+ return os.path.join(self.dataset_dir, f'{self.name}_structure_cache.lmdb')
107
+
108
+ def _connect_db(self):
109
+ self._close_db()
110
+ self.db_conn = lmdb.open(
111
+ self._cache_db_path,
112
+ map_size=self.MAP_SIZE,
113
+ create=False,
114
+ subdir=False,
115
+ readonly=True,
116
+ lock=False,
117
+ readahead=False,
118
+ meminit=False,
119
+ )
120
+ with self.db_conn.begin() as txn:
121
+ keys = [k.decode() for k in txn.cursor().iternext(values=False)]
122
+ self.db_ids = keys
123
+
124
+ def _close_db(self):
125
+ if self.db_conn is not None:
126
+ self.db_conn.close()
127
+ self.db_conn = None
128
+ self.db_ids = None
129
+
130
+ def _load_structures(self, reset):
131
+ all_pdbs = os.listdir(self.structure_dir)
132
+
133
+ if reset:
134
+ if os.path.exists(self._cache_db_path):
135
+ os.remove(self._cache_db_path)
136
+ lock_file = self._cache_db_path + "-lock"
137
+ if os.path.exists(lock_file):
138
+ os.remove(lock_file)
139
+ self._close_db()
140
+ todo_pdbs = all_pdbs
141
+ else:
142
+ if not os.path.exists(self._cache_db_path):
143
+ todo_pdbs = all_pdbs
144
+ else:
145
+ todo_pdbs = []
146
+ # self._connect_db()
147
+ # processed_pdbs = self.db_ids
148
+ # self._close_db()
149
+ # todo_pdbs = list(set(all_pdbs) - set(processed_pdbs))
150
+
151
+ if len(todo_pdbs) > 0:
152
+ self._preprocess_structures(todo_pdbs)
153
+
154
+ def _preprocess_structures(self, pdb_list):
155
+ tasks = []
156
+ for pdb_fname in pdb_list:
157
+ pdb_path = os.path.join(self.structure_dir, pdb_fname)
158
+ tasks.append({
159
+ 'id': pdb_fname,
160
+ 'pdb_path': pdb_path,
161
+ })
162
+
163
+ data_list = joblib.Parallel(
164
+ n_jobs = max(joblib.cpu_count() // 2, 1),
165
+ )(
166
+ joblib.delayed(preprocess_structure)(task)
167
+ for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess')
168
+ )
169
+
170
+ db_conn = lmdb.open(
171
+ self._cache_db_path,
172
+ map_size = self.MAP_SIZE,
173
+ create=True,
174
+ subdir=False,
175
+ readonly=False,
176
+ )
177
+ ids = []
178
+ with db_conn.begin(write=True, buffers=True) as txn:
179
+ for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'):
180
+ if data is None:
181
+ continue
182
+ ids.append(data['id'])
183
+ txn.put(data['id'].encode('utf-8'), pickle.dumps(data))
184
+
185
+ def __len__(self):
186
+ self._connect_db() # make sure db_ids is not None
187
+ return len(self.db_ids)
188
+
189
+ def __getitem__(self, index):
190
+ self._connect_db()
191
+ id = self.db_ids[index]
192
+ with self.db_conn.begin() as txn:
193
+ data = pickle.loads(txn.get(id.encode()))
194
+ if self.transform is not None:
195
+ data = self.transform(data)
196
+ return data
197
+
198
+
199
+
200
+ if __name__ == '__main__':
201
+ device = 'cuda:1'
202
+ config,cfg_name = load_config("./configs/learn/learn_all.yaml")
203
+ dataset = PepDataset(structure_dir = "./Data/PepMerge_new/", dataset_dir = "/Data/Fixed Data",
204
+ name = 'pep_pocket_test', transform=None, reset=True)
205
+ print(len(dataset))
206
+ print(dataset[0])
207
+
208
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=PaddingCollate(eight=False))
209
+
210
+ batch = next(iter(dataloader))
211
+ print(batch['torsion_angle'].shape)
212
+ print(batch['torsion_angle_mask'].shape)
models_con/sample.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+
7
+ import copy
8
+ import math
9
+ from tqdm.auto import tqdm
10
+ import functools
11
+ import os
12
+ import argparse
13
+ import pandas as pd
14
+ from copy import deepcopy
15
+
16
+ from models_con.pep_dataloader import PepDataset
17
+
18
+ from pepflow.utils.train import recursive_to
19
+
20
+ from pepflow.modules.common.geometry import reconstruct_backbone, reconstruct_backbone_partially, align, batch_align
21
+ from pepflow.modules.protein.writers import save_pdb
22
+
23
+ from pepflow.utils.data import PaddingCollate
24
+
25
+ from models_con.utils import process_dic
26
+
27
+ from models_con.flow_model import FlowModel
28
+
29
+ from models_con.torsion import full_atom_reconstruction, get_heavyatom_mask
30
+
31
+ collate_fn = PaddingCollate(eight=False)
32
+
33
+ import argparse
34
+
35
+
36
+ def item_to_batch(item, nums=32):
37
+ data_list = [deepcopy(item) for i in range(nums)]
38
+ return collate_fn(data_list)
39
+
40
+ def sample_for_data_bb(data, model, device, save_root, num_steps=200, sample_structure=True, sample_sequence=True, nums=8):
41
+ if not os.path.exists(os.path.join(save_root,data["id"])):
42
+ os.makedirs(os.path.join(save_root,data["id"]))
43
+ batch = recursive_to(item_to_batch(data, nums=nums),device=device)
44
+ traj = model.sample(batch, num_steps=num_steps, sample_structure=sample_structure, sample_sequence=sample_sequence)
45
+ final = recursive_to(traj[-1], device=device)
46
+ pos_bb = reconstruct_backbone(R=final['rotmats'],t=final['trans'],aa=final['seqs'],chain_nb=batch['chain_nb'],res_nb=batch['res_nb'],mask=batch['res_mask']) # (32,L,4,3)
47
+ pos_ha = F.pad(pos_bb, pad=(0,0,0,15-4), value=0.) # (32,L,A,3) pos14 A=14
48
+ pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
49
+ mask_bb_atoms = torch.zeros_like(batch['mask_heavyatom'])
50
+ mask_bb_atoms[:,:,:4] = True
51
+ mask_new = torch.where(batch['generate_mask'][:,:,None],mask_bb_atoms,batch['mask_heavyatom'])
52
+ aa_new = final['seqs']
53
+
54
+ chain_nb = torch.LongTensor([0 if gen_mask else 1 for gen_mask in data['generate_mask']])
55
+ chain_id = ['A' if gen_mask else 'B' for gen_mask in data['generate_mask']]
56
+ icode = [' ' for _ in range(len(data['icode']))]
57
+ for i in range(nums):
58
+ ref_bb_pos = data['pos_heavyatom'][i][:,:4].cpu()
59
+ pred_bb_pos = pos_new[i][:,:4].cpu()
60
+ data_saved = {
61
+ 'chain_nb':data['chain_nb'],'chain_id':data['chain_id'],'resseq':data['resseq'],'icode':data['icode'],
62
+ 'aa':aa_new[i].cpu(), 'mask_heavyatom':mask_new[i].cpu(), 'pos_heavyatom':pos_new[i].cpu(),
63
+ }
64
+
65
+ save_pdb(data_saved,path=os.path.join(save_root,data["id"],f'{data["id"]}_{i}.pdb'))
66
+ save_pdb(data,path=os.path.join(save_root,data["id"],f'{data["id"]}_gt.pdb'))
67
+
68
+ def save_samples_bb(samples,save_dir):
69
+ # meta data
70
+ batch = recursive_to(samples['batch'],'cpu')
71
+ chain_id = [list(item) for item in zip(*batch['chain_id'])][0] # fix chain id in collate func
72
+ icode = [' ' for _ in range(len(chain_id))] # batch icode have same problem
73
+ nums = len(batch['id'])
74
+ id = batch['id'][0]
75
+ # batch convert
76
+ # aa=batch['aa] if only bb level
77
+ pos_bb = reconstruct_backbone(R=samples['rotmats'],t=samples['trans'],aa=samples['seqs'],chain_nb=batch['chain_nb'],res_nb=batch['res_nb'],mask=batch['res_mask']) # (32,L,4,3)
78
+ pos_ha = F.pad(pos_bb, pad=(0,0,0,15-4), value=0.) # (32,L,A,3) pos14 A=14
79
+ pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
80
+ mask_bb_atoms = torch.zeros_like(batch['mask_heavyatom'])
81
+ mask_bb_atoms[:,:,:4] = True
82
+ mask_new = torch.where(batch['generate_mask'][:,:,None],mask_bb_atoms,batch['mask_heavyatom'])
83
+ aa_new = samples['seqs']
84
+ for i in range(nums):
85
+ data_saved = {
86
+ 'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
87
+ 'aa':aa_new[i], 'mask_heavyatom':mask_new[i], 'pos_heavyatom':pos_new[i],
88
+ }
89
+ save_pdb(data_saved,path=os.path.join(save_dir,f'sample_{i}.pdb'))
90
+ data_saved = {
91
+ 'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
92
+ 'aa':batch['aa'][0], 'mask_heavyatom':batch['mask_heavyatom'][0], 'pos_heavyatom':batch['pos_heavyatom'][0],
93
+ }
94
+ save_pdb(data_saved,path=os.path.join(save_dir,f'gt.pdb'))
95
+
96
+ def save_samples_sc(samples,save_dir):
97
+ # meta data
98
+ batch = recursive_to(samples['batch'],'cpu')
99
+ chain_id = [list(item) for item in zip(*batch['chain_id'])][0] # fix chain id in collate func
100
+ icode = [' ' for _ in range(len(chain_id))] # batch icode have same problem
101
+ nums = len(batch['id'])
102
+ id = batch['id'][0]
103
+ # batch convert
104
+ # aa=batch['aa] if only bb level
105
+ pos_ha,_,_ = full_atom_reconstruction(R_bb=samples['rotmats'],t_bb=samples['trans'],angles=samples['angles'],aa=samples['seqs']) # (32,L,14,3), instead of 15, ignore OXT masked
106
+ pos_ha = F.pad(pos_ha, pad=(0,0,0,15-14), value=0.) # (32,L,A,3) pos14 A=14
107
+ pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
108
+ mask_new = get_heavyatom_mask(samples['seqs'])
109
+ aa_new = samples['seqs']
110
+ for i in range(nums):
111
+ data_saved = {
112
+ 'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
113
+ 'aa':aa_new[i], 'mask_heavyatom':mask_new[i], 'pos_heavyatom':pos_new[i],
114
+ }
115
+ save_pdb(data_saved,path=os.path.join(save_dir,f'sample_{i}.pdb'))
116
+ data_saved = {
117
+ 'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
118
+ 'aa':batch['aa'][0], 'mask_heavyatom':batch['mask_heavyatom'][0], 'pos_heavyatom':batch['pos_heavyatom'][0],
119
+ }
120
+ save_pdb(data_saved,path=os.path.join(save_dir,f'gt.pdb'))
121
+
122
+ if __name__ == '__main__':
123
+ # sample = torch.load('./Codesign/outputs/1aze_B.pt')
124
+ # save_samples_sc(sample,'./misc/test')
125
+ # save_samples_bb(sample,'./misc/test')
126
+ # for k,v in sample.items():
127
+ # if isinstance(v,torch.Tensor):
128
+ # print(f'{k},{v.shape}')
129
+
130
+ # # subdir = 'bb_seq_angle' # bb,bb_seq,bb_seq_angle
131
+ # names = [n.split('.')[0] for n in os.listdir(os.path.join(SAMPLE_DIR,subdir,'outputs'))]
132
+ # for name in tqdm(names):
133
+ # sample = torch.load(os.path.join(SAMPLE_DIR,subdir,'outputs',f'{name}.pt'))
134
+ # os.makedirs(os.path.join(SAMPLE_DIR,subdir,'pdbs',name),exist_ok=True)
135
+ # save_samples_sc(sample,os.path.join(SAMPLE_DIR,subdir,'pdbs',name))
136
+
137
+ args = argparse.ArgumentParser()
138
+ args.add_argument('--SAMPLEDIR', type=str)
139
+ parser = args.parse_args()
140
+ SAMPLE_DIR = parser.SAMPLEDIR
141
+ names = [n.split('.')[0] for n in os.listdir(os.path.join(SAMPLE_DIR,'outputs'))]
142
+ for name in tqdm(names):
143
+ sample = torch.load(os.path.join(SAMPLE_DIR,'outputs',f'{name}.pt'))
144
+ os.makedirs(os.path.join(SAMPLE_DIR,'pdbs',name),exist_ok=True)
145
+ save_samples_sc(sample,os.path.join(SAMPLE_DIR,'pdbs',name))
models_con/torsion.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from typing import Any, Optional, Union, cast
5
+
6
+ from pepflow.modules.common.geometry import *
7
+ import pepflow.modules.protein.constants as constants
8
+
9
+ """
10
+ calc torsion angles between (0,2pi)
11
+ """
12
+
13
+ def _get_torsion(p0, p1, p2, p3):
14
+ """
15
+ Args:
16
+ p0-3: (*, 3).
17
+ Returns:
18
+ Dihedral angles in radian, (*, ).
19
+ """
20
+ v0 = p2 - p1
21
+ v1 = p0 - p1
22
+ v2 = p3 - p2
23
+ u1 = torch.cross(v0, v1, dim=-1)
24
+ n1 = u1 / torch.linalg.norm(u1, dim=-1, keepdim=True)
25
+ u2 = torch.cross(v0, v2, dim=-1)
26
+ n2 = u2 / torch.linalg.norm(u2, dim=-1, keepdim=True)
27
+ sgn = torch.sign( (torch.cross(v1, v2, dim=-1) * v0).sum(-1) )
28
+ dihed = sgn*torch.acos( (n1 * n2).sum(-1).clamp(min=-0.999999, max=0.999999))
29
+ return dihed
30
+
31
+ def get_chi_angles(restype, pos14):
32
+ chi_angles = torch.full([4], fill_value=float("inf")).to(pos14)
33
+ base_atom_names = constants.chi_angles_atoms[restype]
34
+ for i, four_atom_names in enumerate(base_atom_names):
35
+ atom_indices = [constants.restype_atom14_name_to_index[restype][a] for a in four_atom_names]
36
+ p = torch.stack([pos14[i] for i in atom_indices])
37
+ # if torch.eq(p, 99999).any():
38
+ # continue
39
+ torsion = _get_torsion(*torch.unbind(p, dim=0))
40
+ chi_angles[i] = torsion
41
+ return chi_angles
42
+
43
+
44
+ def get_psi_angle(pos14: torch.Tensor) -> torch.Tensor:
45
+ return _get_torsion(pos14[0], pos14[1], pos14[2], pos14[3]).reshape([1]) # af style psi, N,CA,C,O
46
+
47
+
48
+ def get_torsion_angle(pos14: torch.Tensor, aa: torch.LongTensor):
49
+ torsion, torsion_mask = [], []
50
+ for i in range(pos14.shape[0]):
51
+ if aa[i] < constants.AA.UNK: # 0-19
52
+ chi = get_chi_angles(aa[i].item(), pos14[i])
53
+ psi = get_psi_angle(pos14[i])
54
+ torsion_this = torch.cat([psi, chi], dim=0)
55
+ torsion_mask_this = torsion_this.isfinite()
56
+ else:
57
+ torsion_this = torch.full([5], 0.)
58
+ torsion_mask_this = torch.full([5], False)
59
+ torsion.append(torsion_this.nan_to_num(posinf=0.))
60
+ torsion_mask.append(torsion_mask_this)
61
+
62
+ torsion = torch.stack(torsion) % (2*math.pi)
63
+ torsion_mask = torch.stack(torsion_mask).bool()
64
+
65
+ return torsion, torsion_mask
66
+
67
+ def _make_psi_chi_rotation_matrices(angles: torch.Tensor) -> torch.Tensor:
68
+ """Compute psi and chi rotation matrices from torsional angles.
69
+
70
+ Here we provide angles instead of alpha in af2 between (0,2pi)
71
+
72
+ See alphafold supplementary Algorithm 25 for details.
73
+
74
+ Args:
75
+ angles: (B, N, 5), angles between (0,2pi)
76
+
77
+ Returns:
78
+ Torsional angle rotation matrices, (B, N, 5, 3, 3).
79
+ """
80
+ batch_size, n_res = angles.shape[:2]
81
+ sine,cosine = torch.sin(angles), torch.cos(angles)
82
+ sine = sine.reshape(batch_size, n_res, -1, 1, 1)
83
+ cosine = cosine.reshape(batch_size, n_res, -1, 1, 1)
84
+ zero = torch.zeros_like(sine)
85
+ one = torch.ones_like(sine)
86
+
87
+ row1 = torch.cat([one, zero, zero], dim=-1) # (B, N, 5, 1, 3)
88
+ row2 = torch.cat([zero, cosine, -sine], dim=-1) # (B, N, 5, 1, 3)
89
+ row3 = torch.cat([zero, sine, cosine], dim=-1) # (B, N, 5, 1, 3)
90
+ R = torch.cat([row1, row2, row3], dim=-2) # (B, N, 5, 3, 3)
91
+
92
+ return R
93
+
94
+
95
+ def _get_rigid_group(aa: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
96
+ """Extract rigid group constants.
97
+
98
+ Args:
99
+ aa: Amino acid types, (B, N).
100
+
101
+ Returns:
102
+ A tuple of rigid group rotation, translation, atom14 group and atom14 position.
103
+ """
104
+ batch_size, n_res = aa.size()
105
+ aa = aa.flatten()
106
+ rotation = constants.restype_rigid_group_rotation.to(aa.device)[aa].reshape(batch_size, n_res, 8, 3, 3)
107
+ translation = constants.restype_rigid_group_translation.to(aa.device)[aa].reshape(batch_size, n_res, 8, 3)
108
+ atom14_group = constants.restype_heavyatom_to_rigid_group.to(aa.device)[aa].reshape(batch_size, n_res, 14)
109
+ atom14_position = constants.restype_heavyatom_rigid_group_positions.to(aa.device)[aa].reshape(
110
+ batch_size, n_res, 14, 3
111
+ )
112
+ return rotation, translation, atom14_group, atom14_position
113
+
114
+
115
+ # construct heavy atom masks for genrating
116
+ # restype_to_heavyatom_masks = {
117
+ # restype: [name != "" and name !='OXT' for name in names]
118
+ # for restype, names in constants.restype_to_heavyatom_names.items()
119
+ # }
120
+ # print(restype_to_heavyatom_masks[0])
121
+
122
+ restype_to_heavyatom_masks = torch.zeros([22,15]).bool()
123
+ for i in range(21):
124
+ restype_to_heavyatom_masks[i] = torch.tensor([name != "" and name !='OXT' for name in constants.restype_to_heavyatom_names[i]]).bool()
125
+
126
+ def get_heavyatom_mask(aa: torch.Tensor) -> torch.Tensor:
127
+ """Compute heavy atom masks from amino acid types.
128
+
129
+ Args:
130
+ aa: Amino acid types, (B, N).
131
+
132
+ Returns:
133
+ Heavy atom masks, (B, N, 15).
134
+ """
135
+ batch_size, n_res = aa.size()
136
+ aa = aa.flatten()
137
+ mask = restype_to_heavyatom_masks.to(aa.device)[aa].reshape(batch_size, n_res, 15)
138
+ return mask
139
+
140
+ def full_atom_reconstruction(
141
+ R_bb: torch.Tensor,
142
+ t_bb: torch.Tensor,
143
+ angles: torch.Tensor,
144
+ aa: torch.Tensor,
145
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ """Compute full atom positions from backbone frames and torsional angles.
147
+
148
+ See alphafold supplementary Algorithm 24 for details.
149
+
150
+ Args:
151
+ R_bb: Rotation of backbone frames, (B, N, 3, 3).
152
+ t_bb: Translation of backbone frames, (B, N, 3).
153
+ angles: (B, N, 5), angles between (0,2pi)
154
+ aa: Amino acid types, (B, N).
155
+
156
+ Returns:
157
+ A tuple of atom positions and full frames, (pos14, R, t).
158
+ pos14: Full atom positions in pos14 representations, (B, N, 14, 3).
159
+ R: Rotation of backbone, psi, chi1-4 frames, (B, N, 5, 3, 3).
160
+ t: Rotation of backbone, psi, chi1-4 frames, (B, N, 5, 3).
161
+ """
162
+ N, L = aa.size()
163
+
164
+ rot_psi, rot_chi1, rot_chi2, rot_chi3, rot_chi4 = _make_psi_chi_rotation_matrices(angles).unbind(dim=2)
165
+ # (B, N, 3, 3)
166
+ zeros = torch.zeros_like(t_bb)
167
+
168
+ rigid_rotation, rigid_translation, atom14_group, atom14_position = _get_rigid_group(aa)
169
+
170
+ R_psi, t_psi = compose_chain(
171
+ [
172
+ (R_bb, t_bb),
173
+ (rigid_rotation[:, :, constants.PSI_FRAME], rigid_translation[:, :, constants.PSI_FRAME]),
174
+ (rot_psi, zeros),
175
+ ]
176
+ )
177
+
178
+ R_chi1, t_chi1 = compose_chain(
179
+ [
180
+ (R_bb, t_bb),
181
+ (rigid_rotation[:, :, constants.CHI1_FRAME], rigid_translation[:, :, constants.CHI1_FRAME]),
182
+ (rot_chi1, zeros),
183
+ ]
184
+ )
185
+
186
+ R_chi2, t_chi2 = compose_chain(
187
+ [
188
+ (R_chi1, t_chi1),
189
+ (rigid_rotation[:, :, constants.CHI2_FRAME], rigid_translation[:, :, constants.CHI2_FRAME]),
190
+ (rot_chi2, zeros),
191
+ ]
192
+ )
193
+
194
+ R_chi3, t_chi3 = compose_chain(
195
+ [
196
+ (R_chi2, t_chi2),
197
+ (rigid_rotation[:, :, constants.CHI3_FRAME], rigid_translation[:, :, constants.CHI3_FRAME]),
198
+ (rot_chi3, zeros),
199
+ ]
200
+ )
201
+
202
+ R_chi4, t_chi4 = compose_chain(
203
+ [
204
+ (R_chi3, t_chi3),
205
+ (rigid_rotation[:, :, constants.CHI4_FRAME], rigid_translation[:, :, constants.CHI4_FRAME]),
206
+ (rot_chi4, zeros),
207
+ ]
208
+ )
209
+
210
+ # Return Frame
211
+ R_ret = torch.stack([R_bb, R_psi, R_chi1, R_chi2, R_chi3, R_chi4], dim=2)
212
+ t_ret = torch.stack([t_bb, t_psi, t_chi1, t_chi2, t_chi3, t_chi4], dim=2)
213
+
214
+ # Backbone, Omega, Phi, Psi, Chi1,2,3,4
215
+ R_all = torch.stack([R_bb, R_bb, R_bb, R_psi, R_chi1, R_chi2, R_chi3, R_chi4], dim=2) # (B, N, 8, 3, 3)
216
+ t_all = torch.stack([t_bb, t_bb, t_bb, t_psi, t_chi1, t_chi2, t_chi3, t_chi4], dim=2) # (B, N, 8, 3)
217
+
218
+ index_R = atom14_group.reshape(N, L, 14, 1, 1).repeat(1, 1, 1, 3, 3) # (B, N, 14, 3, 3)
219
+ index_t = atom14_group.reshape(N, L, 14, 1).repeat(1, 1, 1, 3) # (B, N, 14, 3)
220
+
221
+ R_atom = torch.gather(R_all, dim=2, index=index_R) # (N, L, 14, 3, 3)
222
+ t_atom = torch.gather(t_all, dim=2, index=index_t) # (N, L, 14, 3)
223
+ p_atom = atom14_position # (N, L, 14, 3)
224
+
225
+ pos14 = torch.matmul(R_atom, p_atom.unsqueeze(-1)).squeeze(-1) + t_atom
226
+ return pos14, R_ret, t_ret
227
+
228
+
229
+
230
+ torsions_mask = torch.zeros([22,5]).float() # 0-19, X, PAD
231
+ for i in range(21):
232
+ torsions_mask[i] = torch.tensor([True] + constants.chi_angles_mask[i]).float()
233
+ # print(angles_mask)
234
+
235
+ if __name__ =='__main__':
236
+ aa = torch.full([3,8],fill_value=constants.AA.THR).long()
237
+ mask = get_heavyatom_mask(aa)
238
+ print(mask)
239
+ print(mask.shape)
models_con/torus.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ def tor_expmap(x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
6
+ return (x + u) % (2 * math.pi)
7
+
8
+ def tor_logmap(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
9
+ return torch.atan2(torch.sin(y - x), torch.cos(y - x))
10
+
11
+ def tor_projx(x: torch.Tensor) -> torch.Tensor:
12
+ return x % (2 * math.pi)
13
+
14
+ def tor_random_uniform(*size, dtype=None, device=None) -> torch.Tensor:
15
+ z = torch.rand(*size, dtype=dtype, device=device)
16
+ return z * 2 * math.pi
17
+
18
+ def tor_uniform_logprob(x):
19
+ dim = x.shape[-1]
20
+ return torch.full_like(x[..., 0], -dim * math.log(2 * math.pi))
21
+
22
+ def tor_geodesic_t(t, angles_1, angles_0):
23
+ # target, base
24
+ tangent_vec = t * tor_logmap(angles_0, angles_1)
25
+ points_at_time_t = tor_expmap(angles_0, tangent_vec)
26
+ return points_at_time_t
27
+
28
+ if __name__ =='__main__':
29
+ a = tor_random_uniform((2,3,5))
30
+ b = tor_random_uniform((2,3,5))
31
+ t = torch.ones((2,1)) * 0.2
32
+ c = tor_geodesic_t(t[...,None],a,b)
33
+ print(c)
34
+ print(c.shape)
models_con/utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import copy
7
+ import math
8
+ from tqdm.auto import tqdm
9
+ import functools
10
+ from torch.utils.data import DataLoader
11
+ import os
12
+ import argparse
13
+
14
+ import pandas as pd
15
+
16
+ def process_dic(state_dict):
17
+ new_state_dict = {}
18
+ for k,v in state_dict.items():
19
+ if 'module' in k:
20
+ new_state_dict[k[7:]] = v
21
+ else:
22
+ new_state_dict[k] = v
23
+ return new_state_dict
24
+
25
+
26
+ def calc_distogram(pos, min_bin, max_bin, num_bins):
27
+ dists_2d = torch.linalg.norm(
28
+ pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None]
29
+ lower = torch.linspace(
30
+ min_bin,
31
+ max_bin,
32
+ num_bins,
33
+ device=pos.device)
34
+ upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
35
+ dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
36
+ return dgram
37
+
38
+
39
+ def get_index_embedding(indices, embed_size, max_len=2056):
40
+ """Creates sine / cosine positional embeddings from a prespecified indices.
41
+
42
+ Args:
43
+ indices: offsets of size [..., N_edges] of type integer
44
+ max_len: maximum length.
45
+ embed_size: dimension of the embeddings to create
46
+
47
+ Returns:
48
+ positional embedding of shape [N, embed_size]
49
+ """
50
+ K = torch.arange(embed_size//2, device=indices.device)
51
+ pos_embedding_sin = torch.sin(
52
+ indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
53
+ pos_embedding_cos = torch.cos(
54
+ indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
55
+ pos_embedding = torch.cat([
56
+ pos_embedding_sin, pos_embedding_cos], axis=-1)
57
+ return pos_embedding
58
+
59
+
60
+ def get_time_embedding(timesteps, embedding_dim, max_positions=2000):
61
+ # Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
62
+ assert len(timesteps.shape) == 1
63
+ timesteps = timesteps * max_positions
64
+ half_dim = embedding_dim // 2
65
+ emb = math.log(max_positions) / (half_dim - 1)
66
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
67
+ emb = timesteps.float()[:, None] * emb[None, :]
68
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
69
+ if embedding_dim % 2 == 1: # zero pad
70
+ emb = F.pad(emb, (0, 1), mode='constant')
71
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
72
+ return emb
openfold/config.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ NUM_RES = "num residues placeholder"
2
+ NUM_MSA_SEQ = "msa placeholder"
3
+ NUM_EXTRA_SEQ = "extra msa placeholder"
4
+ NUM_TEMPLATES = "num templates placeholder"
openfold/model/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import importlib as importlib
4
+
5
+ _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
6
+ __all__ = [
7
+ os.path.basename(f)[:-3]
8
+ for f in _files
9
+ if os.path.isfile(f) and not f.endswith("__init__.py")
10
+ ]
11
+ _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
12
+ for _m in _modules:
13
+ globals()[_m[0]] = _m[1]
14
+
15
+ # Avoid needlessly cluttering the global namespace
16
+ del _files, _m, _modules
openfold/model/dropout.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from functools import partialmethod
19
+ from typing import Union, List
20
+
21
+
22
+ class Dropout(nn.Module):
23
+ """
24
+ Implementation of dropout with the ability to share the dropout mask
25
+ along a particular dimension.
26
+
27
+ If not in training mode, this module computes the identity function.
28
+ """
29
+
30
+ def __init__(self, r: float, batch_dim: Union[int, List[int]]):
31
+ """
32
+ Args:
33
+ r:
34
+ Dropout rate
35
+ batch_dim:
36
+ Dimension(s) along which the dropout mask is shared
37
+ """
38
+ super(Dropout, self).__init__()
39
+
40
+ self.r = r
41
+ if type(batch_dim) == int:
42
+ batch_dim = [batch_dim]
43
+ self.batch_dim = batch_dim
44
+ self.dropout = nn.Dropout(self.r)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Args:
49
+ x:
50
+ Tensor to which dropout is applied. Can have any shape
51
+ compatible with self.batch_dim
52
+ """
53
+ shape = list(x.shape)
54
+ if self.batch_dim is not None:
55
+ for bd in self.batch_dim:
56
+ shape[bd] = 1
57
+ mask = x.new_ones(shape)
58
+ mask = self.dropout(mask)
59
+ x *= mask
60
+ return x
61
+
62
+
63
+ class DropoutRowwise(Dropout):
64
+ """
65
+ Convenience class for rowwise dropout as described in subsection
66
+ 1.11.6.
67
+ """
68
+
69
+ __init__ = partialmethod(Dropout.__init__, batch_dim=-3)
70
+
71
+
72
+ class DropoutColumnwise(Dropout):
73
+ """
74
+ Convenience class for columnwise dropout as described in subsection
75
+ 1.11.6.
76
+ """
77
+
78
+ __init__ = partialmethod(Dropout.__init__, batch_dim=-2)
openfold/model/embedders.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from typing import Tuple
19
+
20
+ from openfold.model.primitives import Linear, LayerNorm
21
+ from openfold.utils.tensor_utils import one_hot
22
+
23
+
24
+ class InputEmbedder(nn.Module):
25
+ """
26
+ Embeds a subset of the input features.
27
+
28
+ Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ tf_dim: int,
34
+ msa_dim: int,
35
+ c_z: int,
36
+ c_m: int,
37
+ relpos_k: int,
38
+ **kwargs,
39
+ ):
40
+ """
41
+ Args:
42
+ tf_dim:
43
+ Final dimension of the target features
44
+ msa_dim:
45
+ Final dimension of the MSA features
46
+ c_z:
47
+ Pair embedding dimension
48
+ c_m:
49
+ MSA embedding dimension
50
+ relpos_k:
51
+ Window size used in relative positional encoding
52
+ """
53
+ super(InputEmbedder, self).__init__()
54
+
55
+ self.tf_dim = tf_dim
56
+ self.msa_dim = msa_dim
57
+
58
+ self.c_z = c_z
59
+ self.c_m = c_m
60
+
61
+ self.linear_tf_z_i = Linear(tf_dim, c_z)
62
+ self.linear_tf_z_j = Linear(tf_dim, c_z)
63
+ self.linear_tf_m = Linear(tf_dim, c_m)
64
+ self.linear_msa_m = Linear(msa_dim, c_m)
65
+
66
+ # RPE stuff
67
+ self.relpos_k = relpos_k
68
+ self.no_bins = 2 * relpos_k + 1
69
+ self.linear_relpos = Linear(self.no_bins, c_z)
70
+
71
+ def relpos(self, ri: torch.Tensor):
72
+ """
73
+ Computes relative positional encodings
74
+
75
+ Implements Algorithm 4.
76
+
77
+ Args:
78
+ ri:
79
+ "residue_index" features of shape [*, N]
80
+ """
81
+ d = ri[..., None] - ri[..., None, :]
82
+ boundaries = torch.arange(
83
+ start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
84
+ )
85
+ oh = one_hot(d, boundaries).type(ri.dtype)
86
+ return self.linear_relpos(oh)
87
+
88
+ def forward(
89
+ self,
90
+ tf: torch.Tensor,
91
+ ri: torch.Tensor,
92
+ msa: torch.Tensor,
93
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
94
+ """
95
+ Args:
96
+ tf:
97
+ "target_feat" features of shape [*, N_res, tf_dim]
98
+ ri:
99
+ "residue_index" features of shape [*, N_res]
100
+ msa:
101
+ "msa_feat" features of shape [*, N_clust, N_res, msa_dim]
102
+ Returns:
103
+ msa_emb:
104
+ [*, N_clust, N_res, C_m] MSA embedding
105
+ pair_emb:
106
+ [*, N_res, N_res, C_z] pair embedding
107
+
108
+ """
109
+ # [*, N_res, c_z]
110
+ tf_emb_i = self.linear_tf_z_i(tf)
111
+ tf_emb_j = self.linear_tf_z_j(tf)
112
+
113
+ # [*, N_res, N_res, c_z]
114
+ pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
115
+ pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
116
+
117
+ # [*, N_clust, N_res, c_m]
118
+ n_clust = msa.shape[-3]
119
+ tf_m = (
120
+ self.linear_tf_m(tf)
121
+ .unsqueeze(-3)
122
+ .expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
123
+ )
124
+ msa_emb = self.linear_msa_m(msa) + tf_m
125
+
126
+ return msa_emb, pair_emb
127
+
128
+
129
+ class RecyclingEmbedder(nn.Module):
130
+ """
131
+ Embeds the output of an iteration of the model for recycling.
132
+
133
+ Implements Algorithm 32.
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ c_m: int,
139
+ c_z: int,
140
+ min_bin: float,
141
+ max_bin: float,
142
+ no_bins: int,
143
+ inf: float = 1e8,
144
+ **kwargs,
145
+ ):
146
+ """
147
+ Args:
148
+ c_m:
149
+ MSA channel dimension
150
+ c_z:
151
+ Pair embedding channel dimension
152
+ min_bin:
153
+ Smallest distogram bin (Angstroms)
154
+ max_bin:
155
+ Largest distogram bin (Angstroms)
156
+ no_bins:
157
+ Number of distogram bins
158
+ """
159
+ super(RecyclingEmbedder, self).__init__()
160
+
161
+ self.c_m = c_m
162
+ self.c_z = c_z
163
+ self.min_bin = min_bin
164
+ self.max_bin = max_bin
165
+ self.no_bins = no_bins
166
+ self.inf = inf
167
+
168
+ self.bins = None
169
+
170
+ self.linear = Linear(self.no_bins, self.c_z)
171
+ self.layer_norm_m = LayerNorm(self.c_m)
172
+ self.layer_norm_z = LayerNorm(self.c_z)
173
+
174
+ def forward(
175
+ self,
176
+ m: torch.Tensor,
177
+ z: torch.Tensor,
178
+ x: torch.Tensor,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """
181
+ Args:
182
+ m:
183
+ First row of the MSA embedding. [*, N_res, C_m]
184
+ z:
185
+ [*, N_res, N_res, C_z] pair embedding
186
+ x:
187
+ [*, N_res, 3] predicted C_beta coordinates
188
+ Returns:
189
+ m:
190
+ [*, N_res, C_m] MSA embedding update
191
+ z:
192
+ [*, N_res, N_res, C_z] pair embedding update
193
+ """
194
+ if self.bins is None:
195
+ self.bins = torch.linspace(
196
+ self.min_bin,
197
+ self.max_bin,
198
+ self.no_bins,
199
+ dtype=x.dtype,
200
+ device=x.device,
201
+ requires_grad=False,
202
+ )
203
+
204
+ # [*, N, C_m]
205
+ m_update = self.layer_norm_m(m)
206
+
207
+ # This squared method might become problematic in FP16 mode.
208
+ # I'm using it because my homegrown method had a stubborn discrepancy I
209
+ # couldn't find in time.
210
+ squared_bins = self.bins ** 2
211
+ upper = torch.cat(
212
+ [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
213
+ )
214
+ d = torch.sum(
215
+ (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
216
+ )
217
+
218
+ # [*, N, N, no_bins]
219
+ d = ((d > squared_bins) * (d < upper)).type(x.dtype)
220
+
221
+ # [*, N, N, C_z]
222
+ d = self.linear(d)
223
+ z_update = d + self.layer_norm_z(z)
224
+
225
+ return m_update, z_update
226
+
227
+
228
+ class TemplateAngleEmbedder(nn.Module):
229
+ """
230
+ Embeds the "template_angle_feat" feature.
231
+
232
+ Implements Algorithm 2, line 7.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ c_in: int,
238
+ c_out: int,
239
+ **kwargs,
240
+ ):
241
+ """
242
+ Args:
243
+ c_in:
244
+ Final dimension of "template_angle_feat"
245
+ c_out:
246
+ Output channel dimension
247
+ """
248
+ super(TemplateAngleEmbedder, self).__init__()
249
+
250
+ self.c_out = c_out
251
+ self.c_in = c_in
252
+
253
+ self.linear_1 = Linear(self.c_in, self.c_out, init="relu")
254
+ self.relu = nn.ReLU()
255
+ self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
256
+
257
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
258
+ """
259
+ Args:
260
+ x: [*, N_templ, N_res, c_in] "template_angle_feat" features
261
+ Returns:
262
+ x: [*, N_templ, N_res, C_out] embedding
263
+ """
264
+ x = self.linear_1(x)
265
+ x = self.relu(x)
266
+ x = self.linear_2(x)
267
+
268
+ return x
269
+
270
+
271
+ class TemplatePairEmbedder(nn.Module):
272
+ """
273
+ Embeds "template_pair_feat" features.
274
+
275
+ Implements Algorithm 2, line 9.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ c_in: int,
281
+ c_out: int,
282
+ **kwargs,
283
+ ):
284
+ """
285
+ Args:
286
+ c_in:
287
+
288
+ c_out:
289
+ Output channel dimension
290
+ """
291
+ super(TemplatePairEmbedder, self).__init__()
292
+
293
+ self.c_in = c_in
294
+ self.c_out = c_out
295
+
296
+ # Despite there being no relu nearby, the source uses that initializer
297
+ self.linear = Linear(self.c_in, self.c_out, init="relu")
298
+
299
+ def forward(
300
+ self,
301
+ x: torch.Tensor,
302
+ ) -> torch.Tensor:
303
+ """
304
+ Args:
305
+ x:
306
+ [*, C_in] input tensor
307
+ Returns:
308
+ [*, C_out] output tensor
309
+ """
310
+ x = self.linear(x)
311
+
312
+ return x
313
+
314
+
315
+ class ExtraMSAEmbedder(nn.Module):
316
+ """
317
+ Embeds unclustered MSA sequences.
318
+
319
+ Implements Algorithm 2, line 15
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ c_in: int,
325
+ c_out: int,
326
+ **kwargs,
327
+ ):
328
+ """
329
+ Args:
330
+ c_in:
331
+ Input channel dimension
332
+ c_out:
333
+ Output channel dimension
334
+ """
335
+ super(ExtraMSAEmbedder, self).__init__()
336
+
337
+ self.c_in = c_in
338
+ self.c_out = c_out
339
+
340
+ self.linear = Linear(self.c_in, self.c_out)
341
+
342
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
343
+ """
344
+ Args:
345
+ x:
346
+ [*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
347
+ Returns:
348
+ [*, N_extra_seq, N_res, C_out] embedding
349
+ """
350
+ x = self.linear(x)
351
+
352
+ return x
openfold/model/evoformer.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ import torch
18
+ import torch.nn as nn
19
+ from typing import Tuple, Optional
20
+ from functools import partial
21
+
22
+ from openfold.model.primitives import Linear, LayerNorm
23
+ from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
24
+ from openfold.model.msa import (
25
+ MSARowAttentionWithPairBias,
26
+ MSAColumnAttention,
27
+ MSAColumnGlobalAttention,
28
+ )
29
+ from openfold.model.outer_product_mean import OuterProductMean
30
+ from openfold.model.pair_transition import PairTransition
31
+ from openfold.model.triangular_attention import (
32
+ TriangleAttentionStartingNode,
33
+ TriangleAttentionEndingNode,
34
+ )
35
+ from openfold.model.triangular_multiplicative_update import (
36
+ TriangleMultiplicationOutgoing,
37
+ TriangleMultiplicationIncoming,
38
+ )
39
+ from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
40
+ from openfold.utils.tensor_utils import chunk_layer
41
+
42
+
43
+ class MSATransition(nn.Module):
44
+ """
45
+ Feed-forward network applied to MSA activations after attention.
46
+
47
+ Implements Algorithm 9
48
+ """
49
+ def __init__(self, c_m, n):
50
+ """
51
+ Args:
52
+ c_m:
53
+ MSA channel dimension
54
+ n:
55
+ Factor multiplied to c_m to obtain the hidden channel
56
+ dimension
57
+ """
58
+ super(MSATransition, self).__init__()
59
+
60
+ self.c_m = c_m
61
+ self.n = n
62
+
63
+ self.layer_norm = LayerNorm(self.c_m)
64
+ self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
65
+ self.relu = nn.ReLU()
66
+ self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
67
+
68
+ def _transition(self, m, mask):
69
+ m = self.linear_1(m)
70
+ m = self.relu(m)
71
+ m = self.linear_2(m) * mask
72
+ return m
73
+
74
+ @torch.jit.ignore
75
+ def _chunk(self,
76
+ m: torch.Tensor,
77
+ mask: torch.Tensor,
78
+ chunk_size: int,
79
+ ) -> torch.Tensor:
80
+ return chunk_layer(
81
+ self._transition,
82
+ {"m": m, "mask": mask},
83
+ chunk_size=chunk_size,
84
+ no_batch_dims=len(m.shape[:-2]),
85
+ )
86
+
87
+
88
+ def forward(
89
+ self,
90
+ m: torch.Tensor,
91
+ mask: Optional[torch.Tensor] = None,
92
+ chunk_size: Optional[int] = None,
93
+ ) -> torch.Tensor:
94
+ """
95
+ Args:
96
+ m:
97
+ [*, N_seq, N_res, C_m] MSA activation
98
+ mask:
99
+ [*, N_seq, N_res, C_m] MSA mask
100
+ Returns:
101
+ m:
102
+ [*, N_seq, N_res, C_m] MSA activation update
103
+ """
104
+ # DISCREPANCY: DeepMind forgets to apply the MSA mask here.
105
+ if mask is None:
106
+ mask = m.new_ones(m.shape[:-1])
107
+
108
+ mask = mask.unsqueeze(-1)
109
+
110
+ m = self.layer_norm(m)
111
+
112
+ if chunk_size is not None:
113
+ m = self._chunk(m, mask, chunk_size)
114
+ else:
115
+ m = self._transition(m, mask)
116
+
117
+ return m
118
+
119
+
120
+ class EvoformerBlockCore(nn.Module):
121
+ def __init__(
122
+ self,
123
+ c_m: int,
124
+ c_z: int,
125
+ c_hidden_opm: int,
126
+ c_hidden_mul: int,
127
+ c_hidden_pair_att: int,
128
+ no_heads_msa: int,
129
+ no_heads_pair: int,
130
+ transition_n: int,
131
+ pair_dropout: float,
132
+ inf: float,
133
+ eps: float,
134
+ _is_extra_msa_stack: bool = False,
135
+ ):
136
+ super(EvoformerBlockCore, self).__init__()
137
+
138
+ self.msa_transition = MSATransition(
139
+ c_m=c_m,
140
+ n=transition_n,
141
+ )
142
+
143
+ self.outer_product_mean = OuterProductMean(
144
+ c_m,
145
+ c_z,
146
+ c_hidden_opm,
147
+ )
148
+
149
+ self.tri_mul_out = TriangleMultiplicationOutgoing(
150
+ c_z,
151
+ c_hidden_mul,
152
+ )
153
+ self.tri_mul_in = TriangleMultiplicationIncoming(
154
+ c_z,
155
+ c_hidden_mul,
156
+ )
157
+
158
+ self.tri_att_start = TriangleAttentionStartingNode(
159
+ c_z,
160
+ c_hidden_pair_att,
161
+ no_heads_pair,
162
+ inf=inf,
163
+ )
164
+ self.tri_att_end = TriangleAttentionEndingNode(
165
+ c_z,
166
+ c_hidden_pair_att,
167
+ no_heads_pair,
168
+ inf=inf,
169
+ )
170
+
171
+ self.pair_transition = PairTransition(
172
+ c_z,
173
+ transition_n,
174
+ )
175
+
176
+ self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
177
+ self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
178
+
179
+ def forward(
180
+ self,
181
+ m: torch.Tensor,
182
+ z: torch.Tensor,
183
+ msa_mask: torch.Tensor,
184
+ pair_mask: torch.Tensor,
185
+ chunk_size: Optional[int] = None,
186
+ _mask_trans: bool = True,
187
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
188
+ # DeepMind doesn't mask these transitions in the source, so _mask_trans
189
+ # should be disabled to better approximate the exact activations of
190
+ # the original.
191
+ msa_trans_mask = msa_mask if _mask_trans else None
192
+ pair_trans_mask = pair_mask if _mask_trans else None
193
+
194
+ m = m + self.msa_transition(
195
+ m, mask=msa_trans_mask, chunk_size=chunk_size
196
+ )
197
+ z = z + self.outer_product_mean(
198
+ m, mask=msa_mask, chunk_size=chunk_size
199
+ )
200
+ z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
201
+ z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
202
+ z = z + self.ps_dropout_row_layer(
203
+ self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
204
+ )
205
+ z = z + self.ps_dropout_col_layer(
206
+ self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
207
+ )
208
+ z = z + self.pair_transition(
209
+ z, mask=pair_trans_mask, chunk_size=chunk_size
210
+ )
211
+
212
+ return m, z
213
+
214
+
215
+ class EvoformerBlock(nn.Module):
216
+ def __init__(self,
217
+ c_m: int,
218
+ c_z: int,
219
+ c_hidden_msa_att: int,
220
+ c_hidden_opm: int,
221
+ c_hidden_mul: int,
222
+ c_hidden_pair_att: int,
223
+ no_heads_msa: int,
224
+ no_heads_pair: int,
225
+ transition_n: int,
226
+ msa_dropout: float,
227
+ pair_dropout: float,
228
+ inf: float,
229
+ eps: float,
230
+ ):
231
+ super(EvoformerBlock, self).__init__()
232
+
233
+ self.msa_att_row = MSARowAttentionWithPairBias(
234
+ c_m=c_m,
235
+ c_z=c_z,
236
+ c_hidden=c_hidden_msa_att,
237
+ no_heads=no_heads_msa,
238
+ inf=inf,
239
+ )
240
+
241
+ self.msa_att_col = MSAColumnAttention(
242
+ c_m,
243
+ c_hidden_msa_att,
244
+ no_heads_msa,
245
+ inf=inf,
246
+ )
247
+
248
+ self.msa_dropout_layer = DropoutRowwise(msa_dropout)
249
+
250
+ self.core = EvoformerBlockCore(
251
+ c_m=c_m,
252
+ c_z=c_z,
253
+ c_hidden_opm=c_hidden_opm,
254
+ c_hidden_mul=c_hidden_mul,
255
+ c_hidden_pair_att=c_hidden_pair_att,
256
+ no_heads_msa=no_heads_msa,
257
+ no_heads_pair=no_heads_pair,
258
+ transition_n=transition_n,
259
+ pair_dropout=pair_dropout,
260
+ inf=inf,
261
+ eps=eps,
262
+ )
263
+
264
+ def forward(self,
265
+ m: torch.Tensor,
266
+ z: torch.Tensor,
267
+ msa_mask: torch.Tensor,
268
+ pair_mask: torch.Tensor,
269
+ chunk_size: Optional[int] = None,
270
+ _mask_trans: bool = True,
271
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
272
+ m = m + self.msa_dropout_layer(
273
+ self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
274
+ )
275
+ m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
276
+ m, z = self.core(
277
+ m,
278
+ z,
279
+ msa_mask=msa_mask,
280
+ pair_mask=pair_mask,
281
+ chunk_size=chunk_size,
282
+ _mask_trans=_mask_trans,
283
+ )
284
+
285
+ return m, z
286
+
287
+
288
+ class ExtraMSABlock(nn.Module):
289
+ """
290
+ Almost identical to the standard EvoformerBlock, except in that the
291
+ ExtraMSABlock uses GlobalAttention for MSA column attention and
292
+ requires more fine-grained control over checkpointing. Separated from
293
+ its twin to preserve the TorchScript-ability of the latter.
294
+ """
295
+ def __init__(self,
296
+ c_m: int,
297
+ c_z: int,
298
+ c_hidden_msa_att: int,
299
+ c_hidden_opm: int,
300
+ c_hidden_mul: int,
301
+ c_hidden_pair_att: int,
302
+ no_heads_msa: int,
303
+ no_heads_pair: int,
304
+ transition_n: int,
305
+ msa_dropout: float,
306
+ pair_dropout: float,
307
+ inf: float,
308
+ eps: float,
309
+ ckpt: bool,
310
+ ):
311
+ super(ExtraMSABlock, self).__init__()
312
+
313
+ self.ckpt = ckpt
314
+
315
+ self.msa_att_row = MSARowAttentionWithPairBias(
316
+ c_m=c_m,
317
+ c_z=c_z,
318
+ c_hidden=c_hidden_msa_att,
319
+ no_heads=no_heads_msa,
320
+ inf=inf,
321
+ )
322
+
323
+ self.msa_att_col = MSAColumnGlobalAttention(
324
+ c_in=c_m,
325
+ c_hidden=c_hidden_msa_att,
326
+ no_heads=no_heads_msa,
327
+ inf=inf,
328
+ eps=eps,
329
+ )
330
+
331
+ self.msa_dropout_layer = DropoutRowwise(msa_dropout)
332
+
333
+ self.core = EvoformerBlockCore(
334
+ c_m=c_m,
335
+ c_z=c_z,
336
+ c_hidden_opm=c_hidden_opm,
337
+ c_hidden_mul=c_hidden_mul,
338
+ c_hidden_pair_att=c_hidden_pair_att,
339
+ no_heads_msa=no_heads_msa,
340
+ no_heads_pair=no_heads_pair,
341
+ transition_n=transition_n,
342
+ pair_dropout=pair_dropout,
343
+ inf=inf,
344
+ eps=eps,
345
+ )
346
+
347
+ def forward(self,
348
+ m: torch.Tensor,
349
+ z: torch.Tensor,
350
+ msa_mask: torch.Tensor,
351
+ pair_mask: torch.Tensor,
352
+ chunk_size: Optional[int] = None,
353
+ _chunk_logits: Optional[int] = 1024,
354
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
355
+ def add(m1, m2):
356
+ # The first operation in a checkpoint can't be in-place, but it's
357
+ # nice to have in-place addition during inference. Thus...
358
+ if(torch.is_grad_enabled()):
359
+ m1 = m1 + m2
360
+ else:
361
+ m1 += m2
362
+
363
+ return m1
364
+
365
+ m = add(m, self.msa_dropout_layer(
366
+ self.msa_att_row(
367
+ m.clone() if torch.is_grad_enabled() else m,
368
+ z=z.clone() if torch.is_grad_enabled() else z,
369
+ mask=msa_mask,
370
+ chunk_size=chunk_size,
371
+ _chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
372
+ _checkpoint_chunks=
373
+ self.ckpt if torch.is_grad_enabled() else False,
374
+ )
375
+ ))
376
+
377
+ def fn(m, z):
378
+ m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
379
+ m, z = self.core(
380
+ m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
381
+ )
382
+
383
+ return m, z
384
+
385
+ if(torch.is_grad_enabled() and self.ckpt):
386
+ checkpoint_fn = get_checkpoint_fn()
387
+ m, z = checkpoint_fn(fn, m, z)
388
+ else:
389
+ m, z = fn(m, z)
390
+
391
+ return m, z
392
+
393
+
394
+ class EvoformerStack(nn.Module):
395
+ """
396
+ Main Evoformer trunk.
397
+
398
+ Implements Algorithm 6.
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ c_m: int,
404
+ c_z: int,
405
+ c_hidden_msa_att: int,
406
+ c_hidden_opm: int,
407
+ c_hidden_mul: int,
408
+ c_hidden_pair_att: int,
409
+ c_s: int,
410
+ no_heads_msa: int,
411
+ no_heads_pair: int,
412
+ no_blocks: int,
413
+ transition_n: int,
414
+ msa_dropout: float,
415
+ pair_dropout: float,
416
+ blocks_per_ckpt: int,
417
+ inf: float,
418
+ eps: float,
419
+ clear_cache_between_blocks: bool = False,
420
+ **kwargs,
421
+ ):
422
+ """
423
+ Args:
424
+ c_m:
425
+ MSA channel dimension
426
+ c_z:
427
+ Pair channel dimension
428
+ c_hidden_msa_att:
429
+ Hidden dimension in MSA attention
430
+ c_hidden_opm:
431
+ Hidden dimension in outer product mean module
432
+ c_hidden_mul:
433
+ Hidden dimension in multiplicative updates
434
+ c_hidden_pair_att:
435
+ Hidden dimension in triangular attention
436
+ c_s:
437
+ Channel dimension of the output "single" embedding
438
+ no_heads_msa:
439
+ Number of heads used for MSA attention
440
+ no_heads_pair:
441
+ Number of heads used for pair attention
442
+ no_blocks:
443
+ Number of Evoformer blocks in the stack
444
+ transition_n:
445
+ Factor by which to multiply c_m to obtain the MSATransition
446
+ hidden dimension
447
+ msa_dropout:
448
+ Dropout rate for MSA activations
449
+ pair_dropout:
450
+ Dropout used for pair activations
451
+ blocks_per_ckpt:
452
+ Number of Evoformer blocks in each activation checkpoint
453
+ clear_cache_between_blocks:
454
+ Whether to clear CUDA's GPU memory cache between blocks of the
455
+ stack. Slows down each block but can reduce fragmentation
456
+ """
457
+ super(EvoformerStack, self).__init__()
458
+
459
+ self.blocks_per_ckpt = blocks_per_ckpt
460
+ self.clear_cache_between_blocks = clear_cache_between_blocks
461
+
462
+ self.blocks = nn.ModuleList()
463
+
464
+ for _ in range(no_blocks):
465
+ block = EvoformerBlock(
466
+ c_m=c_m,
467
+ c_z=c_z,
468
+ c_hidden_msa_att=c_hidden_msa_att,
469
+ c_hidden_opm=c_hidden_opm,
470
+ c_hidden_mul=c_hidden_mul,
471
+ c_hidden_pair_att=c_hidden_pair_att,
472
+ no_heads_msa=no_heads_msa,
473
+ no_heads_pair=no_heads_pair,
474
+ transition_n=transition_n,
475
+ msa_dropout=msa_dropout,
476
+ pair_dropout=pair_dropout,
477
+ inf=inf,
478
+ eps=eps,
479
+ )
480
+ self.blocks.append(block)
481
+
482
+ self.linear = Linear(c_m, c_s)
483
+
484
+ def forward(self,
485
+ m: torch.Tensor,
486
+ z: torch.Tensor,
487
+ msa_mask: torch.Tensor,
488
+ pair_mask: torch.Tensor,
489
+ chunk_size: int,
490
+ _mask_trans: bool = True,
491
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
492
+ """
493
+ Args:
494
+ m:
495
+ [*, N_seq, N_res, C_m] MSA embedding
496
+ z:
497
+ [*, N_res, N_res, C_z] pair embedding
498
+ msa_mask:
499
+ [*, N_seq, N_res] MSA mask
500
+ pair_mask:
501
+ [*, N_res, N_res] pair mask
502
+ Returns:
503
+ m:
504
+ [*, N_seq, N_res, C_m] MSA embedding
505
+ z:
506
+ [*, N_res, N_res, C_z] pair embedding
507
+ s:
508
+ [*, N_res, C_s] single embedding (or None if extra MSA stack)
509
+ """
510
+ blocks = [
511
+ partial(
512
+ b,
513
+ msa_mask=msa_mask,
514
+ pair_mask=pair_mask,
515
+ chunk_size=chunk_size,
516
+ _mask_trans=_mask_trans,
517
+ )
518
+ for b in self.blocks
519
+ ]
520
+
521
+ if(self.clear_cache_between_blocks):
522
+ def block_with_cache_clear(block, *args):
523
+ torch.cuda.empty_cache()
524
+ return block(*args)
525
+
526
+ blocks = [partial(block_with_cache_clear, b) for b in blocks]
527
+
528
+ m, z = checkpoint_blocks(
529
+ blocks,
530
+ args=(m, z),
531
+ blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
532
+ )
533
+
534
+ s = self.linear(m[..., 0, :, :])
535
+
536
+ return m, z, s
537
+
538
+
539
+ class ExtraMSAStack(nn.Module):
540
+ """
541
+ Implements Algorithm 18.
542
+ """
543
+
544
+ def __init__(self,
545
+ c_m: int,
546
+ c_z: int,
547
+ c_hidden_msa_att: int,
548
+ c_hidden_opm: int,
549
+ c_hidden_mul: int,
550
+ c_hidden_pair_att: int,
551
+ no_heads_msa: int,
552
+ no_heads_pair: int,
553
+ no_blocks: int,
554
+ transition_n: int,
555
+ msa_dropout: float,
556
+ pair_dropout: float,
557
+ inf: float,
558
+ eps: float,
559
+ ckpt: bool,
560
+ clear_cache_between_blocks: bool = False,
561
+ **kwargs,
562
+ ):
563
+ super(ExtraMSAStack, self).__init__()
564
+
565
+ self.clear_cache_between_blocks = clear_cache_between_blocks
566
+ self.blocks = nn.ModuleList()
567
+ for _ in range(no_blocks):
568
+ block = ExtraMSABlock(
569
+ c_m=c_m,
570
+ c_z=c_z,
571
+ c_hidden_msa_att=c_hidden_msa_att,
572
+ c_hidden_opm=c_hidden_opm,
573
+ c_hidden_mul=c_hidden_mul,
574
+ c_hidden_pair_att=c_hidden_pair_att,
575
+ no_heads_msa=no_heads_msa,
576
+ no_heads_pair=no_heads_pair,
577
+ transition_n=transition_n,
578
+ msa_dropout=msa_dropout,
579
+ pair_dropout=pair_dropout,
580
+ inf=inf,
581
+ eps=eps,
582
+ ckpt=ckpt,
583
+ )
584
+ self.blocks.append(block)
585
+
586
+ def forward(self,
587
+ m: torch.Tensor,
588
+ z: torch.Tensor,
589
+ chunk_size: int,
590
+ msa_mask: Optional[torch.Tensor] = None,
591
+ pair_mask: Optional[torch.Tensor] = None,
592
+ _mask_trans: bool = True,
593
+ ) -> torch.Tensor:
594
+ """
595
+ Args:
596
+ m:
597
+ [*, N_extra, N_res, C_m] extra MSA embedding
598
+ z:
599
+ [*, N_res, N_res, C_z] pair embedding
600
+ msa_mask:
601
+ Optional [*, N_extra, N_res] MSA mask
602
+ pair_mask:
603
+ Optional [*, N_res, N_res] pair mask
604
+ Returns:
605
+ [*, N_res, N_res, C_z] pair update
606
+ """
607
+ #checkpoint_fn = get_checkpoint_fn()
608
+ #blocks = [
609
+ # partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
610
+ #]
611
+
612
+ #def dodo(b, *args):
613
+ # torch.cuda.empty_cache()
614
+ # return b(*args)
615
+
616
+ #blocks = [partial(dodo, b) for b in blocks]
617
+
618
+ #for b in blocks:
619
+ # if(torch.is_grad_enabled()):
620
+ # m, z = checkpoint_fn(b, *(m, z))
621
+ # else:
622
+ # m, z = b(m, z)
623
+
624
+ for b in self.blocks:
625
+ m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
626
+
627
+ if(self.clear_cache_between_blocks):
628
+ torch.cuda.empty_cache()
629
+
630
+ return z
openfold/model/heads.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from openfold.model.primitives import Linear, LayerNorm
20
+ from openfold.utils.loss import (
21
+ compute_plddt,
22
+ compute_tm,
23
+ compute_predicted_aligned_error,
24
+ )
25
+
26
+
27
+ class AuxiliaryHeads(nn.Module):
28
+ def __init__(self, config):
29
+ super(AuxiliaryHeads, self).__init__()
30
+
31
+ self.plddt = PerResidueLDDTCaPredictor(
32
+ **config["lddt"],
33
+ )
34
+
35
+ self.distogram = DistogramHead(
36
+ **config["distogram"],
37
+ )
38
+
39
+ self.masked_msa = MaskedMSAHead(
40
+ **config["masked_msa"],
41
+ )
42
+
43
+ self.experimentally_resolved = ExperimentallyResolvedHead(
44
+ **config["experimentally_resolved"],
45
+ )
46
+
47
+ if config.tm.enabled:
48
+ self.tm = TMScoreHead(
49
+ **config.tm,
50
+ )
51
+
52
+ self.config = config
53
+
54
+ def forward(self, outputs):
55
+ aux_out = {}
56
+ lddt_logits = self.plddt(outputs["sm"]["single"])
57
+ aux_out["lddt_logits"] = lddt_logits
58
+
59
+ # Required for relaxation later on
60
+ aux_out["plddt"] = compute_plddt(lddt_logits)
61
+
62
+ distogram_logits = self.distogram(outputs["pair"])
63
+ aux_out["distogram_logits"] = distogram_logits
64
+
65
+ masked_msa_logits = self.masked_msa(outputs["msa"])
66
+ aux_out["masked_msa_logits"] = masked_msa_logits
67
+
68
+ experimentally_resolved_logits = self.experimentally_resolved(
69
+ outputs["single"]
70
+ )
71
+ aux_out[
72
+ "experimentally_resolved_logits"
73
+ ] = experimentally_resolved_logits
74
+
75
+ if self.config.tm.enabled:
76
+ tm_logits = self.tm(outputs["pair"])
77
+ aux_out["tm_logits"] = tm_logits
78
+ aux_out["predicted_tm_score"] = compute_tm(
79
+ tm_logits, **self.config.tm
80
+ )
81
+ aux_out.update(
82
+ compute_predicted_aligned_error(
83
+ tm_logits,
84
+ **self.config.tm,
85
+ )
86
+ )
87
+
88
+ return aux_out
89
+
90
+
91
+ class PerResidueLDDTCaPredictor(nn.Module):
92
+ def __init__(self, no_bins, c_in, c_hidden):
93
+ super(PerResidueLDDTCaPredictor, self).__init__()
94
+
95
+ self.no_bins = no_bins
96
+ self.c_in = c_in
97
+ self.c_hidden = c_hidden
98
+
99
+ self.layer_norm = LayerNorm(self.c_in)
100
+
101
+ self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
102
+ self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
103
+ self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
104
+
105
+ self.relu = nn.ReLU()
106
+
107
+ def forward(self, s):
108
+ s = self.layer_norm(s)
109
+ s = self.linear_1(s)
110
+ s = self.relu(s)
111
+ s = self.linear_2(s)
112
+ s = self.relu(s)
113
+ s = self.linear_3(s)
114
+
115
+ return s
116
+
117
+
118
+ class DistogramHead(nn.Module):
119
+ """
120
+ Computes a distogram probability distribution.
121
+
122
+ For use in computation of distogram loss, subsection 1.9.8
123
+ """
124
+
125
+ def __init__(self, c_z, no_bins, **kwargs):
126
+ """
127
+ Args:
128
+ c_z:
129
+ Input channel dimension
130
+ no_bins:
131
+ Number of distogram bins
132
+ """
133
+ super(DistogramHead, self).__init__()
134
+
135
+ self.c_z = c_z
136
+ self.no_bins = no_bins
137
+
138
+ self.linear = Linear(self.c_z, self.no_bins, init="final")
139
+
140
+ def forward(self, z): # [*, N, N, C_z]
141
+ """
142
+ Args:
143
+ z:
144
+ [*, N_res, N_res, C_z] pair embedding
145
+ Returns:
146
+ [*, N, N, no_bins] distogram probability distribution
147
+ """
148
+ # [*, N, N, no_bins]
149
+ logits = self.linear(z)
150
+ logits = logits + logits.transpose(-2, -3)
151
+ return logits
152
+
153
+
154
+ class TMScoreHead(nn.Module):
155
+ """
156
+ For use in computation of TM-score, subsection 1.9.7
157
+ """
158
+
159
+ def __init__(self, c_z, no_bins, **kwargs):
160
+ """
161
+ Args:
162
+ c_z:
163
+ Input channel dimension
164
+ no_bins:
165
+ Number of bins
166
+ """
167
+ super(TMScoreHead, self).__init__()
168
+
169
+ self.c_z = c_z
170
+ self.no_bins = no_bins
171
+
172
+ self.linear = Linear(self.c_z, self.no_bins, init="final")
173
+
174
+ def forward(self, z):
175
+ """
176
+ Args:
177
+ z:
178
+ [*, N_res, N_res, C_z] pairwise embedding
179
+ Returns:
180
+ [*, N_res, N_res, no_bins] prediction
181
+ """
182
+ # [*, N, N, no_bins]
183
+ logits = self.linear(z)
184
+ return logits
185
+
186
+
187
+ class MaskedMSAHead(nn.Module):
188
+ """
189
+ For use in computation of masked MSA loss, subsection 1.9.9
190
+ """
191
+
192
+ def __init__(self, c_m, c_out, **kwargs):
193
+ """
194
+ Args:
195
+ c_m:
196
+ MSA channel dimension
197
+ c_out:
198
+ Output channel dimension
199
+ """
200
+ super(MaskedMSAHead, self).__init__()
201
+
202
+ self.c_m = c_m
203
+ self.c_out = c_out
204
+
205
+ self.linear = Linear(self.c_m, self.c_out, init="final")
206
+
207
+ def forward(self, m):
208
+ """
209
+ Args:
210
+ m:
211
+ [*, N_seq, N_res, C_m] MSA embedding
212
+ Returns:
213
+ [*, N_seq, N_res, C_out] reconstruction
214
+ """
215
+ # [*, N_seq, N_res, C_out]
216
+ logits = self.linear(m)
217
+ return logits
218
+
219
+
220
+ class ExperimentallyResolvedHead(nn.Module):
221
+ """
222
+ For use in computation of "experimentally resolved" loss, subsection
223
+ 1.9.10
224
+ """
225
+
226
+ def __init__(self, c_s, c_out, **kwargs):
227
+ """
228
+ Args:
229
+ c_s:
230
+ Input channel dimension
231
+ c_out:
232
+ Number of distogram bins
233
+ """
234
+ super(ExperimentallyResolvedHead, self).__init__()
235
+
236
+ self.c_s = c_s
237
+ self.c_out = c_out
238
+
239
+ self.linear = Linear(self.c_s, self.c_out, init="final")
240
+
241
+ def forward(self, s):
242
+ """
243
+ Args:
244
+ s:
245
+ [*, N_res, C_s] single embedding
246
+ Returns:
247
+ [*, N, C_out] logits
248
+ """
249
+ # [*, N, C_out]
250
+ logits = self.linear(s)
251
+ return logits
openfold/model/model.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from openfold.utils.feats import (
21
+ pseudo_beta_fn,
22
+ build_extra_msa_feat,
23
+ build_template_angle_feat,
24
+ build_template_pair_feat,
25
+ atom14_to_atom37,
26
+ )
27
+ from openfold.model.embedders import (
28
+ InputEmbedder,
29
+ RecyclingEmbedder,
30
+ TemplateAngleEmbedder,
31
+ TemplatePairEmbedder,
32
+ ExtraMSAEmbedder,
33
+ )
34
+ from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
35
+ from openfold.model.heads import AuxiliaryHeads
36
+ import openfold.np.residue_constants as residue_constants
37
+ from openfold.model.structure_module import StructureModule
38
+ from openfold.model.template import (
39
+ TemplatePairStack,
40
+ TemplatePointwiseAttention,
41
+ )
42
+ from openfold.utils.loss import (
43
+ compute_plddt,
44
+ )
45
+ from openfold.utils.tensor_utils import (
46
+ dict_multimap,
47
+ tensor_tree_map,
48
+ )
49
+
50
+
51
+ class AlphaFold(nn.Module):
52
+ """
53
+ Alphafold 2.
54
+
55
+ Implements Algorithm 2 (but with training).
56
+ """
57
+
58
+ def __init__(self, config):
59
+ """
60
+ Args:
61
+ config:
62
+ A dict-like config object (like the one in config.py)
63
+ """
64
+ super(AlphaFold, self).__init__()
65
+
66
+ self.globals = config.globals
67
+ config = config.model
68
+ template_config = config.template
69
+ extra_msa_config = config.extra_msa
70
+
71
+ # Main trunk + structure module
72
+ self.input_embedder = InputEmbedder(
73
+ **config["input_embedder"],
74
+ )
75
+ self.recycling_embedder = RecyclingEmbedder(
76
+ **config["recycling_embedder"],
77
+ )
78
+ self.template_angle_embedder = TemplateAngleEmbedder(
79
+ **template_config["template_angle_embedder"],
80
+ )
81
+ self.template_pair_embedder = TemplatePairEmbedder(
82
+ **template_config["template_pair_embedder"],
83
+ )
84
+ self.template_pair_stack = TemplatePairStack(
85
+ **template_config["template_pair_stack"],
86
+ )
87
+ self.template_pointwise_att = TemplatePointwiseAttention(
88
+ **template_config["template_pointwise_attention"],
89
+ )
90
+ self.extra_msa_embedder = ExtraMSAEmbedder(
91
+ **extra_msa_config["extra_msa_embedder"],
92
+ )
93
+ self.extra_msa_stack = ExtraMSAStack(
94
+ **extra_msa_config["extra_msa_stack"],
95
+ )
96
+ self.evoformer = EvoformerStack(
97
+ **config["evoformer_stack"],
98
+ )
99
+ self.structure_module = StructureModule(
100
+ **config["structure_module"],
101
+ )
102
+
103
+ self.aux_heads = AuxiliaryHeads(
104
+ config["heads"],
105
+ )
106
+
107
+ self.config = config
108
+
109
+ def embed_templates(self, batch, z, pair_mask, templ_dim):
110
+ # Embed the templates one at a time (with a poor man's vmap)
111
+ template_embeds = []
112
+ n_templ = batch["template_aatype"].shape[templ_dim]
113
+ for i in range(n_templ):
114
+ idx = batch["template_aatype"].new_tensor(i)
115
+ single_template_feats = tensor_tree_map(
116
+ lambda t: torch.index_select(t, templ_dim, idx),
117
+ batch,
118
+ )
119
+
120
+ single_template_embeds = {}
121
+ if self.config.template.embed_angles:
122
+ template_angle_feat = build_template_angle_feat(
123
+ single_template_feats,
124
+ )
125
+
126
+ # [*, S_t, N, C_m]
127
+ a = self.template_angle_embedder(template_angle_feat)
128
+
129
+ single_template_embeds["angle"] = a
130
+
131
+ # [*, S_t, N, N, C_t]
132
+ t = build_template_pair_feat(
133
+ single_template_feats,
134
+ inf=self.config.template.inf,
135
+ eps=self.config.template.eps,
136
+ **self.config.template.distogram,
137
+ ).to(z.dtype)
138
+ t = self.template_pair_embedder(t)
139
+
140
+ single_template_embeds.update({"pair": t})
141
+
142
+ template_embeds.append(single_template_embeds)
143
+
144
+ template_embeds = dict_multimap(
145
+ partial(torch.cat, dim=templ_dim),
146
+ template_embeds,
147
+ )
148
+
149
+ # [*, S_t, N, N, C_z]
150
+ t = self.template_pair_stack(
151
+ template_embeds["pair"],
152
+ pair_mask.unsqueeze(-3).to(dtype=z.dtype),
153
+ chunk_size=self.globals.chunk_size,
154
+ _mask_trans=self.config._mask_trans,
155
+ )
156
+
157
+ # [*, N, N, C_z]
158
+ t = self.template_pointwise_att(
159
+ t,
160
+ z,
161
+ template_mask=batch["template_mask"].to(dtype=z.dtype),
162
+ chunk_size=self.globals.chunk_size,
163
+ )
164
+ t = t * (torch.sum(batch["template_mask"]) > 0)
165
+
166
+ ret = {}
167
+ if self.config.template.embed_angles:
168
+ ret["template_angle_embedding"] = template_embeds["angle"]
169
+
170
+ ret.update({"template_pair_embedding": t})
171
+
172
+ return ret
173
+
174
+ def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
175
+ # Primary output dictionary
176
+ outputs = {}
177
+
178
+ # This needs to be done manually for DeepSpeed's sake
179
+ dtype = next(self.parameters()).dtype
180
+ for k in feats:
181
+ if(feats[k].dtype == torch.float32):
182
+ feats[k] = feats[k].to(dtype=dtype)
183
+
184
+ # Grab some data about the input
185
+ batch_dims = feats["target_feat"].shape[:-2]
186
+ no_batch_dims = len(batch_dims)
187
+ n = feats["target_feat"].shape[-2]
188
+ n_seq = feats["msa_feat"].shape[-3]
189
+ device = feats["target_feat"].device
190
+
191
+ # Prep some features
192
+ seq_mask = feats["seq_mask"]
193
+ pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
194
+ msa_mask = feats["msa_mask"]
195
+
196
+ # Initialize the MSA and pair representations
197
+
198
+ # m: [*, S_c, N, C_m]
199
+ # z: [*, N, N, C_z]
200
+ m, z = self.input_embedder(
201
+ feats["target_feat"],
202
+ feats["residue_index"],
203
+ feats["msa_feat"],
204
+ )
205
+
206
+ # Initialize the recycling embeddings, if needs be
207
+ if None in [m_1_prev, z_prev, x_prev]:
208
+ # [*, N, C_m]
209
+ m_1_prev = m.new_zeros(
210
+ (*batch_dims, n, self.config.input_embedder.c_m),
211
+ requires_grad=False,
212
+ )
213
+
214
+ # [*, N, N, C_z]
215
+ z_prev = z.new_zeros(
216
+ (*batch_dims, n, n, self.config.input_embedder.c_z),
217
+ requires_grad=False,
218
+ )
219
+
220
+ # [*, N, 3]
221
+ x_prev = z.new_zeros(
222
+ (*batch_dims, n, residue_constants.atom_type_num, 3),
223
+ requires_grad=False,
224
+ )
225
+
226
+ x_prev = pseudo_beta_fn(
227
+ feats["aatype"], x_prev, None
228
+ ).to(dtype=z.dtype)
229
+
230
+ # m_1_prev_emb: [*, N, C_m]
231
+ # z_prev_emb: [*, N, N, C_z]
232
+ m_1_prev_emb, z_prev_emb = self.recycling_embedder(
233
+ m_1_prev,
234
+ z_prev,
235
+ x_prev,
236
+ )
237
+
238
+ # If the number of recycling iterations is 0, skip recycling
239
+ # altogether. We zero them this way instead of computing them
240
+ # conditionally to avoid leaving parameters unused, which has annoying
241
+ # implications for DDP training.
242
+ if(not _recycle):
243
+ m_1_prev_emb *= 0
244
+ z_prev_emb *= 0
245
+
246
+ # [*, S_c, N, C_m]
247
+ m[..., 0, :, :] += m_1_prev_emb
248
+
249
+ # [*, N, N, C_z]
250
+ z += z_prev_emb
251
+
252
+ # Possibly prevents memory fragmentation
253
+ del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
254
+
255
+ # Embed the templates + merge with MSA/pair embeddings
256
+ if self.config.template.enabled:
257
+ template_feats = {
258
+ k: v for k, v in feats.items() if k.startswith("template_")
259
+ }
260
+ template_embeds = self.embed_templates(
261
+ template_feats,
262
+ z,
263
+ pair_mask.to(dtype=z.dtype),
264
+ no_batch_dims,
265
+ )
266
+
267
+ # [*, N, N, C_z]
268
+ z = z + template_embeds["template_pair_embedding"]
269
+
270
+ if self.config.template.embed_angles:
271
+ # [*, S = S_c + S_t, N, C_m]
272
+ m = torch.cat(
273
+ [m, template_embeds["template_angle_embedding"]],
274
+ dim=-3
275
+ )
276
+
277
+ # [*, S, N]
278
+ torsion_angles_mask = feats["template_torsion_angles_mask"]
279
+ msa_mask = torch.cat(
280
+ [feats["msa_mask"], torsion_angles_mask[..., 2]],
281
+ dim=-2
282
+ )
283
+
284
+ # Embed extra MSA features + merge with pairwise embeddings
285
+ if self.config.extra_msa.enabled:
286
+ # [*, S_e, N, C_e]
287
+ a = self.extra_msa_embedder(build_extra_msa_feat(feats))
288
+
289
+ # [*, N, N, C_z]
290
+ z = self.extra_msa_stack(
291
+ a,
292
+ z,
293
+ msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
294
+ chunk_size=self.globals.chunk_size,
295
+ pair_mask=pair_mask.to(dtype=z.dtype),
296
+ _mask_trans=self.config._mask_trans,
297
+ )
298
+
299
+ # Run MSA + pair embeddings through the trunk of the network
300
+ # m: [*, S, N, C_m]
301
+ # z: [*, N, N, C_z]
302
+ # s: [*, N, C_s]
303
+ m, z, s = self.evoformer(
304
+ m,
305
+ z,
306
+ msa_mask=msa_mask.to(dtype=m.dtype),
307
+ pair_mask=pair_mask.to(dtype=z.dtype),
308
+ chunk_size=self.globals.chunk_size,
309
+ _mask_trans=self.config._mask_trans,
310
+ )
311
+
312
+ outputs["msa"] = m[..., :n_seq, :, :]
313
+ outputs["pair"] = z
314
+ outputs["single"] = s
315
+
316
+ # Predict 3D structure
317
+ outputs["sm"] = self.structure_module(
318
+ s,
319
+ z,
320
+ feats["aatype"],
321
+ mask=feats["seq_mask"].to(dtype=s.dtype),
322
+ )
323
+ outputs["final_atom_positions"] = atom14_to_atom37(
324
+ outputs["sm"]["positions"][-1], feats
325
+ )
326
+ outputs["final_atom_mask"] = feats["atom37_atom_exists"]
327
+ outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
328
+
329
+ # Save embeddings for use during the next recycling iteration
330
+
331
+ # [*, N, C_m]
332
+ m_1_prev = m[..., 0, :, :]
333
+
334
+ # [*, N, N, C_z]
335
+ z_prev = z
336
+
337
+ # [*, N, 3]
338
+ x_prev = outputs["final_atom_positions"]
339
+
340
+ return outputs, m_1_prev, z_prev, x_prev
341
+
342
+ def _disable_activation_checkpointing(self):
343
+ self.template_pair_stack.blocks_per_ckpt = None
344
+ self.evoformer.blocks_per_ckpt = None
345
+
346
+ for b in self.extra_msa_stack.blocks:
347
+ b.ckpt = False
348
+
349
+ def _enable_activation_checkpointing(self):
350
+ self.template_pair_stack.blocks_per_ckpt = (
351
+ self.config.template.template_pair_stack.blocks_per_ckpt
352
+ )
353
+ self.evoformer.blocks_per_ckpt = (
354
+ self.config.evoformer_stack.blocks_per_ckpt
355
+ )
356
+
357
+ for b in self.extra_msa_stack.blocks:
358
+ b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
359
+
360
+ def forward(self, batch):
361
+ """
362
+ Args:
363
+ batch:
364
+ Dictionary of arguments outlined in Algorithm 2. Keys must
365
+ include the official names of the features in the
366
+ supplement subsection 1.2.9.
367
+
368
+ The final dimension of each input must have length equal to
369
+ the number of recycling iterations.
370
+
371
+ Features (without the recycling dimension):
372
+
373
+ "aatype" ([*, N_res]):
374
+ Contrary to the supplement, this tensor of residue
375
+ indices is not one-hot.
376
+ "target_feat" ([*, N_res, C_tf])
377
+ One-hot encoding of the target sequence. C_tf is
378
+ config.model.input_embedder.tf_dim.
379
+ "residue_index" ([*, N_res])
380
+ Tensor whose final dimension consists of
381
+ consecutive indices from 0 to N_res.
382
+ "msa_feat" ([*, N_seq, N_res, C_msa])
383
+ MSA features, constructed as in the supplement.
384
+ C_msa is config.model.input_embedder.msa_dim.
385
+ "seq_mask" ([*, N_res])
386
+ 1-D sequence mask
387
+ "msa_mask" ([*, N_seq, N_res])
388
+ MSA mask
389
+ "pair_mask" ([*, N_res, N_res])
390
+ 2-D pair mask
391
+ "extra_msa_mask" ([*, N_extra, N_res])
392
+ Extra MSA mask
393
+ "template_mask" ([*, N_templ])
394
+ Template mask (on the level of templates, not
395
+ residues)
396
+ "template_aatype" ([*, N_templ, N_res])
397
+ Tensor of template residue indices (indices greater
398
+ than 19 are clamped to 20 (Unknown))
399
+ "template_all_atom_positions"
400
+ ([*, N_templ, N_res, 37, 3])
401
+ Template atom coordinates in atom37 format
402
+ "template_all_atom_mask" ([*, N_templ, N_res, 37])
403
+ Template atom coordinate mask
404
+ "template_pseudo_beta" ([*, N_templ, N_res, 3])
405
+ Positions of template carbon "pseudo-beta" atoms
406
+ (i.e. C_beta for all residues but glycine, for
407
+ for which C_alpha is used instead)
408
+ "template_pseudo_beta_mask" ([*, N_templ, N_res])
409
+ Pseudo-beta mask
410
+ """
411
+ # Initialize recycling embeddings
412
+ m_1_prev, z_prev, x_prev = None, None, None
413
+
414
+ # Disable activation checkpointing for the first few recycling iters
415
+ is_grad_enabled = torch.is_grad_enabled()
416
+ self._disable_activation_checkpointing()
417
+
418
+ # Main recycling loop
419
+ num_iters = batch["aatype"].shape[-1]
420
+ for cycle_no in range(num_iters):
421
+ # Select the features for the current recycling cycle
422
+ fetch_cur_batch = lambda t: t[..., cycle_no]
423
+ feats = tensor_tree_map(fetch_cur_batch, batch)
424
+
425
+ # Enable grad iff we're training and it's the final recycling layer
426
+ is_final_iter = cycle_no == (num_iters - 1)
427
+ with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
428
+ if is_final_iter:
429
+ self._enable_activation_checkpointing()
430
+ # Sidestep AMP bug (PyTorch issue #65766)
431
+ if torch.is_autocast_enabled():
432
+ torch.clear_autocast_cache()
433
+
434
+ # Run the next iteration of the model
435
+ outputs, m_1_prev, z_prev, x_prev = self.iteration(
436
+ feats,
437
+ m_1_prev,
438
+ z_prev,
439
+ x_prev,
440
+ _recycle=(num_iters > 1)
441
+ )
442
+
443
+ # Run auxiliary heads
444
+ outputs.update(self.aux_heads(outputs))
445
+
446
+ return outputs
openfold/model/msa.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ import torch
18
+ import torch.nn as nn
19
+ from typing import Optional, List, Tuple
20
+
21
+ from openfold.model.primitives import (
22
+ Linear,
23
+ LayerNorm,
24
+ Attention,
25
+ GlobalAttention,
26
+ _attention_chunked_trainable,
27
+ )
28
+ from openfold.utils.checkpointing import get_checkpoint_fn
29
+ from openfold.utils.tensor_utils import (
30
+ chunk_layer,
31
+ permute_final_dims,
32
+ flatten_final_dims,
33
+ )
34
+
35
+
36
+ class MSAAttention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ c_in,
40
+ c_hidden,
41
+ no_heads,
42
+ pair_bias=False,
43
+ c_z=None,
44
+ inf=1e9,
45
+ ):
46
+ """
47
+ Args:
48
+ c_in:
49
+ Input channel dimension
50
+ c_hidden:
51
+ Per-head hidden channel dimension
52
+ no_heads:
53
+ Number of attention heads
54
+ pair_bias:
55
+ Whether to use pair embedding bias
56
+ c_z:
57
+ Pair embedding channel dimension. Ignored unless pair_bias
58
+ is true
59
+ inf:
60
+ A large number to be used in computing the attention mask
61
+ """
62
+ super(MSAAttention, self).__init__()
63
+
64
+ self.c_in = c_in
65
+ self.c_hidden = c_hidden
66
+ self.no_heads = no_heads
67
+ self.pair_bias = pair_bias
68
+ self.c_z = c_z
69
+ self.inf = inf
70
+
71
+ self.layer_norm_m = LayerNorm(self.c_in)
72
+
73
+ self.layer_norm_z = None
74
+ self.linear_z = None
75
+ if self.pair_bias:
76
+ self.layer_norm_z = LayerNorm(self.c_z)
77
+ self.linear_z = Linear(
78
+ self.c_z, self.no_heads, bias=False, init="normal"
79
+ )
80
+
81
+ self.mha = Attention(
82
+ self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
83
+ )
84
+
85
+ @torch.jit.ignore
86
+ def _chunk(self,
87
+ m: torch.Tensor,
88
+ biases: List[torch.Tensor],
89
+ chunk_size: int,
90
+ ) -> torch.Tensor:
91
+ return chunk_layer(
92
+ self.mha,
93
+ {"q_x": m, "kv_x": m, "biases": biases},
94
+ chunk_size=chunk_size,
95
+ no_batch_dims=len(m.shape[:-2]),
96
+ )
97
+
98
+ def _prep_inputs(self,
99
+ m: torch.Tensor,
100
+ z: Optional[torch.Tensor],
101
+ mask: Optional[torch.Tensor]
102
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ # [*, N_seq, N_res, C_m]
104
+ m = self.layer_norm_m(m)
105
+
106
+ n_seq, n_res = m.shape[-3:-1]
107
+ if mask is None:
108
+ # [*, N_seq, N_res]
109
+ mask = m.new_ones(
110
+ m.shape[:-3] + (n_seq, n_res),
111
+ )
112
+
113
+ # [*, N_seq, 1, 1, N_res]
114
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
115
+
116
+ # This step simply returns a larger view of the bias, and does not
117
+ # consume additional memory.
118
+ # [*, N_seq, no_heads, N_res, N_res]
119
+ #bias = bias.expand(
120
+ # ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
121
+ #)
122
+
123
+ if (self.pair_bias and
124
+ z is not None and # For the
125
+ self.layer_norm_z is not None and # benefit of
126
+ self.linear_z is not None # TorchScript
127
+ ):
128
+ # [*, N_res, N_res, C_z]
129
+ z = self.layer_norm_z(z)
130
+
131
+ # [*, N_res, N_res, no_heads]
132
+ z = self.linear_z(z)
133
+
134
+ # [*, 1, no_heads, N_res, N_res]
135
+ z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
136
+
137
+ return m, mask_bias, z
138
+
139
+ @torch.jit.ignore
140
+ def _chunked_msa_attn(self,
141
+ m: torch.Tensor,
142
+ z: Optional[torch.Tensor],
143
+ mask: Optional[torch.Tensor],
144
+ chunk_logits: int,
145
+ checkpoint: bool,
146
+ ) -> torch.Tensor:
147
+ MSA_DIM = -4
148
+
149
+ def _get_qkv(m, z):
150
+ m, mask_bias, z = self._prep_inputs(m, z, mask)
151
+ q, k, v = self.mha._prep_qkv(m, m)
152
+ return m, q, k, v, mask_bias, z
153
+
154
+ checkpoint_fn = get_checkpoint_fn()
155
+
156
+ if(torch.is_grad_enabled() and checkpoint):
157
+ m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
158
+ else:
159
+ m, q, k, v, mask_bias, z = _get_qkv(m, z)
160
+
161
+ o = _attention_chunked_trainable(
162
+ query=q,
163
+ key=k,
164
+ value=v,
165
+ biases=[mask_bias, z],
166
+ chunk_size=chunk_logits,
167
+ chunk_dim=MSA_DIM,
168
+ checkpoint=checkpoint,
169
+ )
170
+
171
+ if(torch.is_grad_enabled() and checkpoint):
172
+ # Storing an additional m here is far from ideal
173
+ m = checkpoint_fn(self.mha._wrap_up, o, m)
174
+ else:
175
+ m = self.mha._wrap_up(o, m)
176
+
177
+ return m
178
+
179
+ def forward(self,
180
+ m: torch.Tensor,
181
+ z: Optional[torch.Tensor] = None,
182
+ mask: Optional[torch.Tensor] = None,
183
+ chunk_size: Optional[int] = None,
184
+ _chunk_logits: Optional[int] = None,
185
+ _checkpoint_chunks: Optional[bool] = None,
186
+ ) -> torch.Tensor:
187
+ """
188
+ Args:
189
+ m:
190
+ [*, N_seq, N_res, C_m] MSA embedding
191
+ z:
192
+ [*, N_res, N_res, C_z] pair embedding. Required only if
193
+ pair_bias is True
194
+ mask:
195
+ [*, N_seq, N_res] MSA mask
196
+ chunk_size:
197
+ Size of chunks into which the inputs are split along their
198
+ batch dimensions. A low value decreases memory overhead at the
199
+ cost of slower execution. Chunking is not performed by default.
200
+
201
+ """
202
+ if(_chunk_logits is not None):
203
+ return self._chunked_msa_attn(
204
+ m=m, z=z, mask=mask,
205
+ chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
206
+ )
207
+
208
+ m, mask_bias, z = self._prep_inputs(m, z, mask)
209
+
210
+ biases = [mask_bias]
211
+ if(z is not None):
212
+ biases.append(z)
213
+
214
+ if chunk_size is not None:
215
+ m = self._chunk(m, biases, chunk_size)
216
+ else:
217
+ m = self.mha(
218
+ q_x=m,
219
+ kv_x=m,
220
+ biases=biases
221
+ )
222
+
223
+ return m
224
+
225
+
226
+ class MSARowAttentionWithPairBias(MSAAttention):
227
+ """
228
+ Implements Algorithm 7.
229
+ """
230
+
231
+ def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
232
+ """
233
+ Args:
234
+ c_m:
235
+ Input channel dimension
236
+ c_z:
237
+ Pair embedding channel dimension
238
+ c_hidden:
239
+ Per-head hidden channel dimension
240
+ no_heads:
241
+ Number of attention heads
242
+ inf:
243
+ Large number used to construct attention masks
244
+ """
245
+ super(MSARowAttentionWithPairBias, self).__init__(
246
+ c_m,
247
+ c_hidden,
248
+ no_heads,
249
+ pair_bias=True,
250
+ c_z=c_z,
251
+ inf=inf,
252
+ )
253
+
254
+
255
+ class MSAColumnAttention(nn.Module):
256
+ """
257
+ Implements Algorithm 8.
258
+
259
+ By rights, this should also be a subclass of MSAAttention. Alas,
260
+ most inheritance isn't supported by TorchScript.
261
+ """
262
+
263
+ def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
264
+ """
265
+ Args:
266
+ c_m:
267
+ MSA channel dimension
268
+ c_hidden:
269
+ Per-head hidden channel dimension
270
+ no_heads:
271
+ Number of attention heads
272
+ inf:
273
+ Large number used to construct attention masks
274
+ """
275
+ super(MSAColumnAttention, self).__init__()
276
+
277
+ self.c_m = c_m
278
+ self.c_hidden = c_hidden
279
+ self.no_heads = no_heads
280
+ self.inf = inf
281
+
282
+ self._msa_att = MSAAttention(
283
+ c_in=c_m,
284
+ c_hidden=c_hidden,
285
+ no_heads=no_heads,
286
+ pair_bias=False,
287
+ c_z=None,
288
+ inf=inf,
289
+ )
290
+
291
+ def forward(self,
292
+ m: torch.Tensor,
293
+ mask: Optional[torch.Tensor] = None,
294
+ chunk_size: Optional[int] = None
295
+ ) -> torch.Tensor:
296
+ """
297
+ Args:
298
+ m:
299
+ [*, N_seq, N_res, C_m] MSA embedding
300
+ mask:
301
+ [*, N_seq, N_res] MSA mask
302
+ chunk_size:
303
+ Size of chunks into which the inputs are split along their
304
+ batch dimensions. A low value decreases memory overhead at the
305
+ cost of slower execution. Chunking is not performed by default.
306
+ """
307
+ # [*, N_res, N_seq, C_in]
308
+ m = m.transpose(-2, -3)
309
+ if mask is not None:
310
+ mask = mask.transpose(-1, -2)
311
+
312
+ m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
313
+
314
+ # [*, N_seq, N_res, C_in]
315
+ m = m.transpose(-2, -3)
316
+ if mask is not None:
317
+ mask = mask.transpose(-1, -2)
318
+
319
+ return m
320
+
321
+
322
+ class MSAColumnGlobalAttention(nn.Module):
323
+ def __init__(
324
+ self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
325
+ ):
326
+ super(MSAColumnGlobalAttention, self).__init__()
327
+
328
+ self.c_in = c_in
329
+ self.c_hidden = c_hidden
330
+ self.no_heads = no_heads
331
+ self.inf = inf
332
+ self.eps = eps
333
+
334
+ self.layer_norm_m = nn.LayerNorm(c_in)
335
+
336
+ self.global_attention = GlobalAttention(
337
+ c_in=c_in,
338
+ c_hidden=c_hidden,
339
+ no_heads=no_heads,
340
+ inf=inf,
341
+ eps=eps,
342
+ )
343
+
344
+ @torch.jit.ignore
345
+ def _chunk(self,
346
+ m: torch.Tensor,
347
+ mask: torch.Tensor,
348
+ chunk_size: int,
349
+ ) -> torch.Tensor:
350
+ mha_input = {
351
+ "m": m,
352
+ "mask": mask,
353
+ }
354
+ return chunk_layer(
355
+ self.global_attention,
356
+ mha_input,
357
+ chunk_size=chunk_size,
358
+ no_batch_dims=len(m.shape[:-2]),
359
+ )
360
+
361
+ def forward(
362
+ self,
363
+ m: torch.Tensor,
364
+ mask: Optional[torch.Tensor] = None,
365
+ chunk_size: Optional[int] = None,
366
+ ) -> torch.Tensor:
367
+ n_seq, n_res, c_in = m.shape[-3:]
368
+
369
+ if mask is None:
370
+ # [*, N_seq, N_res]
371
+ mask = torch.ones(
372
+ m.shape[:-1],
373
+ dtype=m.dtype,
374
+ device=m.device,
375
+ ).detach()
376
+
377
+ # [*, N_res, N_seq, C_in]
378
+ m = m.transpose(-2, -3)
379
+ mask = mask.transpose(-1, -2)
380
+
381
+ # [*, N_res, N_seq, C_in]
382
+ m = self.layer_norm_m(m)
383
+
384
+ if chunk_size is not None:
385
+ m = self._chunk(m, mask, chunk_size)
386
+ else:
387
+ m = self.global_attention(m=m, mask=mask)
388
+
389
+ # [*, N_seq, N_res, C_in]
390
+ m = m.transpose(-2, -3)
391
+
392
+ return m
openfold/model/outer_product_mean.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from openfold.model.primitives import Linear
23
+ from openfold.utils.tensor_utils import chunk_layer
24
+
25
+
26
+ class OuterProductMean(nn.Module):
27
+ """
28
+ Implements Algorithm 10.
29
+ """
30
+
31
+ def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
32
+ """
33
+ Args:
34
+ c_m:
35
+ MSA embedding channel dimension
36
+ c_z:
37
+ Pair embedding channel dimension
38
+ c_hidden:
39
+ Hidden channel dimension
40
+ """
41
+ super(OuterProductMean, self).__init__()
42
+
43
+ self.c_m = c_m
44
+ self.c_z = c_z
45
+ self.c_hidden = c_hidden
46
+ self.eps = eps
47
+
48
+ self.layer_norm = nn.LayerNorm(c_m)
49
+ self.linear_1 = Linear(c_m, c_hidden)
50
+ self.linear_2 = Linear(c_m, c_hidden)
51
+ self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
52
+
53
+ def _opm(self, a, b):
54
+ # [*, N_res, N_res, C, C]
55
+ outer = torch.einsum("...bac,...dae->...bdce", a, b)
56
+
57
+ # [*, N_res, N_res, C * C]
58
+ outer = outer.reshape(outer.shape[:-2] + (-1,))
59
+
60
+ # [*, N_res, N_res, C_z]
61
+ outer = self.linear_out(outer)
62
+
63
+ return outer
64
+
65
+ @torch.jit.ignore
66
+ def _chunk(self,
67
+ a: torch.Tensor,
68
+ b: torch.Tensor,
69
+ chunk_size: int
70
+ ) -> torch.Tensor:
71
+ # Since the "batch dim" in this case is not a true batch dimension
72
+ # (in that the shape of the output depends on it), we need to
73
+ # iterate over it ourselves
74
+ a_reshape = a.reshape((-1,) + a.shape[-3:])
75
+ b_reshape = b.reshape((-1,) + b.shape[-3:])
76
+ out = []
77
+ for a_prime, b_prime in zip(a_reshape, b_reshape):
78
+ outer = chunk_layer(
79
+ partial(self._opm, b=b_prime),
80
+ {"a": a_prime},
81
+ chunk_size=chunk_size,
82
+ no_batch_dims=1,
83
+ )
84
+ out.append(outer)
85
+ outer = torch.stack(out, dim=0)
86
+ outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
87
+
88
+ return outer
89
+
90
+ def forward(self,
91
+ m: torch.Tensor,
92
+ mask: Optional[torch.Tensor] = None,
93
+ chunk_size: Optional[int] = None
94
+ ) -> torch.Tensor:
95
+ """
96
+ Args:
97
+ m:
98
+ [*, N_seq, N_res, C_m] MSA embedding
99
+ mask:
100
+ [*, N_seq, N_res] MSA mask
101
+ Returns:
102
+ [*, N_res, N_res, C_z] pair embedding update
103
+ """
104
+ if mask is None:
105
+ mask = m.new_ones(m.shape[:-1])
106
+
107
+ # [*, N_seq, N_res, C_m]
108
+ m = self.layer_norm(m)
109
+
110
+ # [*, N_seq, N_res, C]
111
+ mask = mask.unsqueeze(-1)
112
+ a = self.linear_1(m) * mask
113
+ b = self.linear_2(m) * mask
114
+
115
+ a = a.transpose(-2, -3)
116
+ b = b.transpose(-2, -3)
117
+
118
+ if chunk_size is not None:
119
+ outer = self._chunk(a, b, chunk_size)
120
+ else:
121
+ outer = self._opm(a, b)
122
+
123
+ # [*, N_res, N_res, 1]
124
+ norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
125
+
126
+ # [*, N_res, N_res, C_z]
127
+ outer = outer / (self.eps + norm)
128
+
129
+ return outer
openfold/model/pair_transition.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from openfold.model.primitives import Linear, LayerNorm
21
+ from openfold.utils.tensor_utils import chunk_layer
22
+
23
+
24
+ class PairTransition(nn.Module):
25
+ """
26
+ Implements Algorithm 15.
27
+ """
28
+
29
+ def __init__(self, c_z, n):
30
+ """
31
+ Args:
32
+ c_z:
33
+ Pair transition channel dimension
34
+ n:
35
+ Factor by which c_z is multiplied to obtain hidden channel
36
+ dimension
37
+ """
38
+ super(PairTransition, self).__init__()
39
+
40
+ self.c_z = c_z
41
+ self.n = n
42
+
43
+ self.layer_norm = LayerNorm(self.c_z)
44
+ self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
45
+ self.relu = nn.ReLU()
46
+ self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
47
+
48
+ def _transition(self, z, mask):
49
+ # [*, N_res, N_res, C_hidden]
50
+ z = self.linear_1(z)
51
+ z = self.relu(z)
52
+
53
+ # [*, N_res, N_res, C_z]
54
+ z = self.linear_2(z) * mask
55
+
56
+ return z
57
+
58
+ @torch.jit.ignore
59
+ def _chunk(self,
60
+ z: torch.Tensor,
61
+ mask: torch.Tensor,
62
+ chunk_size: int,
63
+ ) -> torch.Tensor:
64
+ return chunk_layer(
65
+ self._transition,
66
+ {"z": z, "mask": mask},
67
+ chunk_size=chunk_size,
68
+ no_batch_dims=len(z.shape[:-2]),
69
+ )
70
+
71
+
72
+ def forward(self,
73
+ z: torch.Tensor,
74
+ mask: Optional[torch.Tensor] = None,
75
+ chunk_size: Optional[int] = None,
76
+ ) -> torch.Tensor:
77
+ """
78
+ Args:
79
+ z:
80
+ [*, N_res, N_res, C_z] pair embedding
81
+ Returns:
82
+ [*, N_res, N_res, C_z] pair embedding update
83
+ """
84
+ # DISCREPANCY: DeepMind forgets to apply the mask in this module.
85
+ if mask is None:
86
+ mask = z.new_ones(z.shape[:-1])
87
+
88
+ # [*, N_res, N_res, 1]
89
+ mask = mask.unsqueeze(-1)
90
+
91
+ # [*, N_res, N_res, C_z]
92
+ z = self.layer_norm(z)
93
+
94
+ if chunk_size is not None:
95
+ z = self._chunk(z, mask, chunk_size)
96
+ else:
97
+ z = self._transition(z=z, mask=mask)
98
+
99
+ return z
openfold/model/primitives.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ import math
18
+ from typing import Optional, Callable, List, Tuple, Sequence
19
+ import numpy as np
20
+
21
+ import deepspeed
22
+ import torch
23
+ import torch.nn as nn
24
+ from scipy.stats import truncnorm
25
+
26
+ from openfold.utils.checkpointing import get_checkpoint_fn
27
+ from openfold.utils.tensor_utils import (
28
+ permute_final_dims,
29
+ flatten_final_dims,
30
+ _chunk_slice,
31
+ )
32
+
33
+
34
+ def _prod(nums):
35
+ out = 1
36
+ for n in nums:
37
+ out = out * n
38
+ return out
39
+
40
+
41
+ def _calculate_fan(linear_weight_shape, fan="fan_in"):
42
+ fan_out, fan_in = linear_weight_shape
43
+
44
+ if fan == "fan_in":
45
+ f = fan_in
46
+ elif fan == "fan_out":
47
+ f = fan_out
48
+ elif fan == "fan_avg":
49
+ f = (fan_in + fan_out) / 2
50
+ else:
51
+ raise ValueError("Invalid fan option")
52
+
53
+ return f
54
+
55
+
56
+ def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
57
+ shape = weights.shape
58
+ f = _calculate_fan(shape, fan)
59
+ scale = scale / max(1, f)
60
+ a = -2
61
+ b = 2
62
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
63
+ size = _prod(shape)
64
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
65
+ samples = np.reshape(samples, shape)
66
+ with torch.no_grad():
67
+ weights.copy_(torch.tensor(samples, device=weights.device))
68
+
69
+
70
+ def lecun_normal_init_(weights):
71
+ trunc_normal_init_(weights, scale=1.0)
72
+
73
+
74
+ def he_normal_init_(weights):
75
+ trunc_normal_init_(weights, scale=2.0)
76
+
77
+
78
+ def glorot_uniform_init_(weights):
79
+ nn.init.xavier_uniform_(weights, gain=1)
80
+
81
+
82
+ def final_init_(weights):
83
+ with torch.no_grad():
84
+ weights.fill_(0.0)
85
+
86
+
87
+ def gating_init_(weights):
88
+ with torch.no_grad():
89
+ weights.fill_(0.0)
90
+
91
+
92
+ def normal_init_(weights):
93
+ torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
94
+
95
+
96
+ def ipa_point_weights_init_(weights):
97
+ with torch.no_grad():
98
+ softplus_inverse_1 = 0.541324854612918
99
+ weights.fill_(softplus_inverse_1)
100
+
101
+
102
+ class Linear(nn.Linear):
103
+ """
104
+ A Linear layer with built-in nonstandard initializations. Called just
105
+ like torch.nn.Linear.
106
+
107
+ Implements the initializers in 1.11.4, plus some additional ones found
108
+ in the code.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ in_dim: int,
114
+ out_dim: int,
115
+ bias: bool = True,
116
+ init: str = "default",
117
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
118
+ ):
119
+ """
120
+ Args:
121
+ in_dim:
122
+ The final dimension of inputs to the layer
123
+ out_dim:
124
+ The final dimension of layer outputs
125
+ bias:
126
+ Whether to learn an additive bias. True by default
127
+ init:
128
+ The initializer to use. Choose from:
129
+
130
+ "default": LeCun fan-in truncated normal initialization
131
+ "relu": He initialization w/ truncated normal distribution
132
+ "glorot": Fan-average Glorot uniform initialization
133
+ "gating": Weights=0, Bias=1
134
+ "normal": Normal initialization with std=1/sqrt(fan_in)
135
+ "final": Weights=0, Bias=0
136
+
137
+ Overridden by init_fn if the latter is not None.
138
+ init_fn:
139
+ A custom initializer taking weight and bias as inputs.
140
+ Overrides init if not None.
141
+ """
142
+ super(Linear, self).__init__(in_dim, out_dim, bias=bias)
143
+
144
+ if bias:
145
+ with torch.no_grad():
146
+ self.bias.fill_(0)
147
+
148
+ if init_fn is not None:
149
+ init_fn(self.weight, self.bias)
150
+ else:
151
+ if init == "default":
152
+ lecun_normal_init_(self.weight)
153
+ elif init == "relu":
154
+ he_normal_init_(self.weight)
155
+ elif init == "glorot":
156
+ glorot_uniform_init_(self.weight)
157
+ elif init == "gating":
158
+ gating_init_(self.weight)
159
+ if bias:
160
+ with torch.no_grad():
161
+ self.bias.fill_(1.0)
162
+ elif init == "normal":
163
+ normal_init_(self.weight)
164
+ elif init == "final":
165
+ final_init_(self.weight)
166
+ else:
167
+ raise ValueError("Invalid init string.")
168
+
169
+
170
+ class LayerNorm(nn.Module):
171
+ def __init__(self, c_in, eps=1e-5):
172
+ super(LayerNorm, self).__init__()
173
+
174
+ self.c_in = (c_in,)
175
+ self.eps = eps
176
+
177
+ self.weight = nn.Parameter(torch.ones(c_in))
178
+ self.bias = nn.Parameter(torch.zeros(c_in))
179
+
180
+ def forward(self, x):
181
+ d = x.dtype
182
+ if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
183
+ with torch.cuda.amp.autocast(enabled=False):
184
+ out = nn.functional.layer_norm(
185
+ x,
186
+ self.c_in,
187
+ self.weight.to(dtype=d),
188
+ self.bias.to(dtype=d),
189
+ self.eps
190
+ )
191
+ else:
192
+ out = nn.functional.layer_norm(
193
+ x,
194
+ self.c_in,
195
+ self.weight,
196
+ self.bias,
197
+ self.eps,
198
+ )
199
+
200
+ return out
201
+
202
+ @torch.jit.ignore
203
+ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
204
+ """
205
+ Softmax, but without automatic casting to fp32 when the input is of
206
+ type bfloat16
207
+ """
208
+ d = t.dtype
209
+ if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
210
+ with torch.cuda.amp.autocast(enabled=False):
211
+ s = torch.nn.functional.softmax(t, dim=dim)
212
+ else:
213
+ s = torch.nn.functional.softmax(t, dim=dim)
214
+
215
+ return s
216
+
217
+
218
+ #@torch.jit.script
219
+ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
220
+ # [*, H, Q, C_hidden]
221
+ query = permute_final_dims(query, (1, 0, 2))
222
+
223
+ # [*, H, C_hidden, K]
224
+ key = permute_final_dims(key, (1, 2, 0))
225
+
226
+ # [*, H, V, C_hidden]
227
+ value = permute_final_dims(value, (1, 0, 2))
228
+
229
+ # [*, H, Q, K]
230
+ a = torch.matmul(query, key)
231
+
232
+ for b in biases:
233
+ a += b
234
+
235
+ a = softmax(a, -1)
236
+
237
+ # [*, H, Q, C_hidden]
238
+ a = torch.matmul(a, value)
239
+
240
+ # [*, Q, H, C_hidden]
241
+ a = a.transpose(-2, -3)
242
+
243
+ return a
244
+
245
+
246
+ @torch.jit.ignore
247
+ def _attention_chunked_trainable(
248
+ query, key, value, biases, chunk_size, chunk_dim, checkpoint,
249
+ ):
250
+ if(checkpoint and len(biases) > 2):
251
+ raise ValueError(
252
+ "Checkpointed version permits only permits two bias terms"
253
+ )
254
+
255
+ def _checkpointable_attention(q, k, v, b1, b2):
256
+ bs = [b for b in [b1, b2] if b is not None]
257
+ return _attention(q, k, v, bs)
258
+
259
+ o_chunks = []
260
+ checkpoint_fn = get_checkpoint_fn()
261
+ count = query.shape[chunk_dim]
262
+ for start in range(0, count, chunk_size):
263
+ end = start + chunk_size
264
+ idx = [slice(None)] * len(query.shape)
265
+ idx[chunk_dim] = slice(start, end)
266
+ idx_tup = tuple(idx)
267
+ q_chunk = query[idx_tup]
268
+ k_chunk = key[idx_tup]
269
+ v_chunk = value[idx_tup]
270
+
271
+ def _slice_bias(b):
272
+ idx[chunk_dim] = (
273
+ slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
274
+ )
275
+ return b[tuple(idx)]
276
+
277
+ if(checkpoint):
278
+ bias_1_chunk, bias_2_chunk = [
279
+ _slice_bias(b) if b is not None else None
280
+ for b in (biases + [None, None])[:2]
281
+ ]
282
+
283
+ o_chunk = checkpoint_fn(_checkpointable_attention,
284
+ q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
285
+ )
286
+ else:
287
+ bias_chunks = [
288
+ _slice_bias(b) for b in biases
289
+ ]
290
+
291
+ o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
292
+
293
+ o_chunks.append(o_chunk)
294
+
295
+ o = torch.cat(o_chunks, dim=chunk_dim)
296
+ return o
297
+
298
+
299
+ class Attention(nn.Module):
300
+ """
301
+ Standard multi-head attention using AlphaFold's default layer
302
+ initialization. Allows multiple bias vectors.
303
+ """
304
+ def __init__(
305
+ self,
306
+ c_q: int,
307
+ c_k: int,
308
+ c_v: int,
309
+ c_hidden: int,
310
+ no_heads: int,
311
+ gating: bool = True,
312
+ ):
313
+ """
314
+ Args:
315
+ c_q:
316
+ Input dimension of query data
317
+ c_k:
318
+ Input dimension of key data
319
+ c_v:
320
+ Input dimension of value data
321
+ c_hidden:
322
+ Per-head hidden dimension
323
+ no_heads:
324
+ Number of attention heads
325
+ gating:
326
+ Whether the output should be gated using query data
327
+ """
328
+ super(Attention, self).__init__()
329
+
330
+ self.c_q = c_q
331
+ self.c_k = c_k
332
+ self.c_v = c_v
333
+ self.c_hidden = c_hidden
334
+ self.no_heads = no_heads
335
+ self.gating = gating
336
+
337
+ # DISCREPANCY: c_hidden is not the per-head channel dimension, as
338
+ # stated in the supplement, but the overall channel dimension.
339
+
340
+ self.linear_q = Linear(
341
+ self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
342
+ )
343
+ self.linear_k = Linear(
344
+ self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
345
+ )
346
+ self.linear_v = Linear(
347
+ self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
348
+ )
349
+ self.linear_o = Linear(
350
+ self.c_hidden * self.no_heads, self.c_q, init="final"
351
+ )
352
+
353
+ self.linear_g = None
354
+ if self.gating:
355
+ self.linear_g = Linear(
356
+ self.c_q, self.c_hidden * self.no_heads, init="gating"
357
+ )
358
+
359
+ self.sigmoid = nn.Sigmoid()
360
+
361
+ def _prep_qkv(self,
362
+ q_x: torch.Tensor,
363
+ kv_x: torch.Tensor
364
+ ) -> Tuple[
365
+ torch.Tensor, torch.Tensor, torch.Tensor
366
+ ]:
367
+ # [*, Q/K/V, H * C_hidden]
368
+ q = self.linear_q(q_x)
369
+ k = self.linear_k(kv_x)
370
+ v = self.linear_v(kv_x)
371
+
372
+ # [*, Q/K, H, C_hidden]
373
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
374
+ k = k.view(k.shape[:-1] + (self.no_heads, -1))
375
+ v = v.view(v.shape[:-1] + (self.no_heads, -1))
376
+
377
+ q /= math.sqrt(self.c_hidden)
378
+
379
+ return q, k, v
380
+
381
+ def _wrap_up(self,
382
+ o: torch.Tensor,
383
+ q_x: torch.Tensor
384
+ ) -> torch.Tensor:
385
+ if(self.linear_g is not None):
386
+ g = self.sigmoid(self.linear_g(q_x))
387
+
388
+ # [*, Q, H, C_hidden]
389
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
390
+ o = o * g
391
+
392
+ # [*, Q, H * C_hidden]
393
+ o = flatten_final_dims(o, 2)
394
+
395
+ # [*, Q, C_q]
396
+ o = self.linear_o(o)
397
+
398
+ return o
399
+
400
+ def forward(
401
+ self,
402
+ q_x: torch.Tensor,
403
+ kv_x: torch.Tensor,
404
+ biases: Optional[List[torch.Tensor]] = None,
405
+ use_lma: bool = False,
406
+ q_chunk_size: Optional[int] = None,
407
+ kv_chunk_size: Optional[int] = None,
408
+ ) -> torch.Tensor:
409
+ """
410
+ Args:
411
+ q_x:
412
+ [*, Q, C_q] query data
413
+ kv_x:
414
+ [*, K, C_k] key data
415
+ biases:
416
+ List of biases that broadcast to [*, H, Q, K]
417
+ use_lma:
418
+ Whether to use low-memory attention
419
+ q_chunk_size:
420
+ Query chunk size (for LMA)
421
+ kv_chunk_size:
422
+ Key/Value chunk size (for LMA)
423
+ Returns
424
+ [*, Q, C_q] attention update
425
+ """
426
+ if(biases is None):
427
+ biases = []
428
+ if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
429
+ raise ValueError(
430
+ "If use_lma is specified, q_chunk_size and kv_chunk_size must "
431
+ "be provided"
432
+ )
433
+
434
+ q, k, v = self._prep_qkv(q_x, kv_x)
435
+
436
+ if(use_lma):
437
+ biases = [
438
+ b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
439
+ for b in biases
440
+ ]
441
+
442
+ o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
443
+ else:
444
+ o = _attention(q, k, v, biases)
445
+
446
+ o = self._wrap_up(o, q_x)
447
+
448
+ return o
449
+
450
+
451
+ class GlobalAttention(nn.Module):
452
+ def __init__(self, c_in, c_hidden, no_heads, inf, eps):
453
+ super(GlobalAttention, self).__init__()
454
+
455
+ self.c_in = c_in
456
+ self.c_hidden = c_hidden
457
+ self.no_heads = no_heads
458
+ self.inf = inf
459
+ self.eps = eps
460
+
461
+ self.linear_q = Linear(
462
+ c_in, c_hidden * no_heads, bias=False, init="glorot"
463
+ )
464
+
465
+ self.linear_k = Linear(
466
+ c_in, c_hidden, bias=False, init="glorot",
467
+ )
468
+ self.linear_v = Linear(
469
+ c_in, c_hidden, bias=False, init="glorot",
470
+ )
471
+ self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
472
+ self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
473
+
474
+ self.sigmoid = nn.Sigmoid()
475
+
476
+ def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
477
+ # [*, N_res, C_in]
478
+ q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
479
+ torch.sum(mask, dim=-1)[..., None] + self.eps
480
+ )
481
+
482
+ # [*, N_res, H * C_hidden]
483
+ q = self.linear_q(q)
484
+ q *= (self.c_hidden ** (-0.5))
485
+
486
+ # [*, N_res, H, C_hidden]
487
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
488
+
489
+ # [*, N_res, N_seq, C_hidden]
490
+ k = self.linear_k(m)
491
+ v = self.linear_v(m)
492
+
493
+ # [*, N_res, H, N_seq]
494
+ a = torch.matmul(
495
+ q,
496
+ k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
497
+ )
498
+ bias = (self.inf * (mask - 1))[..., :, None, :]
499
+ a += bias
500
+ a = softmax(a)
501
+
502
+ # [*, N_res, H, C_hidden]
503
+ o = torch.matmul(
504
+ a,
505
+ v,
506
+ )
507
+
508
+ # [*, N_res, N_seq, C_hidden]
509
+ g = self.sigmoid(self.linear_g(m))
510
+
511
+ # [*, N_res, N_seq, H, C_hidden]
512
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
513
+
514
+ # [*, N_res, N_seq, H, C_hidden]
515
+ o = o.unsqueeze(-3) * g
516
+
517
+ # [*, N_res, N_seq, H * C_hidden]
518
+ o = o.reshape(o.shape[:-2] + (-1,))
519
+
520
+ # [*, N_res, N_seq, C_in]
521
+ m = self.linear_o(o)
522
+
523
+ return m
524
+
525
+
526
+ def _lma(
527
+ q: torch.Tensor,
528
+ k: torch.Tensor,
529
+ v: torch.Tensor,
530
+ biases: List[torch.Tensor],
531
+ q_chunk_size: int,
532
+ kv_chunk_size: int,
533
+ ):
534
+ no_q, no_kv = q.shape[-3], k.shape[-3]
535
+
536
+ # [*, Q, H, C_hidden]
537
+ o = q.new_zeros(q.shape)
538
+ for q_s in range(0, no_q, q_chunk_size):
539
+ q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
540
+ large_bias_chunks = [
541
+ b[..., q_s: q_s + q_chunk_size, :] for b in biases
542
+ ]
543
+
544
+ maxes = []
545
+ weights = []
546
+ values = []
547
+ for kv_s in range(0, no_kv, kv_chunk_size):
548
+ k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
549
+ v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
550
+ small_bias_chunks = [
551
+ b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
552
+ ]
553
+
554
+ a = torch.einsum(
555
+ "...qhd,...khd->...hqk", q_chunk, k_chunk,
556
+ )
557
+
558
+ for b in small_bias_chunks:
559
+ a += b
560
+
561
+ a = a.transpose(-2, -3)
562
+
563
+ max_a = torch.max(a, dim=-1, keepdim=True)[0]
564
+ exp_a = torch.exp(a - max_a)
565
+ exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
566
+
567
+ maxes.append(max_a.detach().squeeze(-1))
568
+ weights.append(torch.sum(exp_a, dim=-1))
569
+ values.append(exp_v)
570
+
571
+ chunk_max = torch.stack(maxes, dim=-3)
572
+ chunk_weights = torch.stack(weights, dim=-3)
573
+ chunk_values = torch.stack(values, dim=-4)
574
+
575
+ global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
576
+ max_diffs = torch.exp(chunk_max - global_max)
577
+ chunk_values *= max_diffs.unsqueeze(-1)
578
+ chunk_weights *= max_diffs
579
+
580
+ all_values = torch.sum(chunk_values, dim=-4)
581
+ all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
582
+
583
+ q_chunk_out = all_values / all_weights
584
+
585
+ o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
586
+
587
+ return o
openfold/model/structure_module.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from functools import reduce
16
+ import importlib
17
+ import math
18
+ import sys
19
+ from operator import mul
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from typing import Optional, Tuple, Sequence
24
+
25
+ from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
26
+ from openfold.np.residue_constants import (
27
+ restype_rigid_group_default_frame,
28
+ restype_atom14_to_rigid_group,
29
+ restype_atom14_mask,
30
+ restype_atom14_rigid_group_positions,
31
+ )
32
+ from openfold.utils.feats import (
33
+ frames_and_literature_positions_to_atom14_pos,
34
+ torsion_angles_to_frames,
35
+ )
36
+ from openfold.utils.precision_utils import is_fp16_enabled
37
+ from openfold.utils.rigid_utils import Rotation, Rigid
38
+ from openfold.utils.tensor_utils import (
39
+ dict_multimap,
40
+ permute_final_dims,
41
+ flatten_final_dims,
42
+ )
43
+
44
+ # attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
45
+
46
+
47
+ class AngleResnetBlock(nn.Module):
48
+ def __init__(self, c_hidden):
49
+ """
50
+ Args:
51
+ c_hidden:
52
+ Hidden channel dimension
53
+ """
54
+ super(AngleResnetBlock, self).__init__()
55
+
56
+ self.c_hidden = c_hidden
57
+
58
+ self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu")
59
+ self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final")
60
+
61
+ self.relu = nn.ReLU()
62
+
63
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
64
+
65
+ s_initial = a
66
+
67
+ a = self.relu(a)
68
+ a = self.linear_1(a)
69
+ a = self.relu(a)
70
+ a = self.linear_2(a)
71
+
72
+ return a + s_initial
73
+
74
+
75
+ class AngleResnet(nn.Module):
76
+ """
77
+ Implements Algorithm 20, lines 11-14
78
+ """
79
+
80
+ def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
81
+ """
82
+ Args:
83
+ c_in:
84
+ Input channel dimension
85
+ c_hidden:
86
+ Hidden channel dimension
87
+ no_blocks:
88
+ Number of resnet blocks
89
+ no_angles:
90
+ Number of torsion angles to generate
91
+ epsilon:
92
+ Small constant for normalization
93
+ """
94
+ super(AngleResnet, self).__init__()
95
+
96
+ self.c_in = c_in
97
+ self.c_hidden = c_hidden
98
+ self.no_blocks = no_blocks
99
+ self.no_angles = no_angles
100
+ self.eps = epsilon
101
+
102
+ self.linear_in = Linear(self.c_in, self.c_hidden)
103
+ self.linear_initial = Linear(self.c_in, self.c_hidden)
104
+
105
+ self.layers = nn.ModuleList()
106
+ for _ in range(self.no_blocks):
107
+ layer = AngleResnetBlock(c_hidden=self.c_hidden)
108
+ self.layers.append(layer)
109
+
110
+ self.linear_out = Linear(self.c_hidden, self.no_angles * 2)
111
+
112
+ self.relu = nn.ReLU()
113
+
114
+ def forward(
115
+ self, s: torch.Tensor, s_initial: torch.Tensor
116
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
117
+ """
118
+ Args:
119
+ s:
120
+ [*, C_hidden] single embedding
121
+ s_initial:
122
+ [*, C_hidden] single embedding as of the start of the
123
+ StructureModule
124
+ Returns:
125
+ [*, no_angles, 2] predicted angles
126
+ """
127
+ # NOTE: The ReLU's applied to the inputs are absent from the supplement
128
+ # pseudocode but present in the source. For maximal compatibility with
129
+ # the pretrained weights, I'm going with the source.
130
+
131
+ # [*, C_hidden]
132
+ s_initial = self.relu(s_initial)
133
+ s_initial = self.linear_initial(s_initial)
134
+ s = self.relu(s)
135
+ s = self.linear_in(s)
136
+ s = s + s_initial
137
+
138
+ for l in self.layers:
139
+ s = l(s)
140
+
141
+ s = self.relu(s)
142
+
143
+ # [*, no_angles * 2]
144
+ s = self.linear_out(s)
145
+
146
+ # [*, no_angles, 2]
147
+ s = s.view(s.shape[:-1] + (-1, 2))
148
+
149
+ unnormalized_s = s
150
+ norm_denom = torch.sqrt(
151
+ torch.clamp(
152
+ torch.sum(s ** 2, dim=-1, keepdim=True),
153
+ min=self.eps,
154
+ )
155
+ )
156
+ s = s / norm_denom
157
+
158
+ return unnormalized_s, s
159
+
160
+
161
+ class InvariantPointAttention(nn.Module):
162
+ """
163
+ Implements Algorithm 22.
164
+ """
165
+ def __init__(
166
+ self,
167
+ c_s: int,
168
+ c_z: int,
169
+ c_hidden: int,
170
+ no_heads: int,
171
+ no_qk_points: int,
172
+ no_v_points: int,
173
+ inf: float = 1e5,
174
+ eps: float = 1e-8,
175
+ ):
176
+ """
177
+ Args:
178
+ c_s:
179
+ Single representation channel dimension
180
+ c_z:
181
+ Pair representation channel dimension
182
+ c_hidden:
183
+ Hidden channel dimension
184
+ no_heads:
185
+ Number of attention heads
186
+ no_qk_points:
187
+ Number of query/key points to generate
188
+ no_v_points:
189
+ Number of value points to generate
190
+ """
191
+ super(InvariantPointAttention, self).__init__()
192
+
193
+ self.c_s = c_s
194
+ self.c_z = c_z
195
+ self.c_hidden = c_hidden
196
+ self.no_heads = no_heads
197
+ self.no_qk_points = no_qk_points
198
+ self.no_v_points = no_v_points
199
+ self.inf = inf
200
+ self.eps = eps
201
+
202
+ # These linear layers differ from their specifications in the
203
+ # supplement. There, they lack bias and use Glorot initialization.
204
+ # Here as in the official source, they have bias and use the default
205
+ # Lecun initialization.
206
+ hc = self.c_hidden * self.no_heads
207
+ self.linear_q = Linear(self.c_s, hc)
208
+ self.linear_kv = Linear(self.c_s, 2 * hc)
209
+
210
+ hpq = self.no_heads * self.no_qk_points * 3
211
+ self.linear_q_points = Linear(self.c_s, hpq)
212
+
213
+ hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
214
+ self.linear_kv_points = Linear(self.c_s, hpkv)
215
+
216
+ hpv = self.no_heads * self.no_v_points * 3
217
+
218
+ self.linear_b = Linear(self.c_z, self.no_heads)
219
+
220
+ self.head_weights = nn.Parameter(torch.zeros((no_heads)))
221
+ ipa_point_weights_init_(self.head_weights)
222
+
223
+ concat_out_dim = self.no_heads * (
224
+ self.c_z + self.c_hidden + self.no_v_points * 4
225
+ )
226
+ self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
227
+
228
+ self.softmax = nn.Softmax(dim=-1)
229
+ self.softplus = nn.Softplus()
230
+
231
+ def forward(
232
+ self,
233
+ s: torch.Tensor,
234
+ z: Optional[torch.Tensor],
235
+ r: Rigid,
236
+ mask: torch.Tensor,
237
+ inplace_safe: bool = False,
238
+ _offload_inference: bool = False,
239
+ _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
240
+ ) -> torch.Tensor:
241
+ """
242
+ Args:
243
+ s:
244
+ [*, N_res, C_s] single representation
245
+ z:
246
+ [*, N_res, N_res, C_z] pair representation
247
+ r:
248
+ [*, N_res] transformation object
249
+ mask:
250
+ [*, N_res] mask
251
+ Returns:
252
+ [*, N_res, C_s] single representation update
253
+ """
254
+ if(_offload_inference and inplace_safe):
255
+ z = _z_reference_list
256
+ else:
257
+ z = [z]
258
+
259
+ #######################################
260
+ # Generate scalar and point activations
261
+ #######################################
262
+ # [*, N_res, H * C_hidden]
263
+ q = self.linear_q(s)
264
+ kv = self.linear_kv(s)
265
+
266
+ # [*, N_res, H, C_hidden]
267
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
268
+
269
+ # [*, N_res, H, 2 * C_hidden]
270
+ kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
271
+
272
+ # [*, N_res, H, C_hidden]
273
+ k, v = torch.split(kv, self.c_hidden, dim=-1)
274
+
275
+ # [*, N_res, H * P_q * 3]
276
+ q_pts = self.linear_q_points(s)
277
+
278
+ # This is kind of clunky, but it's how the original does it
279
+ # [*, N_res, H * P_q, 3]
280
+ q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
281
+ q_pts = torch.stack(q_pts, dim=-1)
282
+ q_pts = r[..., None].apply(q_pts)
283
+
284
+ # [*, N_res, H, P_q, 3]
285
+ q_pts = q_pts.view(
286
+ q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
287
+ )
288
+
289
+ # [*, N_res, H * (P_q + P_v) * 3]
290
+ kv_pts = self.linear_kv_points(s)
291
+
292
+ # [*, N_res, H * (P_q + P_v), 3]
293
+ kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
294
+ kv_pts = torch.stack(kv_pts, dim=-1)
295
+ kv_pts = r[..., None].apply(kv_pts)
296
+
297
+ # [*, N_res, H, (P_q + P_v), 3]
298
+ kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
299
+
300
+ # [*, N_res, H, P_q/P_v, 3]
301
+ k_pts, v_pts = torch.split(
302
+ kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
303
+ )
304
+
305
+ ##########################
306
+ # Compute attention scores
307
+ ##########################
308
+ # [*, N_res, N_res, H]
309
+ b = self.linear_b(z[0])
310
+
311
+ if(_offload_inference):
312
+ assert(sys.getrefcount(z[0]) == 2)
313
+ z[0] = z[0].cpu()
314
+
315
+ # [*, H, N_res, N_res]
316
+ if(is_fp16_enabled()):
317
+ with torch.cuda.amp.autocast(enabled=False):
318
+ a = torch.matmul(
319
+ permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
320
+ permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
321
+ )
322
+ else:
323
+ a = torch.matmul(
324
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
325
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
326
+ )
327
+
328
+ a *= math.sqrt(1.0 / (3 * self.c_hidden))
329
+ a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
330
+
331
+ # [*, N_res, N_res, H, P_q, 3]
332
+ pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
333
+ if(inplace_safe):
334
+ pt_att *= pt_att
335
+ else:
336
+ pt_att = pt_att ** 2
337
+
338
+ # [*, N_res, N_res, H, P_q]
339
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
340
+ head_weights = self.softplus(self.head_weights).view(
341
+ *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
342
+ )
343
+ head_weights = head_weights * math.sqrt(
344
+ 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
345
+ )
346
+ if(inplace_safe):
347
+ pt_att *= head_weights
348
+ else:
349
+ pt_att = pt_att * head_weights
350
+
351
+ # [*, N_res, N_res, H]
352
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
353
+ # [*, N_res, N_res]
354
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
355
+ square_mask = self.inf * (square_mask - 1)
356
+
357
+ # [*, H, N_res, N_res]
358
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
359
+
360
+ if(inplace_safe):
361
+ a += pt_att
362
+ del pt_att
363
+ a += square_mask.unsqueeze(-3)
364
+ # in-place softmax
365
+ attn_core_inplace_cuda.forward_(
366
+ a,
367
+ reduce(mul, a.shape[:-1]),
368
+ a.shape[-1],
369
+ )
370
+ else:
371
+ a = a + pt_att
372
+ a = a + square_mask.unsqueeze(-3)
373
+ a = self.softmax(a)
374
+
375
+ ################
376
+ # Compute output
377
+ ################
378
+ # [*, N_res, H, C_hidden]
379
+ o = torch.matmul(
380
+ a, v.transpose(-2, -3).to(dtype=a.dtype)
381
+ ).transpose(-2, -3)
382
+
383
+ # [*, N_res, H * C_hidden]
384
+ o = flatten_final_dims(o, 2)
385
+
386
+ # [*, H, 3, N_res, P_v]
387
+ if(inplace_safe):
388
+ v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
389
+ o_pt = [
390
+ torch.matmul(a, v.to(a.dtype))
391
+ for v in torch.unbind(v_pts, dim=-3)
392
+ ]
393
+ o_pt = torch.stack(o_pt, dim=-3)
394
+ else:
395
+ o_pt = torch.sum(
396
+ (
397
+ a[..., None, :, :, None]
398
+ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
399
+ ),
400
+ dim=-2,
401
+ )
402
+
403
+ # [*, N_res, H, P_v, 3]
404
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
405
+ o_pt = r[..., None, None].invert_apply(o_pt)
406
+
407
+ # [*, N_res, H * P_v]
408
+ o_pt_norm = flatten_final_dims(
409
+ torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
410
+ )
411
+
412
+ # [*, N_res, H * P_v, 3]
413
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
414
+
415
+ if(_offload_inference):
416
+ z[0] = z[0].to(o_pt.device)
417
+
418
+ # [*, N_res, H, C_z]
419
+ o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
420
+
421
+ # [*, N_res, H * C_z]
422
+ o_pair = flatten_final_dims(o_pair, 2)
423
+
424
+ # [*, N_res, C_s]
425
+ s = self.linear_out(
426
+ torch.cat(
427
+ (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
428
+ ).to(dtype=z[0].dtype)
429
+ )
430
+
431
+ return s
432
+
433
+
434
+ class BackboneUpdate(nn.Module):
435
+ """
436
+ Implements part of Algorithm 23.
437
+ """
438
+
439
+ def __init__(self, c_s):
440
+ """
441
+ Args:
442
+ c_s:
443
+ Single representation channel dimension
444
+ """
445
+ super(BackboneUpdate, self).__init__()
446
+
447
+ self.c_s = c_s
448
+
449
+ self.linear = Linear(self.c_s, 6, init="final")
450
+
451
+ def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
452
+ """
453
+ Args:
454
+ [*, N_res, C_s] single representation
455
+ Returns:
456
+ [*, N_res, 6] update vector
457
+ """
458
+ # [*, 6]
459
+ update = self.linear(s)
460
+
461
+ return update
462
+
463
+
464
+ class StructureModuleTransitionLayer(nn.Module):
465
+ def __init__(self, c):
466
+ super(StructureModuleTransitionLayer, self).__init__()
467
+
468
+ self.c = c
469
+
470
+ self.linear_1 = Linear(self.c, self.c, init="relu")
471
+ self.linear_2 = Linear(self.c, self.c, init="relu")
472
+ self.linear_3 = Linear(self.c, self.c, init="final")
473
+
474
+ self.relu = nn.ReLU()
475
+
476
+ def forward(self, s):
477
+ s_initial = s
478
+ s = self.linear_1(s)
479
+ s = self.relu(s)
480
+ s = self.linear_2(s)
481
+ s = self.relu(s)
482
+ s = self.linear_3(s)
483
+
484
+ s = s + s_initial
485
+
486
+ return s
487
+
488
+
489
+ class StructureModuleTransition(nn.Module):
490
+ def __init__(self, c, num_layers, dropout_rate):
491
+ super(StructureModuleTransition, self).__init__()
492
+
493
+ self.c = c
494
+ self.num_layers = num_layers
495
+ self.dropout_rate = dropout_rate
496
+
497
+ self.layers = nn.ModuleList()
498
+ for _ in range(self.num_layers):
499
+ l = StructureModuleTransitionLayer(self.c)
500
+ self.layers.append(l)
501
+
502
+ self.dropout = nn.Dropout(self.dropout_rate)
503
+ self.layer_norm = LayerNorm(self.c)
504
+
505
+ def forward(self, s):
506
+ for l in self.layers:
507
+ s = l(s)
508
+
509
+ s = self.dropout(s)
510
+ s = self.layer_norm(s)
511
+
512
+ return s
513
+
514
+
515
+ class StructureModule(nn.Module):
516
+ def __init__(
517
+ self,
518
+ c_s,
519
+ c_z,
520
+ c_ipa,
521
+ c_resnet,
522
+ no_heads_ipa,
523
+ no_qk_points,
524
+ no_v_points,
525
+ dropout_rate,
526
+ no_blocks,
527
+ no_transition_layers,
528
+ no_resnet_blocks,
529
+ no_angles,
530
+ trans_scale_factor,
531
+ epsilon,
532
+ inf,
533
+ **kwargs,
534
+ ):
535
+ """
536
+ Args:
537
+ c_s:
538
+ Single representation channel dimension
539
+ c_z:
540
+ Pair representation channel dimension
541
+ c_ipa:
542
+ IPA hidden channel dimension
543
+ c_resnet:
544
+ Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
545
+ no_heads_ipa:
546
+ Number of IPA heads
547
+ no_qk_points:
548
+ Number of query/key points to generate during IPA
549
+ no_v_points:
550
+ Number of value points to generate during IPA
551
+ dropout_rate:
552
+ Dropout rate used throughout the layer
553
+ no_blocks:
554
+ Number of structure module blocks
555
+ no_transition_layers:
556
+ Number of layers in the single representation transition
557
+ (Alg. 23 lines 8-9)
558
+ no_resnet_blocks:
559
+ Number of blocks in the angle resnet
560
+ no_angles:
561
+ Number of angles to generate in the angle resnet
562
+ trans_scale_factor:
563
+ Scale of single representation transition hidden dimension
564
+ epsilon:
565
+ Small number used in angle resnet normalization
566
+ inf:
567
+ Large number used for attention masking
568
+ """
569
+ super(StructureModule, self).__init__()
570
+
571
+ self.c_s = c_s
572
+ self.c_z = c_z
573
+ self.c_ipa = c_ipa
574
+ self.c_resnet = c_resnet
575
+ self.no_heads_ipa = no_heads_ipa
576
+ self.no_qk_points = no_qk_points
577
+ self.no_v_points = no_v_points
578
+ self.dropout_rate = dropout_rate
579
+ self.no_blocks = no_blocks
580
+ self.no_transition_layers = no_transition_layers
581
+ self.no_resnet_blocks = no_resnet_blocks
582
+ self.no_angles = no_angles
583
+ self.trans_scale_factor = trans_scale_factor
584
+ self.epsilon = epsilon
585
+ self.inf = inf
586
+
587
+ # Buffers to be lazily initialized later
588
+ # self.default_frames
589
+ # self.group_idx
590
+ # self.atom_mask
591
+ # self.lit_positions
592
+
593
+ self.layer_norm_s = LayerNorm(self.c_s)
594
+ self.layer_norm_z = LayerNorm(self.c_z)
595
+
596
+ self.linear_in = Linear(self.c_s, self.c_s)
597
+
598
+ self.ipa = InvariantPointAttention(
599
+ self.c_s,
600
+ self.c_z,
601
+ self.c_ipa,
602
+ self.no_heads_ipa,
603
+ self.no_qk_points,
604
+ self.no_v_points,
605
+ inf=self.inf,
606
+ eps=self.epsilon,
607
+ )
608
+
609
+ self.ipa_dropout = nn.Dropout(self.dropout_rate)
610
+ self.layer_norm_ipa = LayerNorm(self.c_s)
611
+
612
+ self.transition = StructureModuleTransition(
613
+ self.c_s,
614
+ self.no_transition_layers,
615
+ self.dropout_rate,
616
+ )
617
+
618
+ self.bb_update = BackboneUpdate(self.c_s)
619
+
620
+ self.angle_resnet = AngleResnet(
621
+ self.c_s,
622
+ self.c_resnet,
623
+ self.no_resnet_blocks,
624
+ self.no_angles,
625
+ self.epsilon,
626
+ )
627
+
628
+ def forward(
629
+ self,
630
+ evoformer_output_dict,
631
+ aatype,
632
+ mask=None,
633
+ inplace_safe=False,
634
+ _offload_inference=False,
635
+ ):
636
+ """
637
+ Args:
638
+ evoformer_output_dict:
639
+ Dictionary containing:
640
+ "single":
641
+ [*, N_res, C_s] single representation
642
+ "pair":
643
+ [*, N_res, N_res, C_z] pair representation
644
+ aatype:
645
+ [*, N_res] amino acid indices
646
+ mask:
647
+ Optional [*, N_res] sequence mask
648
+ Returns:
649
+ A dictionary of outputs
650
+ """
651
+ s = evoformer_output_dict["single"]
652
+
653
+ if mask is None:
654
+ # [*, N]
655
+ mask = s.new_ones(s.shape[:-1])
656
+
657
+ # [*, N, C_s]
658
+ s = self.layer_norm_s(s)
659
+
660
+ # [*, N, N, C_z]
661
+ z = self.layer_norm_z(evoformer_output_dict["pair"])
662
+
663
+ z_reference_list = None
664
+ if(_offload_inference):
665
+ assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
666
+ evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
667
+ z_reference_list = [z]
668
+ z = None
669
+
670
+ # [*, N, C_s]
671
+ s_initial = s
672
+ s = self.linear_in(s)
673
+
674
+ # [*, N]
675
+ rigids = Rigid.identity(
676
+ s.shape[:-1],
677
+ s.dtype,
678
+ s.device,
679
+ self.training,
680
+ fmt="quat",
681
+ )
682
+ outputs = []
683
+ for i in range(self.no_blocks):
684
+ # [*, N, C_s]
685
+ s = s + self.ipa(
686
+ s,
687
+ z,
688
+ rigids,
689
+ mask,
690
+ inplace_safe=inplace_safe,
691
+ _offload_inference=_offload_inference,
692
+ _z_reference_list=z_reference_list
693
+ )
694
+ s = self.ipa_dropout(s)
695
+ s = self.layer_norm_ipa(s)
696
+ s = self.transition(s)
697
+
698
+ # [*, N]
699
+ rigids = rigids.compose_q_update_vec(self.bb_update(s))
700
+
701
+ # To hew as closely as possible to AlphaFold, we convert our
702
+ # quaternion-based transformations to rotation-matrix ones
703
+ # here
704
+ backb_to_global = Rigid(
705
+ Rotation(
706
+ rot_mats=rigids.get_rots().get_rot_mats(),
707
+ quats=None
708
+ ),
709
+ rigids.get_trans(),
710
+ )
711
+
712
+ backb_to_global = backb_to_global.scale_translation(
713
+ self.trans_scale_factor
714
+ )
715
+
716
+ # [*, N, 7, 2]
717
+ unnormalized_angles, angles = self.angle_resnet(s, s_initial)
718
+
719
+ all_frames_to_global = self.torsion_angles_to_frames(
720
+ backb_to_global,
721
+ angles,
722
+ aatype,
723
+ )
724
+
725
+ pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
726
+ all_frames_to_global,
727
+ aatype,
728
+ )
729
+
730
+ scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
731
+
732
+ preds = {
733
+ "frames": scaled_rigids.to_tensor_7(),
734
+ "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
735
+ "unnormalized_angles": unnormalized_angles,
736
+ "angles": angles,
737
+ "positions": pred_xyz,
738
+ "states": s,
739
+ }
740
+
741
+ outputs.append(preds)
742
+
743
+ rigids = rigids.stop_rot_gradient()
744
+
745
+ del z, z_reference_list
746
+
747
+ if(_offload_inference):
748
+ evoformer_output_dict["pair"] = (
749
+ evoformer_output_dict["pair"].to(s.device)
750
+ )
751
+
752
+ outputs = dict_multimap(torch.stack, outputs)
753
+ outputs["single"] = s
754
+
755
+ return outputs
756
+
757
+ def _init_residue_constants(self, float_dtype, device):
758
+ if not hasattr(self, "default_frames"):
759
+ self.register_buffer(
760
+ "default_frames",
761
+ torch.tensor(
762
+ restype_rigid_group_default_frame,
763
+ dtype=float_dtype,
764
+ device=device,
765
+ requires_grad=False,
766
+ ),
767
+ persistent=False,
768
+ )
769
+ if not hasattr(self, "group_idx"):
770
+ self.register_buffer(
771
+ "group_idx",
772
+ torch.tensor(
773
+ restype_atom14_to_rigid_group,
774
+ device=device,
775
+ requires_grad=False,
776
+ ),
777
+ persistent=False,
778
+ )
779
+ if not hasattr(self, "atom_mask"):
780
+ self.register_buffer(
781
+ "atom_mask",
782
+ torch.tensor(
783
+ restype_atom14_mask,
784
+ dtype=float_dtype,
785
+ device=device,
786
+ requires_grad=False,
787
+ ),
788
+ persistent=False,
789
+ )
790
+ if not hasattr(self, "lit_positions"):
791
+ self.register_buffer(
792
+ "lit_positions",
793
+ torch.tensor(
794
+ restype_atom14_rigid_group_positions,
795
+ dtype=float_dtype,
796
+ device=device,
797
+ requires_grad=False,
798
+ ),
799
+ persistent=False,
800
+ )
801
+
802
+ def torsion_angles_to_frames(self, r, alpha, f):
803
+ # Lazily initialize the residue constants on the correct device
804
+ self._init_residue_constants(alpha.dtype, alpha.device)
805
+ # Separated purely to make testing less annoying
806
+ return torsion_angles_to_frames(r, alpha, f, self.default_frames)
807
+
808
+ def frames_and_literature_positions_to_atom14_pos(
809
+ self, r, f # [*, N, 8] # [*, N]
810
+ ):
811
+ # Lazily initialize the residue constants on the correct device
812
+ self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
813
+ return frames_and_literature_positions_to_atom14_pos(
814
+ r,
815
+ f,
816
+ self.default_frames,
817
+ self.group_idx,
818
+ self.atom_mask,
819
+ self.lit_positions,
820
+ )
openfold/model/template.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from functools import partial
16
+ import math
17
+ from typing import Optional, List
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from openfold.model.primitives import Linear, LayerNorm, Attention
23
+ from openfold.model.dropout import (
24
+ DropoutRowwise,
25
+ DropoutColumnwise,
26
+ )
27
+ from openfold.model.pair_transition import PairTransition
28
+ from openfold.model.triangular_attention import (
29
+ TriangleAttentionStartingNode,
30
+ TriangleAttentionEndingNode,
31
+ )
32
+ from openfold.model.triangular_multiplicative_update import (
33
+ TriangleMultiplicationOutgoing,
34
+ TriangleMultiplicationIncoming,
35
+ )
36
+ from openfold.utils.checkpointing import checkpoint_blocks
37
+ from openfold.utils.tensor_utils import (
38
+ chunk_layer,
39
+ permute_final_dims,
40
+ flatten_final_dims,
41
+ )
42
+
43
+
44
+ class TemplatePointwiseAttention(nn.Module):
45
+ """
46
+ Implements Algorithm 17.
47
+ """
48
+ def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
49
+ """
50
+ Args:
51
+ c_t:
52
+ Template embedding channel dimension
53
+ c_z:
54
+ Pair embedding channel dimension
55
+ c_hidden:
56
+ Hidden channel dimension
57
+ """
58
+ super(TemplatePointwiseAttention, self).__init__()
59
+
60
+ self.c_t = c_t
61
+ self.c_z = c_z
62
+ self.c_hidden = c_hidden
63
+ self.no_heads = no_heads
64
+ self.inf = inf
65
+
66
+ self.mha = Attention(
67
+ self.c_z,
68
+ self.c_t,
69
+ self.c_t,
70
+ self.c_hidden,
71
+ self.no_heads,
72
+ gating=False,
73
+ )
74
+
75
+ def _chunk(self,
76
+ z: torch.Tensor,
77
+ t: torch.Tensor,
78
+ biases: List[torch.Tensor],
79
+ chunk_size: int,
80
+ ) -> torch.Tensor:
81
+ mha_inputs = {
82
+ "q_x": z,
83
+ "kv_x": t,
84
+ "biases": biases,
85
+ }
86
+ return chunk_layer(
87
+ self.mha,
88
+ mha_inputs,
89
+ chunk_size=chunk_size,
90
+ no_batch_dims=len(z.shape[:-2]),
91
+ )
92
+
93
+
94
+ def forward(self,
95
+ t: torch.Tensor,
96
+ z: torch.Tensor,
97
+ template_mask: Optional[torch.Tensor] = None,
98
+ chunk_size: Optional[int] = None
99
+ ) -> torch.Tensor:
100
+ """
101
+ Args:
102
+ t:
103
+ [*, N_templ, N_res, N_res, C_t] template embedding
104
+ z:
105
+ [*, N_res, N_res, C_t] pair embedding
106
+ template_mask:
107
+ [*, N_templ] template mask
108
+ Returns:
109
+ [*, N_res, N_res, C_z] pair embedding update
110
+ """
111
+ if template_mask is None:
112
+ template_mask = t.new_ones(t.shape[:-3])
113
+
114
+ bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
115
+
116
+ # [*, N_res, N_res, 1, C_z]
117
+ z = z.unsqueeze(-2)
118
+
119
+ # [*, N_res, N_res, N_temp, C_t]
120
+ t = permute_final_dims(t, (1, 2, 0, 3))
121
+
122
+ # [*, N_res, N_res, 1, C_z]
123
+ biases = [bias]
124
+ if chunk_size is not None:
125
+ z = self._chunk(z, t, biases, chunk_size)
126
+ else:
127
+ z = self.mha(q_x=z, kv_x=t, biases=biases)
128
+
129
+ # [*, N_res, N_res, C_z]
130
+ z = z.squeeze(-2)
131
+
132
+ return z
133
+
134
+
135
+ class TemplatePairStackBlock(nn.Module):
136
+ def __init__(
137
+ self,
138
+ c_t: int,
139
+ c_hidden_tri_att: int,
140
+ c_hidden_tri_mul: int,
141
+ no_heads: int,
142
+ pair_transition_n: int,
143
+ dropout_rate: float,
144
+ inf: float,
145
+ **kwargs,
146
+ ):
147
+ super(TemplatePairStackBlock, self).__init__()
148
+
149
+ self.c_t = c_t
150
+ self.c_hidden_tri_att = c_hidden_tri_att
151
+ self.c_hidden_tri_mul = c_hidden_tri_mul
152
+ self.no_heads = no_heads
153
+ self.pair_transition_n = pair_transition_n
154
+ self.dropout_rate = dropout_rate
155
+ self.inf = inf
156
+
157
+ self.dropout_row = DropoutRowwise(self.dropout_rate)
158
+ self.dropout_col = DropoutColumnwise(self.dropout_rate)
159
+
160
+ self.tri_att_start = TriangleAttentionStartingNode(
161
+ self.c_t,
162
+ self.c_hidden_tri_att,
163
+ self.no_heads,
164
+ inf=inf,
165
+ )
166
+ self.tri_att_end = TriangleAttentionEndingNode(
167
+ self.c_t,
168
+ self.c_hidden_tri_att,
169
+ self.no_heads,
170
+ inf=inf,
171
+ )
172
+
173
+ self.tri_mul_out = TriangleMultiplicationOutgoing(
174
+ self.c_t,
175
+ self.c_hidden_tri_mul,
176
+ )
177
+ self.tri_mul_in = TriangleMultiplicationIncoming(
178
+ self.c_t,
179
+ self.c_hidden_tri_mul,
180
+ )
181
+
182
+ self.pair_transition = PairTransition(
183
+ self.c_t,
184
+ self.pair_transition_n,
185
+ )
186
+
187
+ def forward(self,
188
+ z: torch.Tensor,
189
+ mask: torch.Tensor,
190
+ chunk_size: Optional[int] = None,
191
+ _mask_trans: bool = True
192
+ ):
193
+ single_templates = [
194
+ t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
195
+ ]
196
+ single_templates_masks = [
197
+ m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
198
+ ]
199
+ for i in range(len(single_templates)):
200
+ single = single_templates[i]
201
+ single_mask = single_templates_masks[i]
202
+
203
+ single = single + self.dropout_row(
204
+ self.tri_att_start(
205
+ single,
206
+ chunk_size=chunk_size,
207
+ mask=single_mask
208
+ )
209
+ )
210
+ single = single + self.dropout_col(
211
+ self.tri_att_end(
212
+ single,
213
+ chunk_size=chunk_size,
214
+ mask=single_mask
215
+ )
216
+ )
217
+ single = single + self.dropout_row(
218
+ self.tri_mul_out(
219
+ single,
220
+ mask=single_mask
221
+ )
222
+ )
223
+ single = single + self.dropout_row(
224
+ self.tri_mul_in(
225
+ single,
226
+ mask=single_mask
227
+ )
228
+ )
229
+ single = single + self.pair_transition(
230
+ single,
231
+ mask=single_mask if _mask_trans else None,
232
+ chunk_size=chunk_size,
233
+ )
234
+
235
+ single_templates[i] = single
236
+
237
+ z = torch.cat(single_templates, dim=-4)
238
+
239
+ return z
240
+
241
+
242
+ class TemplatePairStack(nn.Module):
243
+ """
244
+ Implements Algorithm 16.
245
+ """
246
+ def __init__(
247
+ self,
248
+ c_t,
249
+ c_hidden_tri_att,
250
+ c_hidden_tri_mul,
251
+ no_blocks,
252
+ no_heads,
253
+ pair_transition_n,
254
+ dropout_rate,
255
+ blocks_per_ckpt,
256
+ inf=1e9,
257
+ **kwargs,
258
+ ):
259
+ """
260
+ Args:
261
+ c_t:
262
+ Template embedding channel dimension
263
+ c_hidden_tri_att:
264
+ Per-head hidden dimension for triangular attention
265
+ c_hidden_tri_att:
266
+ Hidden dimension for triangular multiplication
267
+ no_blocks:
268
+ Number of blocks in the stack
269
+ pair_transition_n:
270
+ Scale of pair transition (Alg. 15) hidden dimension
271
+ dropout_rate:
272
+ Dropout rate used throughout the stack
273
+ blocks_per_ckpt:
274
+ Number of blocks per activation checkpoint. None disables
275
+ activation checkpointing
276
+ """
277
+ super(TemplatePairStack, self).__init__()
278
+
279
+ self.blocks_per_ckpt = blocks_per_ckpt
280
+
281
+ self.blocks = nn.ModuleList()
282
+ for _ in range(no_blocks):
283
+ block = TemplatePairStackBlock(
284
+ c_t=c_t,
285
+ c_hidden_tri_att=c_hidden_tri_att,
286
+ c_hidden_tri_mul=c_hidden_tri_mul,
287
+ no_heads=no_heads,
288
+ pair_transition_n=pair_transition_n,
289
+ dropout_rate=dropout_rate,
290
+ inf=inf,
291
+ )
292
+ self.blocks.append(block)
293
+
294
+ self.layer_norm = LayerNorm(c_t)
295
+
296
+ def forward(
297
+ self,
298
+ t: torch.tensor,
299
+ mask: torch.tensor,
300
+ chunk_size: int,
301
+ _mask_trans: bool = True,
302
+ ):
303
+ """
304
+ Args:
305
+ t:
306
+ [*, N_templ, N_res, N_res, C_t] template embedding
307
+ mask:
308
+ [*, N_templ, N_res, N_res] mask
309
+ Returns:
310
+ [*, N_templ, N_res, N_res, C_t] template embedding update
311
+ """
312
+ if(mask.shape[-3] == 1):
313
+ expand_idx = list(mask.shape)
314
+ expand_idx[-3] = t.shape[-4]
315
+ mask = mask.expand(*expand_idx)
316
+
317
+ t, = checkpoint_blocks(
318
+ blocks=[
319
+ partial(
320
+ b,
321
+ mask=mask,
322
+ chunk_size=chunk_size,
323
+ _mask_trans=_mask_trans,
324
+ )
325
+ for b in self.blocks
326
+ ],
327
+ args=(t,),
328
+ blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
329
+ )
330
+
331
+ t = self.layer_norm(t)
332
+
333
+ return t
openfold/model/torchscript.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Sequence, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from openfold.model.dropout import (
21
+ DropoutRowwise,
22
+ DropoutColumnwise,
23
+ )
24
+ from openfold.model.evoformer import (
25
+ EvoformerBlock,
26
+ EvoformerStack,
27
+ )
28
+ from openfold.model.outer_product_mean import OuterProductMean
29
+ from openfold.model.msa import (
30
+ MSARowAttentionWithPairBias,
31
+ MSAColumnAttention,
32
+ MSAColumnGlobalAttention,
33
+ )
34
+ from openfold.model.pair_transition import PairTransition
35
+ from openfold.model.primitives import Attention, GlobalAttention
36
+ from openfold.model.structure_module import (
37
+ InvariantPointAttention,
38
+ BackboneUpdate,
39
+ )
40
+ from openfold.model.template import TemplatePairStackBlock
41
+ from openfold.model.triangular_attention import (
42
+ TriangleAttentionStartingNode,
43
+ TriangleAttentionEndingNode,
44
+ )
45
+ from openfold.model.triangular_multiplicative_update import (
46
+ TriangleMultiplicationOutgoing,
47
+ TriangleMultiplicationIncoming,
48
+ )
49
+
50
+
51
+ def script_preset_(model: torch.nn.Module):
52
+ """
53
+ TorchScript a handful of low-level but frequently used submodule types
54
+ that are known to be scriptable.
55
+
56
+ Args:
57
+ model:
58
+ A torch.nn.Module. It should contain at least some modules from
59
+ this repository, or this function won't do anything.
60
+ """
61
+ script_submodules_(
62
+ model,
63
+ [
64
+ nn.Dropout,
65
+ Attention,
66
+ GlobalAttention,
67
+ EvoformerBlock,
68
+ #TemplatePairStackBlock,
69
+ ],
70
+ attempt_trace=False,
71
+ batch_dims=None,
72
+ )
73
+
74
+
75
+ def _get_module_device(module: torch.nn.Module) -> torch.device:
76
+ """
77
+ Fetches the device of a module, assuming that all of the module's
78
+ parameters reside on a single device
79
+
80
+ Args:
81
+ module: A torch.nn.Module
82
+ Returns:
83
+ The module's device
84
+ """
85
+ return next(module.parameters()).device
86
+
87
+
88
+ def _trace_module(module, batch_dims=None):
89
+ if(batch_dims is None):
90
+ batch_dims = ()
91
+
92
+ # Stand-in values
93
+ n_seq = 10
94
+ n_res = 10
95
+
96
+ device = _get_module_device(module)
97
+
98
+ def msa(channel_dim):
99
+ return torch.rand(
100
+ (*batch_dims, n_seq, n_res, channel_dim),
101
+ device=device,
102
+ )
103
+
104
+ def pair(channel_dim):
105
+ return torch.rand(
106
+ (*batch_dims, n_res, n_res, channel_dim),
107
+ device=device,
108
+ )
109
+
110
+ if(isinstance(module, MSARowAttentionWithPairBias)):
111
+ inputs = {
112
+ "forward": (
113
+ msa(module.c_in), # m
114
+ pair(module.c_z), # z
115
+ torch.randint(
116
+ 0, 2,
117
+ (*batch_dims, n_seq, n_res)
118
+ ), # mask
119
+ ),
120
+ }
121
+ elif(isinstance(module, MSAColumnAttention)):
122
+ inputs = {
123
+ "forward": (
124
+ msa(module.c_in), # m
125
+ torch.randint(
126
+ 0, 2,
127
+ (*batch_dims, n_seq, n_res)
128
+ ), # mask
129
+ ),
130
+ }
131
+ elif(isinstance(module, OuterProductMean)):
132
+ inputs = {
133
+ "forward": (
134
+ msa(module.c_m),
135
+ torch.randint(
136
+ 0, 2,
137
+ (*batch_dims, n_seq, n_res)
138
+ )
139
+ )
140
+ }
141
+ else:
142
+ raise TypeError(
143
+ f"tracing is not supported for modules of type {type(module)}"
144
+ )
145
+
146
+ return torch.jit.trace_module(module, inputs)
147
+
148
+
149
+ def _script_submodules_helper_(
150
+ model,
151
+ types,
152
+ attempt_trace,
153
+ to_trace,
154
+ ):
155
+ for name, child in model.named_children():
156
+ if(types is None or any(isinstance(child, t) for t in types)):
157
+ try:
158
+ scripted = torch.jit.script(child)
159
+ setattr(model, name, scripted)
160
+ continue
161
+ except (RuntimeError, torch.jit.frontend.NotSupportedError) as e:
162
+ if(attempt_trace):
163
+ to_trace.add(type(child))
164
+ else:
165
+ raise e
166
+
167
+ _script_submodules_helper_(child, types, attempt_trace, to_trace)
168
+
169
+
170
+ def _trace_submodules_(
171
+ model,
172
+ types,
173
+ batch_dims=None,
174
+ ):
175
+ for name, child in model.named_children():
176
+ if(any(isinstance(child, t) for t in types)):
177
+ traced = _trace_module(child, batch_dims=batch_dims)
178
+ setattr(model, name, traced)
179
+ else:
180
+ _trace_submodules_(child, types, batch_dims=batch_dims)
181
+
182
+
183
+ def script_submodules_(
184
+ model: nn.Module,
185
+ types: Optional[Sequence[type]] = None,
186
+ attempt_trace: Optional[bool] = True,
187
+ batch_dims: Optional[Tuple[int]] = None,
188
+ ):
189
+ """
190
+ Convert all submodules whose types match one of those in the input
191
+ list to recursively scripted equivalents in place. To script the entire
192
+ model, just call torch.jit.script on it directly.
193
+
194
+ When types is None, all submodules are scripted.
195
+
196
+ Args:
197
+ model:
198
+ A torch.nn.Module
199
+ types:
200
+ A list of types of submodules to script
201
+ attempt_trace:
202
+ Whether to attempt to trace specified modules if scripting
203
+ fails. Recall that tracing eliminates all conditional
204
+ logic---with great tracing comes the mild responsibility of
205
+ having to remember to ensure that the modules in question
206
+ perform the same computations no matter what.
207
+ """
208
+ to_trace = set()
209
+
210
+ # Aggressively script as much as possible first...
211
+ _script_submodules_helper_(model, types, attempt_trace, to_trace)
212
+
213
+ # ... and then trace stragglers.
214
+ if(attempt_trace and len(to_trace) > 0):
215
+ _trace_submodules_(model, to_trace, batch_dims=batch_dims)
openfold/model/triangular_attention.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partialmethod, partial
17
+ import math
18
+ from typing import Optional, List
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from openfold.model.primitives import Linear, LayerNorm, Attention
24
+ from openfold.utils.tensor_utils import (
25
+ chunk_layer,
26
+ permute_final_dims,
27
+ flatten_final_dims,
28
+ )
29
+
30
+
31
+ class TriangleAttention(nn.Module):
32
+ def __init__(
33
+ self, c_in, c_hidden, no_heads, starting, inf=1e9
34
+ ):
35
+ """
36
+ Args:
37
+ c_in:
38
+ Input channel dimension
39
+ c_hidden:
40
+ Overall hidden channel dimension (not per-head)
41
+ no_heads:
42
+ Number of attention heads
43
+ """
44
+ super(TriangleAttention, self).__init__()
45
+
46
+ self.c_in = c_in
47
+ self.c_hidden = c_hidden
48
+ self.no_heads = no_heads
49
+ self.starting = starting
50
+ self.inf = inf
51
+
52
+ self.layer_norm = LayerNorm(self.c_in)
53
+
54
+ self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
55
+
56
+ self.mha = Attention(
57
+ self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
58
+ )
59
+
60
+ @torch.jit.ignore
61
+ def _chunk(self,
62
+ x: torch.Tensor,
63
+ biases: List[torch.Tensor],
64
+ chunk_size: int,
65
+ ) -> torch.Tensor:
66
+ mha_inputs = {
67
+ "q_x": x,
68
+ "kv_x": x,
69
+ "biases": biases,
70
+ }
71
+ return chunk_layer(
72
+ partial(self.mha),
73
+ mha_inputs,
74
+ chunk_size=chunk_size,
75
+ no_batch_dims=len(x.shape[:-2]),
76
+ )
77
+
78
+ def forward(self,
79
+ x: torch.Tensor,
80
+ mask: Optional[torch.Tensor] = None,
81
+ chunk_size: Optional[int] = None
82
+ ) -> torch.Tensor:
83
+ """
84
+ Args:
85
+ x:
86
+ [*, I, J, C_in] input tensor (e.g. the pair representation)
87
+ Returns:
88
+ [*, I, J, C_in] output tensor
89
+ """
90
+ if mask is None:
91
+ # [*, I, J]
92
+ mask = x.new_ones(
93
+ x.shape[:-1],
94
+ )
95
+
96
+ # Shape annotations assume self.starting. Else, I and J are flipped
97
+ if not self.starting:
98
+ x = x.transpose(-2, -3)
99
+ mask = mask.transpose(-1, -2)
100
+
101
+ # [*, I, J, C_in]
102
+ x = self.layer_norm(x)
103
+
104
+ # [*, I, 1, 1, J]
105
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
106
+
107
+ # [*, H, I, J]
108
+ triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
109
+
110
+ # [*, 1, H, I, J]
111
+ triangle_bias = triangle_bias.unsqueeze(-4)
112
+
113
+ biases = [mask_bias, triangle_bias]
114
+
115
+ if chunk_size is not None:
116
+ x = self._chunk(x, biases, chunk_size)
117
+ else:
118
+ x = self.mha(q_x=x, kv_x=x, biases=biases)
119
+
120
+ if not self.starting:
121
+ x = x.transpose(-2, -3)
122
+
123
+ return x
124
+
125
+
126
+ class TriangleAttentionStartingNode(TriangleAttention):
127
+ """
128
+ Implements Algorithm 13.
129
+ """
130
+
131
+ __init__ = partialmethod(TriangleAttention.__init__, starting=True)
132
+
133
+
134
+ class TriangleAttentionEndingNode(TriangleAttention):
135
+ """
136
+ Implements Algorithm 14.
137
+ """
138
+
139
+ __init__ = partialmethod(TriangleAttention.__init__, starting=False)
openfold/model/triangular_multiplicative_update.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partialmethod
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from openfold.model.primitives import Linear, LayerNorm
23
+ from openfold.utils.tensor_utils import permute_final_dims
24
+
25
+
26
+ class TriangleMultiplicativeUpdate(nn.Module):
27
+ """
28
+ Implements Algorithms 11 and 12.
29
+ """
30
+ def __init__(self, c_z, c_hidden, _outgoing=True):
31
+ """
32
+ Args:
33
+ c_z:
34
+ Input channel dimension
35
+ c:
36
+ Hidden channel dimension
37
+ """
38
+ super(TriangleMultiplicativeUpdate, self).__init__()
39
+ self.c_z = c_z
40
+ self.c_hidden = c_hidden
41
+ self._outgoing = _outgoing
42
+
43
+ self.linear_a_p = Linear(self.c_z, self.c_hidden)
44
+ self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
45
+ self.linear_b_p = Linear(self.c_z, self.c_hidden)
46
+ self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
47
+ self.linear_g = Linear(self.c_z, self.c_z, init="gating")
48
+ self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
49
+
50
+ self.layer_norm_in = LayerNorm(self.c_z)
51
+ self.layer_norm_out = LayerNorm(self.c_hidden)
52
+
53
+ self.sigmoid = nn.Sigmoid()
54
+
55
+ def _combine_projections(self,
56
+ a: torch.Tensor,
57
+ b: torch.Tensor,
58
+ ) -> torch.Tensor:
59
+ raise NotImplementedError("This method needs to be overridden")
60
+
61
+ def forward(self,
62
+ z: torch.Tensor,
63
+ mask: Optional[torch.Tensor] = None
64
+ ) -> torch.Tensor:
65
+ """
66
+ Args:
67
+ x:
68
+ [*, N_res, N_res, C_z] input tensor
69
+ mask:
70
+ [*, N_res, N_res] input mask
71
+ Returns:
72
+ [*, N_res, N_res, C_z] output tensor
73
+ """
74
+ if mask is None:
75
+ mask = z.new_ones(z.shape[:-1])
76
+
77
+ mask = mask.unsqueeze(-1)
78
+
79
+ z = self.layer_norm_in(z)
80
+ a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
81
+ a = a * mask
82
+ b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
83
+ b = b * mask
84
+ x = self._combine_projections(a, b)
85
+ x = self.layer_norm_out(x)
86
+ x = self.linear_z(x)
87
+ g = self.sigmoid(self.linear_g(z))
88
+ z = x * g
89
+
90
+ return z
91
+
92
+
93
+ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
94
+ """
95
+ Implements Algorithm 11.
96
+ """
97
+ def _combine_projections(self,
98
+ a: torch.Tensor, # [*, N_i, N_k, C]
99
+ b: torch.Tensor, # [*, N_j, N_k, C]
100
+ ):
101
+ # [*, C, N_i, N_j]
102
+ p = torch.matmul(
103
+ permute_final_dims(a, (2, 0, 1)),
104
+ permute_final_dims(b, (2, 1, 0)),
105
+ )
106
+
107
+ # [*, N_i, N_j, C]
108
+ return permute_final_dims(p, (1, 2, 0))
109
+
110
+
111
+ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
112
+ """
113
+ Implements Algorithm 12.
114
+ """
115
+ def _combine_projections(self,
116
+ a: torch.Tensor, # [*, N_k, N_i, C]
117
+ b: torch.Tensor, # [*, N_k, N_j, C]
118
+ ):
119
+ # [*, C, N_i, N_j]
120
+ p = torch.matmul(
121
+ permute_final_dims(a, (2, 1, 0)),
122
+ permute_final_dims(b, (2, 0, 1)),
123
+ )
124
+
125
+ # [*, N_i, N_j, C]
126
+ return permute_final_dims(p, (1, 2, 0))
127
+
openfold/np/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import importlib as importlib
4
+
5
+ _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
6
+ __all__ = [
7
+ os.path.basename(f)[:-3]
8
+ for f in _files
9
+ if os.path.isfile(f) and not f.endswith("__init__.py")
10
+ ]
11
+ _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
12
+ for _m in _modules:
13
+ globals()[_m[0]] = _m[1]
14
+
15
+ # Avoid needlessly cluttering the global namespace
16
+ del _files, _m, _modules
openfold/np/protein.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Protein data type."""
17
+ import dataclasses
18
+ import io
19
+ from typing import Any, Sequence, Mapping, Optional
20
+ import re
21
+ import string
22
+
23
+ from openfold.np import residue_constants
24
+ from Bio.PDB import PDBParser
25
+ import numpy as np
26
+
27
+
28
+ FeatureDict = Mapping[str, np.ndarray]
29
+ ModelOutput = Mapping[str, Any] # Is a nested dict.
30
+ PICO_TO_ANGSTROM = 0.01
31
+
32
+ @dataclasses.dataclass(frozen=True)
33
+ class Protein:
34
+ """Protein structure representation."""
35
+
36
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
37
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
38
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
39
+
40
+ # Amino-acid type for each residue represented as an integer between 0 and
41
+ # 20, where 20 is 'X'.
42
+ aatype: np.ndarray # [num_res]
43
+
44
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
45
+ # is present and 0.0 if not. This should be used for loss masking.
46
+ atom_mask: np.ndarray # [num_res, num_atom_type]
47
+
48
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
49
+ residue_index: np.ndarray # [num_res]
50
+
51
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
52
+ # representing the displacement of the residue from its ground truth mean
53
+ # value.
54
+ b_factors: np.ndarray # [num_res, num_atom_type]
55
+
56
+ # Chain indices for multi-chain predictions
57
+ chain_index: Optional[np.ndarray] = None
58
+
59
+ # Optional remark about the protein. Included as a comment in output PDB
60
+ # files
61
+ remark: Optional[str] = None
62
+
63
+ # Templates used to generate this protein (prediction-only)
64
+ parents: Optional[Sequence[str]] = None
65
+
66
+ # Chain corresponding to each parent
67
+ parents_chain_index: Optional[Sequence[int]] = None
68
+
69
+
70
+ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
71
+ """Takes a PDB string and constructs a Protein object.
72
+
73
+ WARNING: All non-standard residue types will be converted into UNK. All
74
+ non-standard atoms will be ignored.
75
+
76
+ Args:
77
+ pdb_str: The contents of the pdb file
78
+ chain_id: If None, then the pdb file must contain a single chain (which
79
+ will be parsed). If chain_id is specified (e.g. A), then only that chain
80
+ is parsed.
81
+
82
+ Returns:
83
+ A new `Protein` parsed from the pdb contents.
84
+ """
85
+ pdb_fh = io.StringIO(pdb_str)
86
+ parser = PDBParser(QUIET=True)
87
+ structure = parser.get_structure("none", pdb_fh)
88
+ models = list(structure.get_models())
89
+ if len(models) != 1:
90
+ raise ValueError(
91
+ f"Only single model PDBs are supported. Found {len(models)} models."
92
+ )
93
+ model = models[0]
94
+
95
+ atom_positions = []
96
+ aatype = []
97
+ atom_mask = []
98
+ residue_index = []
99
+ chain_ids = []
100
+ b_factors = []
101
+
102
+ for chain in model:
103
+ if(chain_id is not None and chain.id != chain_id):
104
+ continue
105
+ for res in chain:
106
+ if res.id[2] != " ":
107
+ raise ValueError(
108
+ f"PDB contains an insertion code at chain {chain.id} and residue "
109
+ f"index {res.id[1]}. These are not supported."
110
+ )
111
+ res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
112
+ restype_idx = residue_constants.restype_order.get(
113
+ res_shortname, residue_constants.restype_num
114
+ )
115
+ pos = np.zeros((residue_constants.atom_type_num, 3))
116
+ mask = np.zeros((residue_constants.atom_type_num,))
117
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
118
+ for atom in res:
119
+ if atom.name not in residue_constants.atom_types:
120
+ continue
121
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
122
+ mask[residue_constants.atom_order[atom.name]] = 1.0
123
+ res_b_factors[
124
+ residue_constants.atom_order[atom.name]
125
+ ] = atom.bfactor
126
+ if np.sum(mask) < 0.5:
127
+ # If no known atom positions are reported for the residue then skip it.
128
+ continue
129
+ aatype.append(restype_idx)
130
+ atom_positions.append(pos)
131
+ atom_mask.append(mask)
132
+ residue_index.append(res.id[1])
133
+ chain_ids.append(chain.id)
134
+ b_factors.append(res_b_factors)
135
+
136
+ parents = None
137
+ parents_chain_index = None
138
+ if("PARENT" in pdb_str):
139
+ parents = []
140
+ parents_chain_index = []
141
+ chain_id = 0
142
+ for l in pdb_str.split("\n"):
143
+ if("PARENT" in l):
144
+ if(not "N/A" in l):
145
+ parent_names = l.split()[1:]
146
+ parents.extend(parent_names)
147
+ parents_chain_index.extend([
148
+ chain_id for _ in parent_names
149
+ ])
150
+ chain_id += 1
151
+
152
+ unique_chain_ids = np.unique(chain_ids)
153
+ chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
154
+ chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
155
+
156
+ return Protein(
157
+ atom_positions=np.array(atom_positions),
158
+ atom_mask=np.array(atom_mask),
159
+ aatype=np.array(aatype),
160
+ residue_index=np.array(residue_index),
161
+ chain_index=chain_index,
162
+ b_factors=np.array(b_factors),
163
+ parents=parents,
164
+ parents_chain_index=parents_chain_index,
165
+ )
166
+
167
+
168
+ def from_proteinnet_string(proteinnet_str: str) -> Protein:
169
+ tag_re = r'(\[[A-Z]+\]\n)'
170
+ tags = [
171
+ tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
172
+ ]
173
+ groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
174
+
175
+ atoms = ['N', 'CA', 'C']
176
+ aatype = None
177
+ atom_positions = None
178
+ atom_mask = None
179
+ for g in groups:
180
+ if("[PRIMARY]" == g[0]):
181
+ seq = g[1][0].strip()
182
+ for i in range(len(seq)):
183
+ if(seq[i] not in residue_constants.restypes):
184
+ seq[i] = 'X'
185
+ aatype = np.array([
186
+ residue_constants.restype_order.get(
187
+ res_symbol, residue_constants.restype_num
188
+ ) for res_symbol in seq
189
+ ])
190
+ elif("[TERTIARY]" == g[0]):
191
+ tertiary = []
192
+ for axis in range(3):
193
+ tertiary.append(list(map(float, g[1][axis].split())))
194
+ tertiary_np = np.array(tertiary)
195
+ atom_positions = np.zeros(
196
+ (len(tertiary[0])//3, residue_constants.atom_type_num, 3)
197
+ ).astype(np.float32)
198
+ for i, atom in enumerate(atoms):
199
+ atom_positions[:, residue_constants.atom_order[atom], :] = (
200
+ np.transpose(tertiary_np[:, i::3])
201
+ )
202
+ atom_positions *= PICO_TO_ANGSTROM
203
+ elif("[MASK]" == g[0]):
204
+ mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
205
+ atom_mask = np.zeros(
206
+ (len(mask), residue_constants.atom_type_num,)
207
+ ).astype(np.float32)
208
+ for i, atom in enumerate(atoms):
209
+ atom_mask[:, residue_constants.atom_order[atom]] = 1
210
+ atom_mask *= mask[..., None]
211
+
212
+ return Protein(
213
+ atom_positions=atom_positions,
214
+ atom_mask=atom_mask,
215
+ aatype=aatype,
216
+ residue_index=np.arange(len(aatype)),
217
+ b_factors=None,
218
+ )
219
+
220
+
221
+ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
222
+ pdb_headers = []
223
+
224
+ remark = prot.remark
225
+ if(remark is not None):
226
+ pdb_headers.append(f"REMARK {remark}")
227
+
228
+ parents = prot.parents
229
+ parents_chain_index = prot.parents_chain_index
230
+ if(parents_chain_index is not None):
231
+ parents = [
232
+ p for i, p in zip(parents_chain_index, parents) if i == chain_id
233
+ ]
234
+
235
+ if(parents is None or len(parents) == 0):
236
+ parents = ["N/A"]
237
+
238
+ pdb_headers.append(f"PARENT {' '.join(parents)}")
239
+
240
+ return pdb_headers
241
+
242
+
243
+ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
244
+ """ Add pdb headers to an existing PDB string. Useful during multi-chain
245
+ recycling
246
+ """
247
+ out_pdb_lines = []
248
+ lines = pdb_str.split('\n')
249
+
250
+ remark = prot.remark
251
+ if(remark is not None):
252
+ out_pdb_lines.append(f"REMARK {remark}")
253
+
254
+ parents_per_chain = None
255
+ if(prot.parents is not None and len(prot.parents) > 0):
256
+ parents_per_chain = []
257
+ if(prot.parents_chain_index is not None):
258
+ cur_chain = prot.parents_chain_index[0]
259
+ parent_dict = {}
260
+ for p, i in zip(prot.parents, prot.parents_chain_index):
261
+ parent_dict.setdefault(str(i), [])
262
+ parent_dict[str(i)].append(p)
263
+
264
+ max_idx = max([int(chain_idx) for chain_idx in parent_dict])
265
+ for i in range(max_idx + 1):
266
+ chain_parents = parent_dict.get(str(i), ["N/A"])
267
+ parents_per_chain.append(chain_parents)
268
+ else:
269
+ parents_per_chain.append(prot.parents)
270
+ else:
271
+ parents_per_chain = [["N/A"]]
272
+
273
+ make_parent_line = lambda p: f"PARENT {' '.join(p)}"
274
+
275
+ out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
276
+
277
+ chain_counter = 0
278
+ for i, l in enumerate(lines):
279
+ if("PARENT" not in l and "REMARK" not in l):
280
+ out_pdb_lines.append(l)
281
+ if("TER" in l and not "END" in lines[i + 1]):
282
+ chain_counter += 1
283
+ if(not chain_counter >= len(parents_per_chain)):
284
+ chain_parents = parents_per_chain[chain_counter]
285
+ else:
286
+ chain_parents = ["N/A"]
287
+
288
+ out_pdb_lines.append(make_parent_line(chain_parents))
289
+
290
+ return '\n'.join(out_pdb_lines)
291
+
292
+
293
+ def to_pdb(prot: Protein) -> str:
294
+ """Converts a `Protein` instance to a PDB string.
295
+
296
+ Args:
297
+ prot: The protein to convert to PDB.
298
+
299
+ Returns:
300
+ PDB string.
301
+ """
302
+ restypes = residue_constants.restypes + ["X"]
303
+ res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
304
+ atom_types = residue_constants.atom_types
305
+
306
+ pdb_lines = []
307
+
308
+ atom_mask = prot.atom_mask
309
+ aatype = prot.aatype
310
+ atom_positions = prot.atom_positions
311
+ residue_index = prot.residue_index.astype(int)
312
+ b_factors = prot.b_factors
313
+ chain_index = prot.chain_index
314
+
315
+ if np.any(aatype > residue_constants.restype_num):
316
+ raise ValueError("Invalid aatypes.")
317
+
318
+ headers = get_pdb_headers(prot)
319
+ if(len(headers) > 0):
320
+ pdb_lines.extend(headers)
321
+
322
+ n = aatype.shape[0]
323
+ atom_index = 1
324
+ prev_chain_index = 0
325
+ chain_tags = string.ascii_uppercase
326
+ # Add all atom sites.
327
+ for i in range(n):
328
+ res_name_3 = res_1to3(aatype[i])
329
+ for atom_name, pos, mask, b_factor in zip(
330
+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]
331
+ ):
332
+ if mask < 0.5:
333
+ continue
334
+
335
+ record_type = "ATOM"
336
+ name = atom_name if len(atom_name) == 4 else f" {atom_name}"
337
+ alt_loc = ""
338
+ insertion_code = ""
339
+ occupancy = 1.00
340
+ element = atom_name[
341
+ 0
342
+ ] # Protein supports only C, N, O, S, this works.
343
+ charge = ""
344
+
345
+ chain_tag = "A"
346
+ if(chain_index is not None):
347
+ chain_tag = chain_tags[chain_index[i]]
348
+
349
+ # PDB is a columnar format, every space matters here!
350
+ atom_line = (
351
+ f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
352
+ f"{res_name_3:>3} {chain_tag:>1}"
353
+ f"{residue_index[i]:>4}{insertion_code:>1} "
354
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
355
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
356
+ f"{element:>2}{charge:>2}"
357
+ )
358
+ pdb_lines.append(atom_line)
359
+ atom_index += 1
360
+
361
+ should_terminate = (i == n - 1)
362
+ if(chain_index is not None):
363
+ if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
364
+ should_terminate = True
365
+ prev_chain_index = chain_index[i + 1]
366
+
367
+ if(should_terminate):
368
+ # Close the chain.
369
+ chain_end = "TER"
370
+ chain_termination_line = (
371
+ f"{chain_end:<6}{atom_index:>5} "
372
+ f"{res_1to3(aatype[i]):>3} "
373
+ f"{chain_tag:>1}{residue_index[i]:>4}"
374
+ )
375
+ pdb_lines.append(chain_termination_line)
376
+ atom_index += 1
377
+
378
+ if(i != n - 1):
379
+ # "prev" is a misnomer here. This happens at the beginning of
380
+ # each new chain.
381
+ pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
382
+
383
+ pdb_lines.append("END")
384
+ pdb_lines.append("")
385
+ return "\n".join(pdb_lines)
386
+
387
+
388
+ def ideal_atom_mask(prot: Protein) -> np.ndarray:
389
+ """Computes an ideal atom mask.
390
+
391
+ `Protein.atom_mask` typically is defined according to the atoms that are
392
+ reported in the PDB. This function computes a mask according to heavy atoms
393
+ that should be present in the given sequence of amino acids.
394
+
395
+ Args:
396
+ prot: `Protein` whose fields are `numpy.ndarray` objects.
397
+
398
+ Returns:
399
+ An ideal atom mask.
400
+ """
401
+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
402
+
403
+
404
+ def from_prediction(
405
+ features: FeatureDict,
406
+ result: ModelOutput,
407
+ b_factors: Optional[np.ndarray] = None,
408
+ chain_index: Optional[np.ndarray] = None,
409
+ remark: Optional[str] = None,
410
+ parents: Optional[Sequence[str]] = None,
411
+ parents_chain_index: Optional[Sequence[int]] = None
412
+ ) -> Protein:
413
+ """Assembles a protein from a prediction.
414
+
415
+ Args:
416
+ features: Dictionary holding model inputs.
417
+ result: Dictionary holding model outputs.
418
+ b_factors: (Optional) B-factors to use for the protein.
419
+ chain_index: (Optional) Chain indices for multi-chain predictions
420
+ remark: (Optional) Remark about the prediction
421
+ parents: (Optional) List of template names
422
+ Returns:
423
+ A protein instance.
424
+ """
425
+ if b_factors is None:
426
+ b_factors = np.zeros_like(result["final_atom_mask"])
427
+
428
+ return Protein(
429
+ aatype=features["aatype"],
430
+ atom_positions=result["final_atom_positions"],
431
+ atom_mask=result["final_atom_mask"],
432
+ residue_index=features["residue_index"] + 1,
433
+ b_factors=b_factors,
434
+ chain_index=chain_index,
435
+ remark=remark,
436
+ parents=parents,
437
+ parents_chain_index=parents_chain_index,
438
+ )
openfold/np/relax/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import importlib as importlib
4
+
5
+ _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
6
+ __all__ = [
7
+ os.path.basename(f)[:-3]
8
+ for f in _files
9
+ if os.path.isfile(f) and not f.endswith("__init__.py")
10
+ ]
11
+ _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
12
+ for _m in _modules:
13
+ globals()[_m[0]] = _m[1]
14
+
15
+ # Avoid needlessly cluttering the global namespace
16
+ del _files, _m, _modules
openfold/np/relax/amber_minimize.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Restrained Amber Minimization of a structure."""
17
+
18
+ import io
19
+ import time
20
+ from typing import Collection, Optional, Sequence
21
+
22
+ from absl import logging
23
+ from openfold.np import (
24
+ protein,
25
+ residue_constants,
26
+ )
27
+ import openfold.utils.loss as loss
28
+ from openfold.np.relax import cleanup, utils
29
+ import ml_collections
30
+ import numpy as np
31
+ import openmm
32
+ from openmm import unit
33
+ from openmm import app as openmm_app
34
+ from openmm.app.internal.pdbstructure import PdbStructure
35
+
36
+ ENERGY = unit.kilocalories_per_mole
37
+ LENGTH = unit.angstroms
38
+
39
+
40
+ def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
41
+ """Returns True if the atom will be restrained by the given restraint set."""
42
+
43
+ if rset == "non_hydrogen":
44
+ return atom.element.name != "hydrogen"
45
+ elif rset == "c_alpha":
46
+ return atom.name == "CA"
47
+
48
+
49
+ def _add_restraints(
50
+ system: openmm.System,
51
+ reference_pdb: openmm_app.PDBFile,
52
+ stiffness: unit.Unit,
53
+ rset: str,
54
+ exclude_residues: Sequence[int],
55
+ ):
56
+ """Adds a harmonic potential that restrains the system to a structure."""
57
+ assert rset in ["non_hydrogen", "c_alpha"]
58
+
59
+ force = openmm.CustomExternalForce(
60
+ "0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
61
+ )
62
+ force.addGlobalParameter("k", stiffness)
63
+ for p in ["x0", "y0", "z0"]:
64
+ force.addPerParticleParameter(p)
65
+
66
+ for i, atom in enumerate(reference_pdb.topology.atoms()):
67
+ if atom.residue.index in exclude_residues:
68
+ continue
69
+ if will_restrain(atom, rset):
70
+ force.addParticle(i, reference_pdb.positions[i])
71
+ logging.info(
72
+ "Restraining %d / %d particles.",
73
+ force.getNumParticles(),
74
+ system.getNumParticles(),
75
+ )
76
+ system.addForce(force)
77
+
78
+
79
+ def _openmm_minimize(
80
+ pdb_str: str,
81
+ max_iterations: int,
82
+ tolerance: unit.Unit,
83
+ stiffness: unit.Unit,
84
+ restraint_set: str,
85
+ exclude_residues: Sequence[int],
86
+ use_gpu: bool,
87
+ ):
88
+ """Minimize energy via openmm."""
89
+
90
+ pdb_file = io.StringIO(pdb_str)
91
+ pdb = openmm_app.PDBFile(pdb_file)
92
+
93
+ force_field = openmm_app.ForceField("amber99sb.xml")
94
+ constraints = openmm_app.HBonds
95
+ system = force_field.createSystem(pdb.topology, constraints=constraints)
96
+ if stiffness > 0 * ENERGY / (LENGTH ** 2):
97
+ _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
98
+
99
+ integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
100
+ platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
101
+ simulation = openmm_app.Simulation(
102
+ pdb.topology, system, integrator, platform
103
+ )
104
+ simulation.context.setPositions(pdb.positions)
105
+
106
+ ret = {}
107
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
108
+ ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
109
+ ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
110
+ simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
111
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
112
+ ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
113
+ ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
114
+ ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
115
+ return ret
116
+
117
+
118
+ def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
119
+ """Returns a pdb string provided OpenMM topology and positions."""
120
+ with io.StringIO() as f:
121
+ openmm_app.PDBFile.writeFile(topology, positions, f)
122
+ return f.getvalue()
123
+
124
+
125
+ def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
126
+ """Checks that no atom positions have been altered by cleaning."""
127
+ cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
128
+ reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
129
+
130
+ cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
131
+ ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
132
+
133
+ for ref_res, cl_res in zip(
134
+ reference.topology.residues(), cleaned.topology.residues()
135
+ ):
136
+ assert ref_res.name == cl_res.name
137
+ for rat in ref_res.atoms():
138
+ for cat in cl_res.atoms():
139
+ if cat.name == rat.name:
140
+ if not np.array_equal(
141
+ cl_xyz[cat.index], ref_xyz[rat.index]
142
+ ):
143
+ raise ValueError(
144
+ f"Coordinates of cleaned atom {cat} do not match "
145
+ f"coordinates of reference atom {rat}."
146
+ )
147
+
148
+
149
+ def _check_residues_are_well_defined(prot: protein.Protein):
150
+ """Checks that all residues contain non-empty atom sets."""
151
+ if (prot.atom_mask.sum(axis=-1) == 0).any():
152
+ raise ValueError(
153
+ "Amber minimization can only be performed on proteins with"
154
+ " well-defined residues. This protein contains at least"
155
+ " one residue with no atoms."
156
+ )
157
+
158
+
159
+ def _check_atom_mask_is_ideal(prot):
160
+ """Sanity-check the atom mask is ideal, up to a possible OXT."""
161
+ atom_mask = prot.atom_mask
162
+ ideal_atom_mask = protein.ideal_atom_mask(prot)
163
+ utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
164
+
165
+
166
+ def clean_protein(prot: protein.Protein, checks: bool = True):
167
+ """Adds missing atoms to Protein instance.
168
+
169
+ Args:
170
+ prot: A `protein.Protein` instance.
171
+ checks: A `bool` specifying whether to add additional checks to the cleaning
172
+ process.
173
+
174
+ Returns:
175
+ pdb_string: A string of the cleaned protein.
176
+ """
177
+ _check_atom_mask_is_ideal(prot)
178
+
179
+ # Clean pdb.
180
+ prot_pdb_string = protein.to_pdb(prot)
181
+ pdb_file = io.StringIO(prot_pdb_string)
182
+ alterations_info = {}
183
+ fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
184
+ fixed_pdb_file = io.StringIO(fixed_pdb)
185
+ pdb_structure = PdbStructure(fixed_pdb_file)
186
+ cleanup.clean_structure(pdb_structure, alterations_info)
187
+
188
+ logging.info("alterations info: %s", alterations_info)
189
+
190
+ # Write pdb file of cleaned structure.
191
+ as_file = openmm_app.PDBFile(pdb_structure)
192
+ pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
193
+ if checks:
194
+ _check_cleaned_atoms(pdb_string, prot_pdb_string)
195
+ return pdb_string
196
+
197
+
198
+ def make_atom14_positions(prot):
199
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
200
+ restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
201
+ restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
202
+ restype_atom14_mask = []
203
+
204
+ for rt in residue_constants.restypes:
205
+ atom_names = residue_constants.restype_name_to_atom14_names[
206
+ residue_constants.restype_1to3[rt]
207
+ ]
208
+
209
+ restype_atom14_to_atom37.append(
210
+ [
211
+ (residue_constants.atom_order[name] if name else 0)
212
+ for name in atom_names
213
+ ]
214
+ )
215
+
216
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
217
+ restype_atom37_to_atom14.append(
218
+ [
219
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
220
+ for name in residue_constants.atom_types
221
+ ]
222
+ )
223
+
224
+ restype_atom14_mask.append(
225
+ [(1.0 if name else 0.0) for name in atom_names]
226
+ )
227
+
228
+ # Add dummy mapping for restype 'UNK'.
229
+ restype_atom14_to_atom37.append([0] * 14)
230
+ restype_atom37_to_atom14.append([0] * 37)
231
+ restype_atom14_mask.append([0.0] * 14)
232
+
233
+ restype_atom14_to_atom37 = np.array(
234
+ restype_atom14_to_atom37, dtype=int
235
+ )
236
+ restype_atom37_to_atom14 = np.array(
237
+ restype_atom37_to_atom14, dtype=int
238
+ )
239
+ restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
240
+
241
+ # Create the mapping for (residx, atom14) --> atom37, i.e. an array
242
+ # with shape (num_res, 14) containing the atom37 indices for this protein.
243
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
244
+ residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
245
+
246
+ # Create a mask for known ground truth positions.
247
+ residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
248
+ prot["all_atom_mask"], residx_atom14_to_atom37, axis=1
249
+ ).astype(np.float32)
250
+
251
+ # Gather the ground truth positions.
252
+ residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
253
+ np.take_along_axis(
254
+ prot["all_atom_positions"],
255
+ residx_atom14_to_atom37[..., None],
256
+ axis=1,
257
+ )
258
+ )
259
+
260
+ prot["atom14_atom_exists"] = residx_atom14_mask
261
+ prot["atom14_gt_exists"] = residx_atom14_gt_mask
262
+ prot["atom14_gt_positions"] = residx_atom14_gt_positions
263
+
264
+ prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
265
+
266
+ # Create the gather indices for mapping back.
267
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
268
+ prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
269
+
270
+ # Create the corresponding mask.
271
+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
272
+ for restype, restype_letter in enumerate(residue_constants.restypes):
273
+ restype_name = residue_constants.restype_1to3[restype_letter]
274
+ atom_names = residue_constants.residue_atoms[restype_name]
275
+ for atom_name in atom_names:
276
+ atom_type = residue_constants.atom_order[atom_name]
277
+ restype_atom37_mask[restype, atom_type] = 1
278
+
279
+ residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
280
+ prot["atom37_atom_exists"] = residx_atom37_mask
281
+
282
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
283
+ # alternative ground truth coordinates where the naming is swapped
284
+ restype_3 = [
285
+ residue_constants.restype_1to3[res]
286
+ for res in residue_constants.restypes
287
+ ]
288
+ restype_3 += ["UNK"]
289
+
290
+ # Matrices for renaming ambiguous atoms.
291
+ all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
292
+ for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
293
+ correspondences = np.arange(14)
294
+ for source_atom_swap, target_atom_swap in swap.items():
295
+ source_index = residue_constants.restype_name_to_atom14_names[
296
+ resname
297
+ ].index(source_atom_swap)
298
+ target_index = residue_constants.restype_name_to_atom14_names[
299
+ resname
300
+ ].index(target_atom_swap)
301
+ correspondences[source_index] = target_index
302
+ correspondences[target_index] = source_index
303
+ renaming_matrix = np.zeros((14, 14), dtype=np.float32)
304
+ for index, correspondence in enumerate(correspondences):
305
+ renaming_matrix[index, correspondence] = 1.0
306
+ all_matrices[resname] = renaming_matrix.astype(np.float32)
307
+ renaming_matrices = np.stack(
308
+ [all_matrices[restype] for restype in restype_3]
309
+ )
310
+
311
+ # Pick the transformation matrices for the given residue sequence
312
+ # shape (num_res, 14, 14).
313
+ renaming_transform = renaming_matrices[prot["aatype"]]
314
+
315
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
316
+ alternative_gt_positions = np.einsum(
317
+ "rac,rab->rbc", residx_atom14_gt_positions, renaming_transform
318
+ )
319
+ prot["atom14_alt_gt_positions"] = alternative_gt_positions
320
+
321
+ # Create the mask for the alternative ground truth (differs from the
322
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
323
+ # ground truth position).
324
+ alternative_gt_mask = np.einsum(
325
+ "ra,rab->rb", residx_atom14_gt_mask, renaming_transform
326
+ )
327
+
328
+ prot["atom14_alt_gt_exists"] = alternative_gt_mask
329
+
330
+ # Create an ambiguous atoms mask. shape: (21, 14).
331
+ restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
332
+ for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
333
+ for atom_name1, atom_name2 in swap.items():
334
+ restype = residue_constants.restype_order[
335
+ residue_constants.restype_3to1[resname]
336
+ ]
337
+ atom_idx1 = residue_constants.restype_name_to_atom14_names[
338
+ resname
339
+ ].index(atom_name1)
340
+ atom_idx2 = residue_constants.restype_name_to_atom14_names[
341
+ resname
342
+ ].index(atom_name2)
343
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
344
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
345
+
346
+ # From this create an ambiguous_mask for the given sequence.
347
+ prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
348
+ prot["aatype"]
349
+ ]
350
+
351
+ return prot
352
+
353
+
354
+ def find_violations(prot_np: protein.Protein):
355
+ """Analyzes a protein and returns structural violation information.
356
+
357
+ Args:
358
+ prot_np: A protein.
359
+
360
+ Returns:
361
+ violations: A `dict` of structure components with structural violations.
362
+ violation_metrics: A `dict` of violation metrics.
363
+ """
364
+ batch = {
365
+ "aatype": prot_np.aatype,
366
+ "all_atom_positions": prot_np.atom_positions.astype(np.float32),
367
+ "all_atom_mask": prot_np.atom_mask.astype(np.float32),
368
+ "residue_index": prot_np.residue_index,
369
+ }
370
+
371
+ batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
372
+ batch = make_atom14_positions(batch)
373
+
374
+ violations = loss.find_structural_violations_np(
375
+ batch=batch,
376
+ atom14_pred_positions=batch["atom14_gt_positions"],
377
+ config=ml_collections.ConfigDict(
378
+ {
379
+ "violation_tolerance_factor": 12, # Taken from model config.
380
+ "clash_overlap_tolerance": 1.5, # Taken from model config.
381
+ }
382
+ ),
383
+ )
384
+ violation_metrics = loss.compute_violation_metrics_np(
385
+ batch=batch,
386
+ atom14_pred_positions=batch["atom14_gt_positions"],
387
+ violations=violations,
388
+ )
389
+
390
+ return violations, violation_metrics
391
+
392
+
393
+ def get_violation_metrics(prot: protein.Protein):
394
+ """Computes violation and alignment metrics."""
395
+ structural_violations, struct_metrics = find_violations(prot)
396
+ violation_idx = np.flatnonzero(
397
+ structural_violations["total_per_residue_violations_mask"]
398
+ )
399
+
400
+ struct_metrics["residue_violations"] = violation_idx
401
+ struct_metrics["num_residue_violations"] = len(violation_idx)
402
+ struct_metrics["structural_violations"] = structural_violations
403
+ return struct_metrics
404
+
405
+
406
+ def _run_one_iteration(
407
+ *,
408
+ pdb_string: str,
409
+ max_iterations: int,
410
+ tolerance: float,
411
+ stiffness: float,
412
+ restraint_set: str,
413
+ max_attempts: int,
414
+ exclude_residues: Optional[Collection[int]] = None,
415
+ use_gpu: bool,
416
+ ):
417
+ """Runs the minimization pipeline.
418
+
419
+ Args:
420
+ pdb_string: A pdb string.
421
+ max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
422
+ A value of 0 specifies no limit.
423
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
424
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
425
+ potential.
426
+ restraint_set: The set of atoms to restrain.
427
+ max_attempts: The maximum number of minimization attempts.
428
+ exclude_residues: An optional list of zero-indexed residues to exclude from
429
+ restraints.
430
+ use_gpu: Whether to run relaxation on GPU
431
+ Returns:
432
+ A `dict` of minimization info.
433
+ """
434
+ exclude_residues = exclude_residues or []
435
+
436
+ # Assign physical dimensions.
437
+ tolerance = tolerance * ENERGY
438
+ stiffness = stiffness * ENERGY / (LENGTH ** 2)
439
+
440
+ start = time.perf_counter()
441
+ minimized = False
442
+ attempts = 0
443
+ while not minimized and attempts < max_attempts:
444
+ attempts += 1
445
+ try:
446
+ logging.info(
447
+ "Minimizing protein, attempt %d of %d.", attempts, max_attempts
448
+ )
449
+ ret = _openmm_minimize(
450
+ pdb_string,
451
+ max_iterations=max_iterations,
452
+ tolerance=tolerance,
453
+ stiffness=stiffness,
454
+ restraint_set=restraint_set,
455
+ exclude_residues=exclude_residues,
456
+ use_gpu=use_gpu,
457
+ )
458
+ minimized = True
459
+ except Exception as e: # pylint: disable=broad-except
460
+ print(e)
461
+ logging.info(e)
462
+ if not minimized:
463
+ raise ValueError(f"Minimization failed after {max_attempts} attempts.")
464
+ ret["opt_time"] = time.perf_counter() - start
465
+ ret["min_attempts"] = attempts
466
+ return ret
467
+
468
+
469
+ def run_pipeline(
470
+ prot: protein.Protein,
471
+ stiffness: float,
472
+ use_gpu: bool,
473
+ max_outer_iterations: int = 1,
474
+ place_hydrogens_every_iteration: bool = True,
475
+ max_iterations: int = 0,
476
+ tolerance: float = 2.39,
477
+ restraint_set: str = "non_hydrogen",
478
+ max_attempts: int = 100,
479
+ checks: bool = True,
480
+ exclude_residues: Optional[Sequence[int]] = None,
481
+ ):
482
+ """Run iterative amber relax.
483
+
484
+ Successive relax iterations are performed until all violations have been
485
+ resolved. Each iteration involves a restrained Amber minimization, with
486
+ restraint exclusions determined by violation-participating residues.
487
+
488
+ Args:
489
+ prot: A protein to be relaxed.
490
+ stiffness: kcal/mol A**2, the restraint stiffness.
491
+ use_gpu: Whether to run on GPU
492
+ max_outer_iterations: The maximum number of iterative minimization.
493
+ place_hydrogens_every_iteration: Whether hydrogens are re-initialized
494
+ prior to every minimization.
495
+ max_iterations: An `int` specifying the maximum number of L-BFGS steps
496
+ per relax iteration. A value of 0 specifies no limit.
497
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
498
+ The default value is the OpenMM default.
499
+ restraint_set: The set of atoms to restrain.
500
+ max_attempts: The maximum number of minimization attempts per iteration.
501
+ checks: Whether to perform cleaning checks.
502
+ exclude_residues: An optional list of zero-indexed residues to exclude from
503
+ restraints.
504
+
505
+ Returns:
506
+ out: A dictionary of output values.
507
+ """
508
+
509
+ # `protein.to_pdb` will strip any poorly-defined residues so we need to
510
+ # perform this check before `clean_protein`.
511
+ _check_residues_are_well_defined(prot)
512
+ pdb_string = clean_protein(prot, checks=checks)
513
+
514
+ exclude_residues = exclude_residues or []
515
+ exclude_residues = set(exclude_residues)
516
+ violations = np.inf
517
+ iteration = 0
518
+
519
+ while violations > 0 and iteration < max_outer_iterations:
520
+ ret = _run_one_iteration(
521
+ pdb_string=pdb_string,
522
+ exclude_residues=exclude_residues,
523
+ max_iterations=max_iterations,
524
+ tolerance=tolerance,
525
+ stiffness=stiffness,
526
+ restraint_set=restraint_set,
527
+ max_attempts=max_attempts,
528
+ use_gpu=use_gpu,
529
+ )
530
+ prot = protein.from_pdb_string(ret["min_pdb"])
531
+ if place_hydrogens_every_iteration:
532
+ pdb_string = clean_protein(prot, checks=True)
533
+ else:
534
+ pdb_string = ret["min_pdb"]
535
+ ret.update(get_violation_metrics(prot))
536
+ ret.update(
537
+ {
538
+ "num_exclusions": len(exclude_residues),
539
+ "iteration": iteration,
540
+ }
541
+ )
542
+ violations = ret["violations_per_residue"]
543
+ exclude_residues = exclude_residues.union(ret["residue_violations"])
544
+
545
+ logging.info(
546
+ "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
547
+ "num residue violations %d num residue exclusions %d ",
548
+ ret["einit"],
549
+ ret["efinal"],
550
+ ret["opt_time"],
551
+ ret["num_residue_violations"],
552
+ ret["num_exclusions"],
553
+ )
554
+ iteration += 1
555
+ return ret
556
+
557
+
558
+ def get_initial_energies(
559
+ pdb_strs: Sequence[str],
560
+ stiffness: float = 0.0,
561
+ restraint_set: str = "non_hydrogen",
562
+ exclude_residues: Optional[Sequence[int]] = None,
563
+ ):
564
+ """Returns initial potential energies for a sequence of PDBs.
565
+
566
+ Assumes the input PDBs are ready for minimization, and all have the same
567
+ topology.
568
+ Allows time to be saved by not pdbfixing / rebuilding the system.
569
+
570
+ Args:
571
+ pdb_strs: List of PDB strings.
572
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
573
+ potential.
574
+ restraint_set: Which atom types to restrain.
575
+ exclude_residues: An optional list of zero-indexed residues to exclude from
576
+ restraints.
577
+
578
+ Returns:
579
+ A list of initial energies in the same order as pdb_strs.
580
+ """
581
+ exclude_residues = exclude_residues or []
582
+
583
+ openmm_pdbs = [
584
+ openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
585
+ ]
586
+ force_field = openmm_app.ForceField("amber99sb.xml")
587
+ system = force_field.createSystem(
588
+ openmm_pdbs[0].topology, constraints=openmm_app.HBonds
589
+ )
590
+ stiffness = stiffness * ENERGY / (LENGTH ** 2)
591
+ if stiffness > 0 * ENERGY / (LENGTH ** 2):
592
+ _add_restraints(
593
+ system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
594
+ )
595
+ simulation = openmm_app.Simulation(
596
+ openmm_pdbs[0].topology,
597
+ system,
598
+ openmm.LangevinIntegrator(0, 0.01, 0.0),
599
+ openmm.Platform.getPlatformByName("CPU"),
600
+ )
601
+ energies = []
602
+ for pdb in openmm_pdbs:
603
+ try:
604
+ simulation.context.setPositions(pdb.positions)
605
+ state = simulation.context.getState(getEnergy=True)
606
+ energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
607
+ except Exception as e: # pylint: disable=broad-except
608
+ logging.error(
609
+ "Error getting initial energy, returning large value %s", e
610
+ )
611
+ energies.append(unit.Quantity(1e20, ENERGY))
612
+ return energies
openfold/np/relax/cleanup.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
16
+
17
+ fix_pdb uses a third-party tool. We also support fixing some additional edge
18
+ cases like removing chains of length one (see clean_structure).
19
+ """
20
+ import io
21
+
22
+ import pdbfixer
23
+ from simtk.openmm import app
24
+ from simtk.openmm.app import element
25
+
26
+
27
+ def fix_pdb(pdbfile, alterations_info):
28
+ """Apply pdbfixer to the contents of a PDB file; return a PDB string result.
29
+
30
+ 1) Replaces nonstandard residues.
31
+ 2) Removes heterogens (non protein residues) including water.
32
+ 3) Adds missing residues and missing atoms within existing residues.
33
+ 4) Adds hydrogens assuming pH=7.0.
34
+ 5) KeepIds is currently true, so the fixer must keep the existing chain and
35
+ residue identifiers. This will fail for some files in wider PDB that have
36
+ invalid IDs.
37
+
38
+ Args:
39
+ pdbfile: Input PDB file handle.
40
+ alterations_info: A dict that will store details of changes made.
41
+
42
+ Returns:
43
+ A PDB string representing the fixed structure.
44
+ """
45
+ fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
46
+ fixer.findNonstandardResidues()
47
+ alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
48
+ fixer.replaceNonstandardResidues()
49
+ _remove_heterogens(fixer, alterations_info, keep_water=False)
50
+ fixer.findMissingResidues()
51
+ alterations_info["missing_residues"] = fixer.missingResidues
52
+ fixer.findMissingAtoms()
53
+ alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
54
+ alterations_info["missing_terminals"] = fixer.missingTerminals
55
+ fixer.addMissingAtoms(seed=0)
56
+ fixer.addMissingHydrogens()
57
+ out_handle = io.StringIO()
58
+ app.PDBFile.writeFile(
59
+ fixer.topology, fixer.positions, out_handle, keepIds=True
60
+ )
61
+ return out_handle.getvalue()
62
+
63
+
64
+ def clean_structure(pdb_structure, alterations_info):
65
+ """Applies additional fixes to an OpenMM structure, to handle edge cases.
66
+
67
+ Args:
68
+ pdb_structure: An OpenMM structure to modify and fix.
69
+ alterations_info: A dict that will store details of changes made.
70
+ """
71
+ _replace_met_se(pdb_structure, alterations_info)
72
+ _remove_chains_of_length_one(pdb_structure, alterations_info)
73
+
74
+
75
+ def _remove_heterogens(fixer, alterations_info, keep_water):
76
+ """Removes the residues that Pdbfixer considers to be heterogens.
77
+
78
+ Args:
79
+ fixer: A Pdbfixer instance.
80
+ alterations_info: A dict that will store details of changes made.
81
+ keep_water: If True, water (HOH) is not considered to be a heterogen.
82
+ """
83
+ initial_resnames = set()
84
+ for chain in fixer.topology.chains():
85
+ for residue in chain.residues():
86
+ initial_resnames.add(residue.name)
87
+ fixer.removeHeterogens(keepWater=keep_water)
88
+ final_resnames = set()
89
+ for chain in fixer.topology.chains():
90
+ for residue in chain.residues():
91
+ final_resnames.add(residue.name)
92
+ alterations_info["removed_heterogens"] = initial_resnames.difference(
93
+ final_resnames
94
+ )
95
+
96
+
97
+ def _replace_met_se(pdb_structure, alterations_info):
98
+ """Replace the Se in any MET residues that were not marked as modified."""
99
+ modified_met_residues = []
100
+ for res in pdb_structure.iter_residues():
101
+ name = res.get_name_with_spaces().strip()
102
+ if name == "MET":
103
+ s_atom = res.get_atom("SD")
104
+ if s_atom.element_symbol == "Se":
105
+ s_atom.element_symbol = "S"
106
+ s_atom.element = element.get_by_symbol("S")
107
+ modified_met_residues.append(s_atom.residue_number)
108
+ alterations_info["Se_in_MET"] = modified_met_residues
109
+
110
+
111
+ def _remove_chains_of_length_one(pdb_structure, alterations_info):
112
+ """Removes chains that correspond to a single amino acid.
113
+
114
+ A single amino acid in a chain is both N and C terminus. There is no force
115
+ template for this case.
116
+
117
+ Args:
118
+ pdb_structure: An OpenMM pdb_structure to modify and fix.
119
+ alterations_info: A dict that will store details of changes made.
120
+ """
121
+ removed_chains = {}
122
+ for model in pdb_structure.iter_models():
123
+ valid_chains = [c for c in model.iter_chains() if len(c) > 1]
124
+ invalid_chain_ids = [
125
+ c.chain_id for c in model.iter_chains() if len(c) <= 1
126
+ ]
127
+ model.chains = valid_chains
128
+ for chain_id in invalid_chain_ids:
129
+ model.chains_by_id.pop(chain_id)
130
+ removed_chains[model.number] = invalid_chain_ids
131
+ alterations_info["removed_chains"] = removed_chains
openfold/np/relax/relax.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Amber relaxation."""
17
+ from typing import Any, Dict, Sequence, Tuple
18
+ from openfold.np import protein
19
+ from openfold.np.relax import amber_minimize, utils
20
+ import numpy as np
21
+
22
+
23
+ class AmberRelaxation(object):
24
+ """Amber relaxation."""
25
+ def __init__(
26
+ self,
27
+ *,
28
+ max_iterations: int,
29
+ tolerance: float,
30
+ stiffness: float,
31
+ exclude_residues: Sequence[int],
32
+ max_outer_iterations: int,
33
+ use_gpu: bool,
34
+ ):
35
+ """Initialize Amber Relaxer.
36
+
37
+ Args:
38
+ max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
39
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
40
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
41
+ potential.
42
+ exclude_residues: Residues to exclude from per-atom restraining.
43
+ Zero-indexed.
44
+ max_outer_iterations: Maximum number of violation-informed relax
45
+ iterations. A value of 1 will run the non-iterative procedure used in
46
+ CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
47
+ as soon as there are no violations, hence in most cases this causes no
48
+ slowdown. In the worst case we do 20 outer iterations.
49
+ use_gpu: Whether to run on GPU
50
+ """
51
+
52
+ self._max_iterations = max_iterations
53
+ self._tolerance = tolerance
54
+ self._stiffness = stiffness
55
+ self._exclude_residues = exclude_residues
56
+ self._max_outer_iterations = max_outer_iterations
57
+ self._use_gpu = use_gpu
58
+
59
+ def process(
60
+ self, *, prot: protein.Protein
61
+ ) -> Tuple[str, Dict[str, Any], np.ndarray]:
62
+ """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
63
+ out = amber_minimize.run_pipeline(
64
+ prot=prot,
65
+ max_iterations=self._max_iterations,
66
+ tolerance=self._tolerance,
67
+ stiffness=self._stiffness,
68
+ exclude_residues=self._exclude_residues,
69
+ max_outer_iterations=self._max_outer_iterations,
70
+ use_gpu=self._use_gpu,
71
+ )
72
+ min_pos = out["pos"]
73
+ start_pos = out["posinit"]
74
+ rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
75
+ debug_data = {
76
+ "initial_energy": out["einit"],
77
+ "final_energy": out["efinal"],
78
+ "attempts": out["min_attempts"],
79
+ "rmsd": rmsd,
80
+ }
81
+ pdb_str = amber_minimize.clean_protein(prot)
82
+ min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
83
+ min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
84
+ utils.assert_equal_nonterminal_atom_types(
85
+ protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
86
+ )
87
+ violations = out["structural_violations"][
88
+ "total_per_residue_violations_mask"
89
+ ]
90
+ return min_pdb, debug_data, violations
openfold/np/relax/utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Utils for minimization."""
17
+ import io
18
+ from openfold.np import residue_constants
19
+ from Bio import PDB
20
+ import numpy as np
21
+ # simtk.openmm is not supported anymore. Remove simtk.
22
+ # https://github.com/openmm/openmm/releases
23
+ from openmm import app as openmm_app
24
+ from openmm.app.internal.pdbstructure import PdbStructure
25
+
26
+
27
+ def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
28
+ pdb_file = io.StringIO(pdb_str)
29
+ structure = PdbStructure(pdb_file)
30
+ topology = openmm_app.PDBFile(structure).getTopology()
31
+ with io.StringIO() as f:
32
+ openmm_app.PDBFile.writeFile(topology, pos, f)
33
+ return f.getvalue()
34
+
35
+
36
+ def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
37
+ """Overwrites the B-factors in pdb_str with contents of bfactors array.
38
+
39
+ Args:
40
+ pdb_str: An input PDB string.
41
+ bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
42
+ B-factors are per residue; i.e. that the nonzero entries are identical in
43
+ [0, i, :].
44
+
45
+ Returns:
46
+ A new PDB string with the B-factors replaced.
47
+ """
48
+ if bfactors.shape[-1] != residue_constants.atom_type_num:
49
+ raise ValueError(
50
+ f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
51
+ )
52
+
53
+ parser = PDB.PDBParser(QUIET=True)
54
+ handle = io.StringIO(pdb_str)
55
+ structure = parser.get_structure("", handle)
56
+
57
+ curr_resid = ("", "", "")
58
+ idx = -1
59
+ for atom in structure.get_atoms():
60
+ atom_resid = atom.parent.get_id()
61
+ if atom_resid != curr_resid:
62
+ idx += 1
63
+ if idx >= bfactors.shape[0]:
64
+ raise ValueError(
65
+ "Index into bfactors exceeds number of residues. "
66
+ "B-factors shape: {shape}, idx: {idx}."
67
+ )
68
+ curr_resid = atom_resid
69
+ atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
70
+
71
+ new_pdb = io.StringIO()
72
+ pdb_io = PDB.PDBIO()
73
+ pdb_io.set_structure(structure)
74
+ pdb_io.save(new_pdb)
75
+ return new_pdb.getvalue()
76
+
77
+
78
+ def assert_equal_nonterminal_atom_types(
79
+ atom_mask: np.ndarray, ref_atom_mask: np.ndarray
80
+ ):
81
+ """Checks that pre- and post-minimized proteins have same atom set."""
82
+ # Ignore any terminal OXT atoms which may have been added by minimization.
83
+ oxt = residue_constants.atom_order["OXT"]
84
+ no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
85
+ no_oxt_mask[..., oxt] = False
86
+ np.testing.assert_almost_equal(
87
+ ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
88
+ )
openfold/np/residue_constants.py ADDED
@@ -0,0 +1,1310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Constants used in AlphaFold."""
17
+
18
+ import collections
19
+ import functools
20
+ from typing import Mapping, List, Tuple
21
+ from importlib import resources
22
+
23
+ import numpy as np
24
+ import tree
25
+
26
+ # Internal import (35fd).
27
+
28
+
29
+ # Distance from one CA to next CA [trans configuration: omega = 180].
30
+ ca_ca = 3.80209737096
31
+
32
+ # Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
33
+ # this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
34
+ # chi angles so their chi angle lists are empty.
35
+ chi_angles_atoms = {
36
+ "ALA": [],
37
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
38
+ "ARG": [
39
+ ["N", "CA", "CB", "CG"],
40
+ ["CA", "CB", "CG", "CD"],
41
+ ["CB", "CG", "CD", "NE"],
42
+ ["CG", "CD", "NE", "CZ"],
43
+ ],
44
+ "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
45
+ "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
46
+ "CYS": [["N", "CA", "CB", "SG"]],
47
+ "GLN": [
48
+ ["N", "CA", "CB", "CG"],
49
+ ["CA", "CB", "CG", "CD"],
50
+ ["CB", "CG", "CD", "OE1"],
51
+ ],
52
+ "GLU": [
53
+ ["N", "CA", "CB", "CG"],
54
+ ["CA", "CB", "CG", "CD"],
55
+ ["CB", "CG", "CD", "OE1"],
56
+ ],
57
+ "GLY": [],
58
+ "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
59
+ "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
60
+ "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
61
+ "LYS": [
62
+ ["N", "CA", "CB", "CG"],
63
+ ["CA", "CB", "CG", "CD"],
64
+ ["CB", "CG", "CD", "CE"],
65
+ ["CG", "CD", "CE", "NZ"],
66
+ ],
67
+ "MET": [
68
+ ["N", "CA", "CB", "CG"],
69
+ ["CA", "CB", "CG", "SD"],
70
+ ["CB", "CG", "SD", "CE"],
71
+ ],
72
+ "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
73
+ "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
74
+ "SER": [["N", "CA", "CB", "OG"]],
75
+ "THR": [["N", "CA", "CB", "OG1"]],
76
+ "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
77
+ "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
78
+ "VAL": [["N", "CA", "CB", "CG1"]],
79
+ }
80
+
81
+ # If chi angles given in fixed-length array, this matrix determines how to mask
82
+ # them for each AA type. The order is as per restype_order (see below).
83
+ chi_angles_mask = [
84
+ [0.0, 0.0, 0.0, 0.0], # ALA
85
+ [1.0, 1.0, 1.0, 1.0], # ARG
86
+ [1.0, 1.0, 0.0, 0.0], # ASN
87
+ [1.0, 1.0, 0.0, 0.0], # ASP
88
+ [1.0, 0.0, 0.0, 0.0], # CYS
89
+ [1.0, 1.0, 1.0, 0.0], # GLN
90
+ [1.0, 1.0, 1.0, 0.0], # GLU
91
+ [0.0, 0.0, 0.0, 0.0], # GLY
92
+ [1.0, 1.0, 0.0, 0.0], # HIS
93
+ [1.0, 1.0, 0.0, 0.0], # ILE
94
+ [1.0, 1.0, 0.0, 0.0], # LEU
95
+ [1.0, 1.0, 1.0, 1.0], # LYS
96
+ [1.0, 1.0, 1.0, 0.0], # MET
97
+ [1.0, 1.0, 0.0, 0.0], # PHE
98
+ [1.0, 1.0, 0.0, 0.0], # PRO
99
+ [1.0, 0.0, 0.0, 0.0], # SER
100
+ [1.0, 0.0, 0.0, 0.0], # THR
101
+ [1.0, 1.0, 0.0, 0.0], # TRP
102
+ [1.0, 1.0, 0.0, 0.0], # TYR
103
+ [1.0, 0.0, 0.0, 0.0], # VAL
104
+ ]
105
+
106
+ # The following chi angles are pi periodic: they can be rotated by a multiple
107
+ # of pi without affecting the structure.
108
+ chi_pi_periodic = [
109
+ [0.0, 0.0, 0.0, 0.0], # ALA
110
+ [0.0, 0.0, 0.0, 0.0], # ARG
111
+ [0.0, 0.0, 0.0, 0.0], # ASN
112
+ [0.0, 1.0, 0.0, 0.0], # ASP
113
+ [0.0, 0.0, 0.0, 0.0], # CYS
114
+ [0.0, 0.0, 0.0, 0.0], # GLN
115
+ [0.0, 0.0, 1.0, 0.0], # GLU
116
+ [0.0, 0.0, 0.0, 0.0], # GLY
117
+ [0.0, 0.0, 0.0, 0.0], # HIS
118
+ [0.0, 0.0, 0.0, 0.0], # ILE
119
+ [0.0, 0.0, 0.0, 0.0], # LEU
120
+ [0.0, 0.0, 0.0, 0.0], # LYS
121
+ [0.0, 0.0, 0.0, 0.0], # MET
122
+ [0.0, 1.0, 0.0, 0.0], # PHE
123
+ [0.0, 0.0, 0.0, 0.0], # PRO
124
+ [0.0, 0.0, 0.0, 0.0], # SER
125
+ [0.0, 0.0, 0.0, 0.0], # THR
126
+ [0.0, 0.0, 0.0, 0.0], # TRP
127
+ [0.0, 1.0, 0.0, 0.0], # TYR
128
+ [0.0, 0.0, 0.0, 0.0], # VAL
129
+ [0.0, 0.0, 0.0, 0.0], # UNK
130
+ ]
131
+
132
+ # Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
133
+ # psi and chi angles:
134
+ # 0: 'backbone group',
135
+ # 1: 'pre-omega-group', (empty)
136
+ # 2: 'phi-group', (currently empty, because it defines only hydrogens)
137
+ # 3: 'psi-group',
138
+ # 4,5,6,7: 'chi1,2,3,4-group'
139
+ # The atom positions are relative to the axis-end-atom of the corresponding
140
+ # rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
141
+ # is defined such that the dihedral-angle-definiting atom (the last entry in
142
+ # chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
143
+ # format: [atomname, group_idx, rel_position]
144
+ rigid_group_atom_positions = {
145
+ "ALA": [
146
+ ["N", 0, (-0.525, 1.363, 0.000)],
147
+ ["CA", 0, (0.000, 0.000, 0.000)],
148
+ ["C", 0, (1.526, -0.000, -0.000)],
149
+ ["CB", 0, (-0.529, -0.774, -1.205)],
150
+ ["O", 3, (0.627, 1.062, 0.000)],
151
+ ],
152
+ "ARG": [
153
+ ["N", 0, (-0.524, 1.362, -0.000)],
154
+ ["CA", 0, (0.000, 0.000, 0.000)],
155
+ ["C", 0, (1.525, -0.000, -0.000)],
156
+ ["CB", 0, (-0.524, -0.778, -1.209)],
157
+ ["O", 3, (0.626, 1.062, 0.000)],
158
+ ["CG", 4, (0.616, 1.390, -0.000)],
159
+ ["CD", 5, (0.564, 1.414, 0.000)],
160
+ ["NE", 6, (0.539, 1.357, -0.000)],
161
+ ["NH1", 7, (0.206, 2.301, 0.000)],
162
+ ["NH2", 7, (2.078, 0.978, -0.000)],
163
+ ["CZ", 7, (0.758, 1.093, -0.000)],
164
+ ],
165
+ "ASN": [
166
+ ["N", 0, (-0.536, 1.357, 0.000)],
167
+ ["CA", 0, (0.000, 0.000, 0.000)],
168
+ ["C", 0, (1.526, -0.000, -0.000)],
169
+ ["CB", 0, (-0.531, -0.787, -1.200)],
170
+ ["O", 3, (0.625, 1.062, 0.000)],
171
+ ["CG", 4, (0.584, 1.399, 0.000)],
172
+ ["ND2", 5, (0.593, -1.188, 0.001)],
173
+ ["OD1", 5, (0.633, 1.059, 0.000)],
174
+ ],
175
+ "ASP": [
176
+ ["N", 0, (-0.525, 1.362, -0.000)],
177
+ ["CA", 0, (0.000, 0.000, 0.000)],
178
+ ["C", 0, (1.527, 0.000, -0.000)],
179
+ ["CB", 0, (-0.526, -0.778, -1.208)],
180
+ ["O", 3, (0.626, 1.062, -0.000)],
181
+ ["CG", 4, (0.593, 1.398, -0.000)],
182
+ ["OD1", 5, (0.610, 1.091, 0.000)],
183
+ ["OD2", 5, (0.592, -1.101, -0.003)],
184
+ ],
185
+ "CYS": [
186
+ ["N", 0, (-0.522, 1.362, -0.000)],
187
+ ["CA", 0, (0.000, 0.000, 0.000)],
188
+ ["C", 0, (1.524, 0.000, 0.000)],
189
+ ["CB", 0, (-0.519, -0.773, -1.212)],
190
+ ["O", 3, (0.625, 1.062, -0.000)],
191
+ ["SG", 4, (0.728, 1.653, 0.000)],
192
+ ],
193
+ "GLN": [
194
+ ["N", 0, (-0.526, 1.361, -0.000)],
195
+ ["CA", 0, (0.000, 0.000, 0.000)],
196
+ ["C", 0, (1.526, 0.000, 0.000)],
197
+ ["CB", 0, (-0.525, -0.779, -1.207)],
198
+ ["O", 3, (0.626, 1.062, -0.000)],
199
+ ["CG", 4, (0.615, 1.393, 0.000)],
200
+ ["CD", 5, (0.587, 1.399, -0.000)],
201
+ ["NE2", 6, (0.593, -1.189, -0.001)],
202
+ ["OE1", 6, (0.634, 1.060, 0.000)],
203
+ ],
204
+ "GLU": [
205
+ ["N", 0, (-0.528, 1.361, 0.000)],
206
+ ["CA", 0, (0.000, 0.000, 0.000)],
207
+ ["C", 0, (1.526, -0.000, -0.000)],
208
+ ["CB", 0, (-0.526, -0.781, -1.207)],
209
+ ["O", 3, (0.626, 1.062, 0.000)],
210
+ ["CG", 4, (0.615, 1.392, 0.000)],
211
+ ["CD", 5, (0.600, 1.397, 0.000)],
212
+ ["OE1", 6, (0.607, 1.095, -0.000)],
213
+ ["OE2", 6, (0.589, -1.104, -0.001)],
214
+ ],
215
+ "GLY": [
216
+ ["N", 0, (-0.572, 1.337, 0.000)],
217
+ ["CA", 0, (0.000, 0.000, 0.000)],
218
+ ["C", 0, (1.517, -0.000, -0.000)],
219
+ ["O", 3, (0.626, 1.062, -0.000)],
220
+ ],
221
+ "HIS": [
222
+ ["N", 0, (-0.527, 1.360, 0.000)],
223
+ ["CA", 0, (0.000, 0.000, 0.000)],
224
+ ["C", 0, (1.525, 0.000, 0.000)],
225
+ ["CB", 0, (-0.525, -0.778, -1.208)],
226
+ ["O", 3, (0.625, 1.063, 0.000)],
227
+ ["CG", 4, (0.600, 1.370, -0.000)],
228
+ ["CD2", 5, (0.889, -1.021, 0.003)],
229
+ ["ND1", 5, (0.744, 1.160, -0.000)],
230
+ ["CE1", 5, (2.030, 0.851, 0.002)],
231
+ ["NE2", 5, (2.145, -0.466, 0.004)],
232
+ ],
233
+ "ILE": [
234
+ ["N", 0, (-0.493, 1.373, -0.000)],
235
+ ["CA", 0, (0.000, 0.000, 0.000)],
236
+ ["C", 0, (1.527, -0.000, -0.000)],
237
+ ["CB", 0, (-0.536, -0.793, -1.213)],
238
+ ["O", 3, (0.627, 1.062, -0.000)],
239
+ ["CG1", 4, (0.534, 1.437, -0.000)],
240
+ ["CG2", 4, (0.540, -0.785, -1.199)],
241
+ ["CD1", 5, (0.619, 1.391, 0.000)],
242
+ ],
243
+ "LEU": [
244
+ ["N", 0, (-0.520, 1.363, 0.000)],
245
+ ["CA", 0, (0.000, 0.000, 0.000)],
246
+ ["C", 0, (1.525, -0.000, -0.000)],
247
+ ["CB", 0, (-0.522, -0.773, -1.214)],
248
+ ["O", 3, (0.625, 1.063, -0.000)],
249
+ ["CG", 4, (0.678, 1.371, 0.000)],
250
+ ["CD1", 5, (0.530, 1.430, -0.000)],
251
+ ["CD2", 5, (0.535, -0.774, 1.200)],
252
+ ],
253
+ "LYS": [
254
+ ["N", 0, (-0.526, 1.362, -0.000)],
255
+ ["CA", 0, (0.000, 0.000, 0.000)],
256
+ ["C", 0, (1.526, 0.000, 0.000)],
257
+ ["CB", 0, (-0.524, -0.778, -1.208)],
258
+ ["O", 3, (0.626, 1.062, -0.000)],
259
+ ["CG", 4, (0.619, 1.390, 0.000)],
260
+ ["CD", 5, (0.559, 1.417, 0.000)],
261
+ ["CE", 6, (0.560, 1.416, 0.000)],
262
+ ["NZ", 7, (0.554, 1.387, 0.000)],
263
+ ],
264
+ "MET": [
265
+ ["N", 0, (-0.521, 1.364, -0.000)],
266
+ ["CA", 0, (0.000, 0.000, 0.000)],
267
+ ["C", 0, (1.525, 0.000, 0.000)],
268
+ ["CB", 0, (-0.523, -0.776, -1.210)],
269
+ ["O", 3, (0.625, 1.062, -0.000)],
270
+ ["CG", 4, (0.613, 1.391, -0.000)],
271
+ ["SD", 5, (0.703, 1.695, 0.000)],
272
+ ["CE", 6, (0.320, 1.786, -0.000)],
273
+ ],
274
+ "PHE": [
275
+ ["N", 0, (-0.518, 1.363, 0.000)],
276
+ ["CA", 0, (0.000, 0.000, 0.000)],
277
+ ["C", 0, (1.524, 0.000, -0.000)],
278
+ ["CB", 0, (-0.525, -0.776, -1.212)],
279
+ ["O", 3, (0.626, 1.062, -0.000)],
280
+ ["CG", 4, (0.607, 1.377, 0.000)],
281
+ ["CD1", 5, (0.709, 1.195, -0.000)],
282
+ ["CD2", 5, (0.706, -1.196, 0.000)],
283
+ ["CE1", 5, (2.102, 1.198, -0.000)],
284
+ ["CE2", 5, (2.098, -1.201, -0.000)],
285
+ ["CZ", 5, (2.794, -0.003, -0.001)],
286
+ ],
287
+ "PRO": [
288
+ ["N", 0, (-0.566, 1.351, -0.000)],
289
+ ["CA", 0, (0.000, 0.000, 0.000)],
290
+ ["C", 0, (1.527, -0.000, 0.000)],
291
+ ["CB", 0, (-0.546, -0.611, -1.293)],
292
+ ["O", 3, (0.621, 1.066, 0.000)],
293
+ ["CG", 4, (0.382, 1.445, 0.0)],
294
+ # ['CD', 5, (0.427, 1.440, 0.0)],
295
+ ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
296
+ ],
297
+ "SER": [
298
+ ["N", 0, (-0.529, 1.360, -0.000)],
299
+ ["CA", 0, (0.000, 0.000, 0.000)],
300
+ ["C", 0, (1.525, -0.000, -0.000)],
301
+ ["CB", 0, (-0.518, -0.777, -1.211)],
302
+ ["O", 3, (0.626, 1.062, -0.000)],
303
+ ["OG", 4, (0.503, 1.325, 0.000)],
304
+ ],
305
+ "THR": [
306
+ ["N", 0, (-0.517, 1.364, 0.000)],
307
+ ["CA", 0, (0.000, 0.000, 0.000)],
308
+ ["C", 0, (1.526, 0.000, -0.000)],
309
+ ["CB", 0, (-0.516, -0.793, -1.215)],
310
+ ["O", 3, (0.626, 1.062, 0.000)],
311
+ ["CG2", 4, (0.550, -0.718, -1.228)],
312
+ ["OG1", 4, (0.472, 1.353, 0.000)],
313
+ ],
314
+ "TRP": [
315
+ ["N", 0, (-0.521, 1.363, 0.000)],
316
+ ["CA", 0, (0.000, 0.000, 0.000)],
317
+ ["C", 0, (1.525, -0.000, 0.000)],
318
+ ["CB", 0, (-0.523, -0.776, -1.212)],
319
+ ["O", 3, (0.627, 1.062, 0.000)],
320
+ ["CG", 4, (0.609, 1.370, -0.000)],
321
+ ["CD1", 5, (0.824, 1.091, 0.000)],
322
+ ["CD2", 5, (0.854, -1.148, -0.005)],
323
+ ["CE2", 5, (2.186, -0.678, -0.007)],
324
+ ["CE3", 5, (0.622, -2.530, -0.007)],
325
+ ["NE1", 5, (2.140, 0.690, -0.004)],
326
+ ["CH2", 5, (3.028, -2.890, -0.013)],
327
+ ["CZ2", 5, (3.283, -1.543, -0.011)],
328
+ ["CZ3", 5, (1.715, -3.389, -0.011)],
329
+ ],
330
+ "TYR": [
331
+ ["N", 0, (-0.522, 1.362, 0.000)],
332
+ ["CA", 0, (0.000, 0.000, 0.000)],
333
+ ["C", 0, (1.524, -0.000, -0.000)],
334
+ ["CB", 0, (-0.522, -0.776, -1.213)],
335
+ ["O", 3, (0.627, 1.062, -0.000)],
336
+ ["CG", 4, (0.607, 1.382, -0.000)],
337
+ ["CD1", 5, (0.716, 1.195, -0.000)],
338
+ ["CD2", 5, (0.713, -1.194, -0.001)],
339
+ ["CE1", 5, (2.107, 1.200, -0.002)],
340
+ ["CE2", 5, (2.104, -1.201, -0.003)],
341
+ ["OH", 5, (4.168, -0.002, -0.005)],
342
+ ["CZ", 5, (2.791, -0.001, -0.003)],
343
+ ],
344
+ "VAL": [
345
+ ["N", 0, (-0.494, 1.373, -0.000)],
346
+ ["CA", 0, (0.000, 0.000, 0.000)],
347
+ ["C", 0, (1.527, -0.000, -0.000)],
348
+ ["CB", 0, (-0.533, -0.795, -1.213)],
349
+ ["O", 3, (0.627, 1.062, -0.000)],
350
+ ["CG1", 4, (0.540, 1.429, -0.000)],
351
+ ["CG2", 4, (0.533, -0.776, 1.203)],
352
+ ],
353
+ }
354
+
355
+ # A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
356
+ residue_atoms = {
357
+ "ALA": ["C", "CA", "CB", "N", "O"],
358
+ "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
359
+ "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
360
+ "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
361
+ "CYS": ["C", "CA", "CB", "N", "O", "SG"],
362
+ "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
363
+ "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
364
+ "GLY": ["C", "CA", "N", "O"],
365
+ "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
366
+ "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
367
+ "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
368
+ "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
369
+ "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
370
+ "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
371
+ "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
372
+ "SER": ["C", "CA", "CB", "N", "O", "OG"],
373
+ "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
374
+ "TRP": [
375
+ "C",
376
+ "CA",
377
+ "CB",
378
+ "CG",
379
+ "CD1",
380
+ "CD2",
381
+ "CE2",
382
+ "CE3",
383
+ "CZ2",
384
+ "CZ3",
385
+ "CH2",
386
+ "N",
387
+ "NE1",
388
+ "O",
389
+ ],
390
+ "TYR": [
391
+ "C",
392
+ "CA",
393
+ "CB",
394
+ "CG",
395
+ "CD1",
396
+ "CD2",
397
+ "CE1",
398
+ "CE2",
399
+ "CZ",
400
+ "N",
401
+ "O",
402
+ "OH",
403
+ ],
404
+ "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
405
+ }
406
+
407
+ # Naming swaps for ambiguous atom names.
408
+ # Due to symmetries in the amino acids the naming of atoms is ambiguous in
409
+ # 4 of the 20 amino acids.
410
+ # (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
411
+ # in LEU, VAL and ARG can be resolved by using the 3d constellations of
412
+ # the 'ambiguous' atoms and their neighbours)
413
+ # TODO: ^ interpret this
414
+ residue_atom_renaming_swaps = {
415
+ "ASP": {"OD1": "OD2"},
416
+ "GLU": {"OE1": "OE2"},
417
+ "PHE": {"CD1": "CD2", "CE1": "CE2"},
418
+ "TYR": {"CD1": "CD2", "CE1": "CE2"},
419
+ }
420
+
421
+ # Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
422
+ van_der_waals_radius = {
423
+ "C": 1.7,
424
+ "N": 1.55,
425
+ "O": 1.52,
426
+ "S": 1.8,
427
+ }
428
+
429
+ Bond = collections.namedtuple(
430
+ "Bond", ["atom1_name", "atom2_name", "length", "stddev"]
431
+ )
432
+ BondAngle = collections.namedtuple(
433
+ "BondAngle",
434
+ ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
435
+ )
436
+
437
+
438
+ @functools.lru_cache(maxsize=None)
439
+ def load_stereo_chemical_props() -> Tuple[
440
+ Mapping[str, List[Bond]],
441
+ Mapping[str, List[Bond]],
442
+ Mapping[str, List[BondAngle]],
443
+ ]:
444
+ """Load stereo_chemical_props.txt into a nice structure.
445
+
446
+ Load literature values for bond lengths and bond angles and translate
447
+ bond angles into the length of the opposite edge of the triangle
448
+ ("residue_virtual_bonds").
449
+
450
+ Returns:
451
+ residue_bonds: dict that maps resname --> list of Bond tuples
452
+ residue_virtual_bonds: dict that maps resname --> list of Bond tuples
453
+ residue_bond_angles: dict that maps resname --> list of BondAngle tuples
454
+ """
455
+ # TODO: this file should be downloaded in a setup script
456
+ stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
457
+
458
+ lines_iter = iter(stereo_chemical_props.splitlines())
459
+ # Load bond lengths.
460
+ residue_bonds = {}
461
+ next(lines_iter) # Skip header line.
462
+ for line in lines_iter:
463
+ if line.strip() == "-":
464
+ break
465
+ bond, resname, length, stddev = line.split()
466
+ atom1, atom2 = bond.split("-")
467
+ if resname not in residue_bonds:
468
+ residue_bonds[resname] = []
469
+ residue_bonds[resname].append(
470
+ Bond(atom1, atom2, float(length), float(stddev))
471
+ )
472
+ residue_bonds["UNK"] = []
473
+
474
+ # Load bond angles.
475
+ residue_bond_angles = {}
476
+ next(lines_iter) # Skip empty line.
477
+ next(lines_iter) # Skip header line.
478
+ for line in lines_iter:
479
+ if line.strip() == "-":
480
+ break
481
+ bond, resname, angle_degree, stddev_degree = line.split()
482
+ atom1, atom2, atom3 = bond.split("-")
483
+ if resname not in residue_bond_angles:
484
+ residue_bond_angles[resname] = []
485
+ residue_bond_angles[resname].append(
486
+ BondAngle(
487
+ atom1,
488
+ atom2,
489
+ atom3,
490
+ float(angle_degree) / 180.0 * np.pi,
491
+ float(stddev_degree) / 180.0 * np.pi,
492
+ )
493
+ )
494
+ residue_bond_angles["UNK"] = []
495
+
496
+ def make_bond_key(atom1_name, atom2_name):
497
+ """Unique key to lookup bonds."""
498
+ return "-".join(sorted([atom1_name, atom2_name]))
499
+
500
+ # Translate bond angles into distances ("virtual bonds").
501
+ residue_virtual_bonds = {}
502
+ for resname, bond_angles in residue_bond_angles.items():
503
+ # Create a fast lookup dict for bond lengths.
504
+ bond_cache = {}
505
+ for b in residue_bonds[resname]:
506
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
507
+ residue_virtual_bonds[resname] = []
508
+ for ba in bond_angles:
509
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
510
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
511
+
512
+ # Compute distance between atom1 and atom3 using the law of cosines
513
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
514
+ gamma = ba.angle_rad
515
+ length = np.sqrt(
516
+ bond1.length ** 2
517
+ + bond2.length ** 2
518
+ - 2 * bond1.length * bond2.length * np.cos(gamma)
519
+ )
520
+
521
+ # Propagation of uncertainty assuming uncorrelated errors.
522
+ dl_outer = 0.5 / length
523
+ dl_dgamma = (
524
+ 2 * bond1.length * bond2.length * np.sin(gamma)
525
+ ) * dl_outer
526
+ dl_db1 = (
527
+ 2 * bond1.length - 2 * bond2.length * np.cos(gamma)
528
+ ) * dl_outer
529
+ dl_db2 = (
530
+ 2 * bond2.length - 2 * bond1.length * np.cos(gamma)
531
+ ) * dl_outer
532
+ stddev = np.sqrt(
533
+ (dl_dgamma * ba.stddev) ** 2
534
+ + (dl_db1 * bond1.stddev) ** 2
535
+ + (dl_db2 * bond2.stddev) ** 2
536
+ )
537
+ residue_virtual_bonds[resname].append(
538
+ Bond(ba.atom1_name, ba.atom3name, length, stddev)
539
+ )
540
+
541
+ return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
542
+
543
+
544
+ # Between-residue bond lengths for general bonds (first element) and for Proline
545
+ # (second element).
546
+ between_res_bond_length_c_n = [1.329, 1.341]
547
+ between_res_bond_length_stddev_c_n = [0.014, 0.016]
548
+
549
+ # Between-residue cos_angles.
550
+ between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
551
+ between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
552
+
553
+ # This mapping is used when we need to store atom data in a format that requires
554
+ # fixed atom data size for every residue (e.g. a numpy array).
555
+ atom_types = [
556
+ "N",
557
+ "CA",
558
+ "C",
559
+ "CB",
560
+ "O",
561
+ "CG",
562
+ "CG1",
563
+ "CG2",
564
+ "OG",
565
+ "OG1",
566
+ "SG",
567
+ "CD",
568
+ "CD1",
569
+ "CD2",
570
+ "ND1",
571
+ "ND2",
572
+ "OD1",
573
+ "OD2",
574
+ "SD",
575
+ "CE",
576
+ "CE1",
577
+ "CE2",
578
+ "CE3",
579
+ "NE",
580
+ "NE1",
581
+ "NE2",
582
+ "OE1",
583
+ "OE2",
584
+ "CH2",
585
+ "NH1",
586
+ "NH2",
587
+ "OH",
588
+ "CZ",
589
+ "CZ2",
590
+ "CZ3",
591
+ "NZ",
592
+ "OXT",
593
+ ]
594
+ atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
595
+ atom_type_num = len(atom_types) # := 37.
596
+
597
+ # A compact atom encoding with 14 columns
598
+ # pylint: disable=line-too-long
599
+ # pylint: disable=bad-whitespace
600
+ restype_name_to_atom14_names = {
601
+ "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
602
+ "ARG": [
603
+ "N",
604
+ "CA",
605
+ "C",
606
+ "O",
607
+ "CB",
608
+ "CG",
609
+ "CD",
610
+ "NE",
611
+ "CZ",
612
+ "NH1",
613
+ "NH2",
614
+ "",
615
+ "",
616
+ "",
617
+ ],
618
+ "ASN": [
619
+ "N",
620
+ "CA",
621
+ "C",
622
+ "O",
623
+ "CB",
624
+ "CG",
625
+ "OD1",
626
+ "ND2",
627
+ "",
628
+ "",
629
+ "",
630
+ "",
631
+ "",
632
+ "",
633
+ ],
634
+ "ASP": [
635
+ "N",
636
+ "CA",
637
+ "C",
638
+ "O",
639
+ "CB",
640
+ "CG",
641
+ "OD1",
642
+ "OD2",
643
+ "",
644
+ "",
645
+ "",
646
+ "",
647
+ "",
648
+ "",
649
+ ],
650
+ "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
651
+ "GLN": [
652
+ "N",
653
+ "CA",
654
+ "C",
655
+ "O",
656
+ "CB",
657
+ "CG",
658
+ "CD",
659
+ "OE1",
660
+ "NE2",
661
+ "",
662
+ "",
663
+ "",
664
+ "",
665
+ "",
666
+ ],
667
+ "GLU": [
668
+ "N",
669
+ "CA",
670
+ "C",
671
+ "O",
672
+ "CB",
673
+ "CG",
674
+ "CD",
675
+ "OE1",
676
+ "OE2",
677
+ "",
678
+ "",
679
+ "",
680
+ "",
681
+ "",
682
+ ],
683
+ "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
684
+ "HIS": [
685
+ "N",
686
+ "CA",
687
+ "C",
688
+ "O",
689
+ "CB",
690
+ "CG",
691
+ "ND1",
692
+ "CD2",
693
+ "CE1",
694
+ "NE2",
695
+ "",
696
+ "",
697
+ "",
698
+ "",
699
+ ],
700
+ "ILE": [
701
+ "N",
702
+ "CA",
703
+ "C",
704
+ "O",
705
+ "CB",
706
+ "CG1",
707
+ "CG2",
708
+ "CD1",
709
+ "",
710
+ "",
711
+ "",
712
+ "",
713
+ "",
714
+ "",
715
+ ],
716
+ "LEU": [
717
+ "N",
718
+ "CA",
719
+ "C",
720
+ "O",
721
+ "CB",
722
+ "CG",
723
+ "CD1",
724
+ "CD2",
725
+ "",
726
+ "",
727
+ "",
728
+ "",
729
+ "",
730
+ "",
731
+ ],
732
+ "LYS": [
733
+ "N",
734
+ "CA",
735
+ "C",
736
+ "O",
737
+ "CB",
738
+ "CG",
739
+ "CD",
740
+ "CE",
741
+ "NZ",
742
+ "",
743
+ "",
744
+ "",
745
+ "",
746
+ "",
747
+ ],
748
+ "MET": [
749
+ "N",
750
+ "CA",
751
+ "C",
752
+ "O",
753
+ "CB",
754
+ "CG",
755
+ "SD",
756
+ "CE",
757
+ "",
758
+ "",
759
+ "",
760
+ "",
761
+ "",
762
+ "",
763
+ ],
764
+ "PHE": [
765
+ "N",
766
+ "CA",
767
+ "C",
768
+ "O",
769
+ "CB",
770
+ "CG",
771
+ "CD1",
772
+ "CD2",
773
+ "CE1",
774
+ "CE2",
775
+ "CZ",
776
+ "",
777
+ "",
778
+ "",
779
+ ],
780
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
781
+ "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
782
+ "THR": [
783
+ "N",
784
+ "CA",
785
+ "C",
786
+ "O",
787
+ "CB",
788
+ "OG1",
789
+ "CG2",
790
+ "",
791
+ "",
792
+ "",
793
+ "",
794
+ "",
795
+ "",
796
+ "",
797
+ ],
798
+ "TRP": [
799
+ "N",
800
+ "CA",
801
+ "C",
802
+ "O",
803
+ "CB",
804
+ "CG",
805
+ "CD1",
806
+ "CD2",
807
+ "NE1",
808
+ "CE2",
809
+ "CE3",
810
+ "CZ2",
811
+ "CZ3",
812
+ "CH2",
813
+ ],
814
+ "TYR": [
815
+ "N",
816
+ "CA",
817
+ "C",
818
+ "O",
819
+ "CB",
820
+ "CG",
821
+ "CD1",
822
+ "CD2",
823
+ "CE1",
824
+ "CE2",
825
+ "CZ",
826
+ "OH",
827
+ "",
828
+ "",
829
+ ],
830
+ "VAL": [
831
+ "N",
832
+ "CA",
833
+ "C",
834
+ "O",
835
+ "CB",
836
+ "CG1",
837
+ "CG2",
838
+ "",
839
+ "",
840
+ "",
841
+ "",
842
+ "",
843
+ "",
844
+ "",
845
+ ],
846
+ "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
847
+ }
848
+ # pylint: enable=line-too-long
849
+ # pylint: enable=bad-whitespace
850
+
851
+
852
+ # This is the standard residue order when coding AA type as a number.
853
+ # Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
854
+ restypes = [
855
+ "A",
856
+ "R",
857
+ "N",
858
+ "D",
859
+ "C",
860
+ "Q",
861
+ "E",
862
+ "G",
863
+ "H",
864
+ "I",
865
+ "L",
866
+ "K",
867
+ "M",
868
+ "F",
869
+ "P",
870
+ "S",
871
+ "T",
872
+ "W",
873
+ "Y",
874
+ "V",
875
+ ]
876
+ restype_order = {restype: i for i, restype in enumerate(restypes)}
877
+ restype_num = len(restypes) # := 20.
878
+ unk_restype_index = restype_num # Catch-all index for unknown restypes.
879
+
880
+ restypes_with_x = restypes + ["X"]
881
+ restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
882
+
883
+
884
+ def sequence_to_onehot(
885
+ sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
886
+ ) -> np.ndarray:
887
+ """Maps the given sequence into a one-hot encoded matrix.
888
+
889
+ Args:
890
+ sequence: An amino acid sequence.
891
+ mapping: A dictionary mapping amino acids to integers.
892
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
893
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
894
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
895
+ the mapping will throw an error.
896
+
897
+ Returns:
898
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
899
+ the sequence.
900
+
901
+ Raises:
902
+ ValueError: If the mapping doesn't contain values from 0 to
903
+ num_unique_aas - 1 without any gaps.
904
+ """
905
+ num_entries = max(mapping.values()) + 1
906
+
907
+ if sorted(set(mapping.values())) != list(range(num_entries)):
908
+ raise ValueError(
909
+ "The mapping must have values from 0 to num_unique_aas-1 "
910
+ "without any gaps. Got: %s" % sorted(mapping.values())
911
+ )
912
+
913
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int)
914
+
915
+ for aa_index, aa_type in enumerate(sequence):
916
+ if map_unknown_to_x:
917
+ if aa_type.isalpha() and aa_type.isupper():
918
+ aa_id = mapping.get(aa_type, mapping["X"])
919
+ else:
920
+ raise ValueError(
921
+ f"Invalid character in the sequence: {aa_type}"
922
+ )
923
+ else:
924
+ aa_id = mapping[aa_type]
925
+ one_hot_arr[aa_index, aa_id] = 1
926
+
927
+ return one_hot_arr
928
+
929
+
930
+ restype_1to3 = {
931
+ "A": "ALA",
932
+ "R": "ARG",
933
+ "N": "ASN",
934
+ "D": "ASP",
935
+ "C": "CYS",
936
+ "Q": "GLN",
937
+ "E": "GLU",
938
+ "G": "GLY",
939
+ "H": "HIS",
940
+ "I": "ILE",
941
+ "L": "LEU",
942
+ "K": "LYS",
943
+ "M": "MET",
944
+ "F": "PHE",
945
+ "P": "PRO",
946
+ "S": "SER",
947
+ "T": "THR",
948
+ "W": "TRP",
949
+ "Y": "TYR",
950
+ "V": "VAL",
951
+ }
952
+
953
+
954
+ # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
955
+ # 1-to-1 mapping of 3 letter names to one letter names. The latter contains
956
+ # many more, and less common, three letter names as keys and maps many of these
957
+ # to the same one letter name (including 'X' and 'U' which we don't use here).
958
+ restype_3to1 = {v: k for k, v in restype_1to3.items()}
959
+
960
+ # Define a restype name for all unknown residues.
961
+ unk_restype = "UNK"
962
+
963
+ resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
964
+ resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
965
+
966
+
967
+ # The mapping here uses hhblits convention, so that B is mapped to D, J and O
968
+ # are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
969
+ # remaining 20 amino acids are kept in alphabetical order.
970
+ # There are 2 non-amino acid codes, X (representing any amino acid) and
971
+ # "-" representing a missing amino acid in an alignment. The id for these
972
+ # codes is put at the end (20 and 21) so that they can easily be ignored if
973
+ # desired.
974
+ HHBLITS_AA_TO_ID = {
975
+ "A": 0,
976
+ "B": 2,
977
+ "C": 1,
978
+ "D": 2,
979
+ "E": 3,
980
+ "F": 4,
981
+ "G": 5,
982
+ "H": 6,
983
+ "I": 7,
984
+ "J": 20,
985
+ "K": 8,
986
+ "L": 9,
987
+ "M": 10,
988
+ "N": 11,
989
+ "O": 20,
990
+ "P": 12,
991
+ "Q": 13,
992
+ "R": 14,
993
+ "S": 15,
994
+ "T": 16,
995
+ "U": 1,
996
+ "V": 17,
997
+ "W": 18,
998
+ "X": 20,
999
+ "Y": 19,
1000
+ "Z": 3,
1001
+ "-": 21,
1002
+ }
1003
+
1004
+ # Partial inversion of HHBLITS_AA_TO_ID.
1005
+ ID_TO_HHBLITS_AA = {
1006
+ 0: "A",
1007
+ 1: "C", # Also U.
1008
+ 2: "D", # Also B.
1009
+ 3: "E", # Also Z.
1010
+ 4: "F",
1011
+ 5: "G",
1012
+ 6: "H",
1013
+ 7: "I",
1014
+ 8: "K",
1015
+ 9: "L",
1016
+ 10: "M",
1017
+ 11: "N",
1018
+ 12: "P",
1019
+ 13: "Q",
1020
+ 14: "R",
1021
+ 15: "S",
1022
+ 16: "T",
1023
+ 17: "V",
1024
+ 18: "W",
1025
+ 19: "Y",
1026
+ 20: "X", # Includes J and O.
1027
+ 21: "-",
1028
+ }
1029
+
1030
+ restypes_with_x_and_gap = restypes + ["X", "-"]
1031
+ MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
1032
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
1033
+ for i in range(len(restypes_with_x_and_gap))
1034
+ )
1035
+
1036
+
1037
+ def _make_standard_atom_mask() -> np.ndarray:
1038
+ """Returns [num_res_types, num_atom_types] mask array."""
1039
+ # +1 to account for unknown (all 0s).
1040
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=int)
1041
+ for restype, restype_letter in enumerate(restypes):
1042
+ restype_name = restype_1to3[restype_letter]
1043
+ atom_names = residue_atoms[restype_name]
1044
+ for atom_name in atom_names:
1045
+ atom_type = atom_order[atom_name]
1046
+ mask[restype, atom_type] = 1
1047
+ return mask
1048
+
1049
+
1050
+ STANDARD_ATOM_MASK = _make_standard_atom_mask()
1051
+
1052
+
1053
+ # A one hot representation for the first and second atoms defining the axis
1054
+ # of rotation for each chi-angle in each residue.
1055
+ def chi_angle_atom(atom_index: int) -> np.ndarray:
1056
+ """Define chi-angle rigid groups via one-hot representations."""
1057
+ chi_angles_index = {}
1058
+ one_hots = []
1059
+
1060
+ for k, v in chi_angles_atoms.items():
1061
+ indices = [atom_types.index(s[atom_index]) for s in v]
1062
+ indices.extend([-1] * (4 - len(indices)))
1063
+ chi_angles_index[k] = indices
1064
+
1065
+ for r in restypes:
1066
+ res3 = restype_1to3[r]
1067
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
1068
+ one_hots.append(one_hot)
1069
+
1070
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
1071
+ one_hot = np.stack(one_hots, axis=0)
1072
+ one_hot = np.transpose(one_hot, [0, 2, 1])
1073
+
1074
+ return one_hot
1075
+
1076
+
1077
+ chi_atom_1_one_hot = chi_angle_atom(1)
1078
+ chi_atom_2_one_hot = chi_angle_atom(2)
1079
+
1080
+ # An array like chi_angles_atoms but using indices rather than names.
1081
+ chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
1082
+ chi_angles_atom_indices = tree.map_structure(
1083
+ lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
1084
+ )
1085
+ chi_angles_atom_indices = np.array(
1086
+ [
1087
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
1088
+ for chi_atoms in chi_angles_atom_indices
1089
+ ]
1090
+ )
1091
+
1092
+ # Mapping from (res_name, atom_name) pairs to the atom's chi group index
1093
+ # and atom index within that group.
1094
+ chi_groups_for_atom = collections.defaultdict(list)
1095
+ for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
1096
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
1097
+ for atom_i, atom in enumerate(chi_group):
1098
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
1099
+ chi_groups_for_atom = dict(chi_groups_for_atom)
1100
+
1101
+
1102
+ def _make_rigid_transformation_4x4(ex, ey, translation):
1103
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
1104
+ # Normalize ex.
1105
+ ex_normalized = ex / np.linalg.norm(ex)
1106
+
1107
+ # make ey perpendicular to ex
1108
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
1109
+ ey_normalized /= np.linalg.norm(ey_normalized)
1110
+
1111
+ # compute ez as cross product
1112
+ eznorm = np.cross(ex_normalized, ey_normalized)
1113
+ m = np.stack(
1114
+ [ex_normalized, ey_normalized, eznorm, translation]
1115
+ ).transpose()
1116
+ m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
1117
+ return m
1118
+
1119
+
1120
+ # create an array with (restype, atomtype) --> rigid_group_idx
1121
+ # and an array with (restype, atomtype, coord) for the atom positions
1122
+ # and compute affine transformation matrices (4,4) from one rigid group to the
1123
+ # previous group
1124
+ restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
1125
+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
1126
+ restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
1127
+ restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
1128
+ restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
1129
+ restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
1130
+ restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
1131
+
1132
+
1133
+ def _make_rigid_group_constants():
1134
+ """Fill the arrays above."""
1135
+ for restype, restype_letter in enumerate(restypes):
1136
+ resname = restype_1to3[restype_letter]
1137
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[
1138
+ resname
1139
+ ]:
1140
+ atomtype = atom_order[atomname]
1141
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
1142
+ restype_atom37_mask[restype, atomtype] = 1
1143
+ restype_atom37_rigid_group_positions[
1144
+ restype, atomtype, :
1145
+ ] = atom_position
1146
+
1147
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
1148
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
1149
+ restype_atom14_mask[restype, atom14idx] = 1
1150
+ restype_atom14_rigid_group_positions[
1151
+ restype, atom14idx, :
1152
+ ] = atom_position
1153
+
1154
+ for restype, restype_letter in enumerate(restypes):
1155
+ resname = restype_1to3[restype_letter]
1156
+ atom_positions = {
1157
+ name: np.array(pos)
1158
+ for name, _, pos in rigid_group_atom_positions[resname]
1159
+ }
1160
+
1161
+ # backbone to backbone is the identity transform
1162
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
1163
+
1164
+ # pre-omega-frame to backbone (currently dummy identity matrix)
1165
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
1166
+
1167
+ # phi-frame to backbone
1168
+ mat = _make_rigid_transformation_4x4(
1169
+ ex=atom_positions["N"] - atom_positions["CA"],
1170
+ ey=np.array([1.0, 0.0, 0.0]),
1171
+ translation=atom_positions["N"],
1172
+ )
1173
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
1174
+
1175
+ # psi-frame to backbone
1176
+ mat = _make_rigid_transformation_4x4(
1177
+ ex=atom_positions["C"] - atom_positions["CA"],
1178
+ ey=atom_positions["CA"] - atom_positions["N"],
1179
+ translation=atom_positions["C"],
1180
+ )
1181
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
1182
+
1183
+ # chi1-frame to backbone
1184
+ if chi_angles_mask[restype][0]:
1185
+ base_atom_names = chi_angles_atoms[resname][0]
1186
+ base_atom_positions = [
1187
+ atom_positions[name] for name in base_atom_names
1188
+ ]
1189
+ mat = _make_rigid_transformation_4x4(
1190
+ ex=base_atom_positions[2] - base_atom_positions[1],
1191
+ ey=base_atom_positions[0] - base_atom_positions[1],
1192
+ translation=base_atom_positions[2],
1193
+ )
1194
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
1195
+
1196
+ # chi2-frame to chi1-frame
1197
+ # chi3-frame to chi2-frame
1198
+ # chi4-frame to chi3-frame
1199
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
1200
+ # previous frame
1201
+ for chi_idx in range(1, 4):
1202
+ if chi_angles_mask[restype][chi_idx]:
1203
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
1204
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
1205
+ mat = _make_rigid_transformation_4x4(
1206
+ ex=axis_end_atom_position,
1207
+ ey=np.array([-1.0, 0.0, 0.0]),
1208
+ translation=axis_end_atom_position,
1209
+ )
1210
+ restype_rigid_group_default_frame[
1211
+ restype, 4 + chi_idx, :, :
1212
+ ] = mat
1213
+
1214
+
1215
+ _make_rigid_group_constants()
1216
+
1217
+
1218
+ def make_atom14_dists_bounds(
1219
+ overlap_tolerance=1.5, bond_length_tolerance_factor=15
1220
+ ):
1221
+ """compute upper and lower bounds for bonds to assess violations."""
1222
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
1223
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
1224
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
1225
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
1226
+ for restype, restype_letter in enumerate(restypes):
1227
+ resname = restype_1to3[restype_letter]
1228
+ atom_list = restype_name_to_atom14_names[resname]
1229
+
1230
+ # create lower and upper bounds for clashes
1231
+ for atom1_idx, atom1_name in enumerate(atom_list):
1232
+ if not atom1_name:
1233
+ continue
1234
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
1235
+ for atom2_idx, atom2_name in enumerate(atom_list):
1236
+ if (not atom2_name) or atom1_idx == atom2_idx:
1237
+ continue
1238
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
1239
+ lower = atom1_radius + atom2_radius - overlap_tolerance
1240
+ upper = 1e10
1241
+ restype_atom14_bond_lower_bound[
1242
+ restype, atom1_idx, atom2_idx
1243
+ ] = lower
1244
+ restype_atom14_bond_lower_bound[
1245
+ restype, atom2_idx, atom1_idx
1246
+ ] = lower
1247
+ restype_atom14_bond_upper_bound[
1248
+ restype, atom1_idx, atom2_idx
1249
+ ] = upper
1250
+ restype_atom14_bond_upper_bound[
1251
+ restype, atom2_idx, atom1_idx
1252
+ ] = upper
1253
+
1254
+ # overwrite lower and upper bounds for bonds and angles
1255
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
1256
+ atom1_idx = atom_list.index(b.atom1_name)
1257
+ atom2_idx = atom_list.index(b.atom2_name)
1258
+ lower = b.length - bond_length_tolerance_factor * b.stddev
1259
+ upper = b.length + bond_length_tolerance_factor * b.stddev
1260
+ restype_atom14_bond_lower_bound[
1261
+ restype, atom1_idx, atom2_idx
1262
+ ] = lower
1263
+ restype_atom14_bond_lower_bound[
1264
+ restype, atom2_idx, atom1_idx
1265
+ ] = lower
1266
+ restype_atom14_bond_upper_bound[
1267
+ restype, atom1_idx, atom2_idx
1268
+ ] = upper
1269
+ restype_atom14_bond_upper_bound[
1270
+ restype, atom2_idx, atom1_idx
1271
+ ] = upper
1272
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
1273
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
1274
+ return {
1275
+ "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
1276
+ "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
1277
+ "stddev": restype_atom14_bond_stddev, # shape (21,14,14)
1278
+ }
1279
+
1280
+
1281
+ restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
1282
+ restype_atom14_ambiguous_atoms_swap_idx = np.tile(
1283
+ np.arange(14, dtype=int), (21, 1)
1284
+ )
1285
+
1286
+
1287
+ def _make_atom14_ambiguity_feats():
1288
+ for res, pairs in residue_atom_renaming_swaps.items():
1289
+ res_idx = restype_order[restype_3to1[res]]
1290
+ for atom1, atom2 in pairs.items():
1291
+ atom1_idx = restype_name_to_atom14_names[res].index(atom1)
1292
+ atom2_idx = restype_name_to_atom14_names[res].index(atom2)
1293
+ restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
1294
+ restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
1295
+ restype_atom14_ambiguous_atoms_swap_idx[
1296
+ res_idx, atom1_idx
1297
+ ] = atom2_idx
1298
+ restype_atom14_ambiguous_atoms_swap_idx[
1299
+ res_idx, atom2_idx
1300
+ ] = atom1_idx
1301
+
1302
+
1303
+ _make_atom14_ambiguity_feats()
1304
+
1305
+
1306
+ def aatype_to_str_sequence(aatype):
1307
+ return ''.join([
1308
+ restypes_with_x[aatype[i]]
1309
+ for i in range(len(aatype))
1310
+ ])