添加PepFlow模型初始代码
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +180 -0
- LICENSE +21 -0
- README.md +106 -3
- configs/learn_angle.yaml +74 -0
- environment.yml +261 -0
- eval/align.py +17 -0
- eval/energy.py +94 -0
- eval/foldx.py +77 -0
- eval/geometry.py +127 -0
- eval/run_esmfold.py +73 -0
- eval/run_esmif.py +33 -0
- eval/run_mpnn.py +146 -0
- eval/run_rfdiffusion.py +75 -0
- eval/run_scwrl4.py +30 -0
- eval/utils.py +106 -0
- models_con/edge.py +112 -0
- models_con/flow_model.py +472 -0
- models_con/ga.py +127 -0
- models_con/inference.py +101 -0
- models_con/ipa_pytorch.py +687 -0
- models_con/node.py +105 -0
- models_con/pep_dataloader.py +212 -0
- models_con/sample.py +145 -0
- models_con/torsion.py +239 -0
- models_con/torus.py +34 -0
- models_con/utils.py +72 -0
- openfold/config.py +4 -0
- openfold/model/__init__.py +16 -0
- openfold/model/dropout.py +78 -0
- openfold/model/embedders.py +352 -0
- openfold/model/evoformer.py +630 -0
- openfold/model/heads.py +251 -0
- openfold/model/model.py +446 -0
- openfold/model/msa.py +392 -0
- openfold/model/outer_product_mean.py +129 -0
- openfold/model/pair_transition.py +99 -0
- openfold/model/primitives.py +587 -0
- openfold/model/structure_module.py +820 -0
- openfold/model/template.py +333 -0
- openfold/model/torchscript.py +215 -0
- openfold/model/triangular_attention.py +139 -0
- openfold/model/triangular_multiplicative_update.py +127 -0
- openfold/np/__init__.py +16 -0
- openfold/np/protein.py +438 -0
- openfold/np/relax/__init__.py +16 -0
- openfold/np/relax/amber_minimize.py +612 -0
- openfold/np/relax/cleanup.py +131 -0
- openfold/np/relax/relax.py +90 -0
- openfold/np/relax/utils.py +88 -0
- openfold/np/residue_constants.py +1310 -0
.gitignore
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
logs/
|
| 2 |
+
lightning_logs/
|
| 3 |
+
wandb/
|
| 4 |
+
Data/
|
| 5 |
+
wandb*.json
|
| 6 |
+
**/*.ckpt
|
| 7 |
+
**/*.pth
|
| 8 |
+
**/*.json
|
| 9 |
+
.vscode/
|
| 10 |
+
**/*.pdb
|
| 11 |
+
ckpt/
|
| 12 |
+
*.code-workspace
|
| 13 |
+
outputs/
|
| 14 |
+
**/*.txt
|
| 15 |
+
**/lightning_logs/
|
| 16 |
+
**/inference_outputs/
|
| 17 |
+
.hydra
|
| 18 |
+
preprocessed/
|
| 19 |
+
misc/
|
| 20 |
+
|
| 21 |
+
# Byte-compiled / optimized / DLL files
|
| 22 |
+
__pycache__/
|
| 23 |
+
*.py[cod]
|
| 24 |
+
*$py.class
|
| 25 |
+
|
| 26 |
+
# C extensions
|
| 27 |
+
*.so
|
| 28 |
+
|
| 29 |
+
# Distribution / packaging
|
| 30 |
+
.Python
|
| 31 |
+
build/
|
| 32 |
+
develop-eggs/
|
| 33 |
+
dist/
|
| 34 |
+
downloads/
|
| 35 |
+
eggs/
|
| 36 |
+
.eggs/
|
| 37 |
+
lib/
|
| 38 |
+
lib64/
|
| 39 |
+
parts/
|
| 40 |
+
sdist/
|
| 41 |
+
var/
|
| 42 |
+
wheels/
|
| 43 |
+
share/python-wheels/
|
| 44 |
+
*.egg-info/
|
| 45 |
+
.installed.cfg
|
| 46 |
+
*.egg
|
| 47 |
+
MANIFEST
|
| 48 |
+
|
| 49 |
+
# PyInstaller
|
| 50 |
+
# Usually these files are written by a python script from a template
|
| 51 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 52 |
+
*.manifest
|
| 53 |
+
*.spec
|
| 54 |
+
|
| 55 |
+
# Installer logs
|
| 56 |
+
pip-log.txt
|
| 57 |
+
pip-delete-this-directory.txt
|
| 58 |
+
|
| 59 |
+
# Unit test / coverage reports
|
| 60 |
+
htmlcov/
|
| 61 |
+
.tox/
|
| 62 |
+
.nox/
|
| 63 |
+
.coverage
|
| 64 |
+
.coverage.*
|
| 65 |
+
.cache
|
| 66 |
+
nosetests.xml
|
| 67 |
+
coverage.xml
|
| 68 |
+
*.cover
|
| 69 |
+
*.py,cover
|
| 70 |
+
.hypothesis/
|
| 71 |
+
.pytest_cache/
|
| 72 |
+
cover/
|
| 73 |
+
|
| 74 |
+
# Translations
|
| 75 |
+
*.mo
|
| 76 |
+
*.pot
|
| 77 |
+
|
| 78 |
+
# Django stuff:
|
| 79 |
+
*.log
|
| 80 |
+
local_settings.py
|
| 81 |
+
db.sqlite3
|
| 82 |
+
db.sqlite3-journal
|
| 83 |
+
|
| 84 |
+
# Flask stuff:
|
| 85 |
+
instance/
|
| 86 |
+
.webassets-cache
|
| 87 |
+
|
| 88 |
+
# Scrapy stuff:
|
| 89 |
+
.scrapy
|
| 90 |
+
|
| 91 |
+
# Sphinx documentation
|
| 92 |
+
docs/_build/
|
| 93 |
+
|
| 94 |
+
# PyBuilder
|
| 95 |
+
.pybuilder/
|
| 96 |
+
target/
|
| 97 |
+
|
| 98 |
+
# Jupyter Notebook
|
| 99 |
+
.ipynb_checkpoints
|
| 100 |
+
|
| 101 |
+
# IPython
|
| 102 |
+
profile_default/
|
| 103 |
+
ipython_config.py
|
| 104 |
+
|
| 105 |
+
# pyenv
|
| 106 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 107 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 108 |
+
# .python-version
|
| 109 |
+
|
| 110 |
+
# pipenv
|
| 111 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 112 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 113 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 114 |
+
# install all needed dependencies.
|
| 115 |
+
#Pipfile.lock
|
| 116 |
+
|
| 117 |
+
# poetry
|
| 118 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 119 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 120 |
+
# commonly ignored for libraries.
|
| 121 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 122 |
+
#poetry.lock
|
| 123 |
+
|
| 124 |
+
# pdm
|
| 125 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 126 |
+
#pdm.lock
|
| 127 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 128 |
+
# in version control.
|
| 129 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 130 |
+
.pdm.toml
|
| 131 |
+
|
| 132 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 133 |
+
__pypackages__/
|
| 134 |
+
|
| 135 |
+
# Celery stuff
|
| 136 |
+
celerybeat-schedule
|
| 137 |
+
celerybeat.pid
|
| 138 |
+
|
| 139 |
+
# SageMath parsed files
|
| 140 |
+
*.sage.py
|
| 141 |
+
|
| 142 |
+
# Environments
|
| 143 |
+
.env
|
| 144 |
+
.venv
|
| 145 |
+
env/
|
| 146 |
+
venv/
|
| 147 |
+
ENV/
|
| 148 |
+
env.bak/
|
| 149 |
+
venv.bak/
|
| 150 |
+
|
| 151 |
+
# Spyder project settings
|
| 152 |
+
.spyderproject
|
| 153 |
+
.spyproject
|
| 154 |
+
|
| 155 |
+
# Rope project settings
|
| 156 |
+
.ropeproject
|
| 157 |
+
|
| 158 |
+
# mkdocs documentation
|
| 159 |
+
/site
|
| 160 |
+
|
| 161 |
+
# mypy
|
| 162 |
+
.mypy_cache/
|
| 163 |
+
.dmypy.json
|
| 164 |
+
dmypy.json
|
| 165 |
+
|
| 166 |
+
# Pyre type checker
|
| 167 |
+
.pyre/
|
| 168 |
+
|
| 169 |
+
# pytype static type analyzer
|
| 170 |
+
.pytype/
|
| 171 |
+
|
| 172 |
+
# Cython debug symbols
|
| 173 |
+
cython_debug/
|
| 174 |
+
|
| 175 |
+
# PyCharm
|
| 176 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 177 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 178 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 179 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 180 |
+
#.idea/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Cedlijh
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,106 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PepFlow: Full-Atom Peptide Design
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
This repository contains the official implementation of 💡 Full-Atom Peptide Design based on Multi-modal Flow Matching (ICML 2024).
|
| 7 |
+
|
| 8 |
+
You can find our [paper](https://arxiv.org/abs/2406.00735) here. We also appreciate the inspiration from [diffab](https://github.com/luost26/diffab) and [frameflow](https://github.com/microsoft/protein-frame-flow).
|
| 9 |
+
|
| 10 |
+
If you have any questions, please contact lijiahanypc@pku.edu.cn or ced3ljhypc@gmail.com. Thank you! :)
|
| 11 |
+
|
| 12 |
+
## Install
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
### Environment
|
| 16 |
+
|
| 17 |
+
Please replace cuda and torch version to match your machine, here we test our code on CUDA >= 11.7, we also suggest using [micromamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html) as a replace of conda.
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
conda env create -f environment.yml # or use micromamba instead of conda
|
| 21 |
+
|
| 22 |
+
conda activate flow
|
| 23 |
+
|
| 24 |
+
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+cu117.html
|
| 25 |
+
|
| 26 |
+
pip install joblib lmdb easydict
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### Clone Repo### Train
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
git clone https://github.com/Ced3-han/PepFlowww.git
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
We suggest adding the code to the Python environment variable, or you can use setup tools.
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
export PYTHONPATH=$(pwd):$PYTHONPATH
|
| 40 |
+
python setup.py develop
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
### Data and Weights Download
|
| 45 |
+
|
| 46 |
+
We provide data and pretrained model weights [here](https://drive.google.com/drive/folders/1bHaKDF3uCDPtfsihjZs0zmjwF6UU1uVl?usp=sharing).
|
| 47 |
+
|
| 48 |
+
+ PepMerge_release.zip: 1.2GB
|
| 49 |
+
+ PepMerge_lmdb.zip: 180MB
|
| 50 |
+
+ model1.pt: 80MB
|
| 51 |
+
+ model2.pt: 80MB
|
| 52 |
+
|
| 53 |
+
The ```PepMerge_release.zip``` contains filtered data of peptide-receptor pairs. For example, in the folder ```1a0n_A```, the ```P``` chain in the PDB file ```1a0n``` is the peptide. In this folder, we provide the FASTA and PDB files of the peptide and receptor. The postfix _merge means the peptide and receptor are in the same PDB file. We also extract the binding pocket of the receptor, where our model is trained to generate peptides based on the binding pocket. You can also download [PepBDB](http://huanglab.phys.hust.edu.cn/pepbdb/db/1cta_A/) and [QBioLip](https://yanglab.qd.sdu.edu.cn/Q-BioLiP/Download), and use ```playgrounds/gen_dataset```.ipynb to reproduce the dataset.
|
| 54 |
+
|
| 55 |
+
The ```PepMerge_lmdb.zip``` contains several different splits of the dataset. We use ```mmseqs2``` to cluster complexes based on receptor sequence identity. See ```playgrounds/cluster.ipynb``` for details. The names.txt file contains the names of complexes in the test set. You can use ```models_con/pep_dataloader.py``` to load these datasets. We suggest putting these LMDBs in a single ```Data``` folder.
|
| 56 |
+
|
| 57 |
+
Besides, ```model1.pt``` and ```model2.pt``` are two checkpoints that you can load using ```models_con/flow_model.py``` together with the config file configs/learn_angle.yaml. We suggest using model1 for benchmark evaluation and model2 for real-world peptide design tasks, the latter is trained on a larger dataset.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
## Usage
|
| 61 |
+
|
| 62 |
+
We will add more user-friendly straightforward pipelines (generation and evaluation) later.
|
| 63 |
+
|
| 64 |
+
### Inference and Generate
|
| 65 |
+
|
| 66 |
+
By default, we support sampling of generated peptides from our processed dataset. You can use ```models_con/sample.py``` to sample, and ```models_con/inference.py``` to reconstruct PDB files.
|
| 67 |
+
|
| 68 |
+
If you want to use your own data, you can organize your data (peptide and pocket) as we did in PepMerge_release and construct a dataset for sampling and reconstruction. You can also use ```models_con/pep_dataloader/preprocess_structure``` to parse a single data point.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
### Evaluation
|
| 74 |
+
|
| 75 |
+
Our evaluation involves many third-party packages, and we include some useful evaluation scripts in ```eval```. Please refer to our paper for details and download the corresponding packages for evaluation. Please use different python environments for these tools.
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
### Train
|
| 80 |
+
|
| 81 |
+
You can also ```train.py``` on single GPU training and ```train_ddp.py``` for multiple GPT training.
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
## Future Work
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
Future improvements on peptide generation models may include chemical modifications, non-canonical amino acids, pretraining on larger datasets, language models, better sampling methods, etc. Stay tuned and feel free to contact us for collaboration and discussion!
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
## Reference
|
| 92 |
+
|
| 93 |
+
```bibtex
|
| 94 |
+
@InProceedings{pmlr-v235-li24o,
|
| 95 |
+
title={Full-Atom Peptide Design based on Multi-modal Flow Matching},
|
| 96 |
+
author={Li, Jiahan and Cheng, Chaoran and Wu, Zuofan and Guo, Ruihan and Luo, Shitong and Ren, Zhizhou and Peng, Jian and Ma, Jianzhu},
|
| 97 |
+
booktitle={Proceedings of the 41st International Conference on Machine Learning},
|
| 98 |
+
pages={27615--27640},
|
| 99 |
+
year={2024},
|
| 100 |
+
editor={Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
|
| 101 |
+
volume={235},
|
| 102 |
+
series={Proceedings of Machine Learning Research},
|
| 103 |
+
month=21--27 Jul},
|
| 104 |
+
publisher={PMLR},
|
| 105 |
+
}
|
| 106 |
+
```
|
configs/learn_angle.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
encoder:
|
| 3 |
+
node_embed_size: 128
|
| 4 |
+
edge_embed_size: 64
|
| 5 |
+
ipa:
|
| 6 |
+
c_s: 128 #${model.node_embed_size}
|
| 7 |
+
c_z: 64 #${model.edge_embed_size}
|
| 8 |
+
c_hidden: 128
|
| 9 |
+
no_heads: 8
|
| 10 |
+
no_qk_points: 8
|
| 11 |
+
no_v_points: 12
|
| 12 |
+
seq_tfmr_num_heads: 4
|
| 13 |
+
seq_tfmr_num_layers: 2
|
| 14 |
+
num_blocks: 6
|
| 15 |
+
stop_grad: False
|
| 16 |
+
interpolant:
|
| 17 |
+
min_t: 1.e-2
|
| 18 |
+
t_normalization_clip: 0.9
|
| 19 |
+
sample_sequence: True
|
| 20 |
+
sample_structure: True
|
| 21 |
+
rots:
|
| 22 |
+
train_schedule: linear
|
| 23 |
+
sample_schedule: exp
|
| 24 |
+
exp_rate: 10
|
| 25 |
+
trans:
|
| 26 |
+
train_schedule: linear
|
| 27 |
+
sample_schedule: linear
|
| 28 |
+
sigma: 1.0
|
| 29 |
+
seqs:
|
| 30 |
+
num_classes: 20
|
| 31 |
+
simplex_value: 5.0
|
| 32 |
+
sampling:
|
| 33 |
+
num_timesteps: 100
|
| 34 |
+
self_condition: False
|
| 35 |
+
|
| 36 |
+
train:
|
| 37 |
+
loss_weights:
|
| 38 |
+
trans_loss: 0.5 # 1.0 for dreamfold, 0.05 for yim
|
| 39 |
+
rot_loss: 0.5 # 1.0 for dreamfold, 0.5 for yim
|
| 40 |
+
bb_atom_loss: 0.25
|
| 41 |
+
seqs_loss: 1.0
|
| 42 |
+
angle_loss: 1.0
|
| 43 |
+
torsion_loss: 0.5
|
| 44 |
+
max_iters: 400000000
|
| 45 |
+
val_freq: 20000
|
| 46 |
+
batch_size: 32
|
| 47 |
+
accum_grad: 1
|
| 48 |
+
seed: 114514
|
| 49 |
+
max_grad_norm: 100.0
|
| 50 |
+
optimizer:
|
| 51 |
+
type: adam
|
| 52 |
+
lr: 5.e-4 #1.e-4
|
| 53 |
+
weight_decay: 0.0
|
| 54 |
+
beta1: 0.9
|
| 55 |
+
beta2: 0.999
|
| 56 |
+
scheduler:
|
| 57 |
+
type: plateau
|
| 58 |
+
factor: 0.8
|
| 59 |
+
patience: 10
|
| 60 |
+
min_lr: 5.e-6
|
| 61 |
+
|
| 62 |
+
dataset:
|
| 63 |
+
train:
|
| 64 |
+
type: peprec
|
| 65 |
+
structure_dir: /datapool/data2/home/jiahan/Data/PepMerge_new/
|
| 66 |
+
dataset_dir: /datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Fixed Data
|
| 67 |
+
name: pep_pocket_train
|
| 68 |
+
reset: False
|
| 69 |
+
val:
|
| 70 |
+
type: peprec
|
| 71 |
+
structure_dir: /datapool/data2/home/jiahan/Data/PepMerge_new/
|
| 72 |
+
dataset_dir: /datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Fixed Data
|
| 73 |
+
name: pep_pocket_test
|
| 74 |
+
reset: False
|
environment.yml
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: flow
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- nvidia
|
| 5 |
+
- pytorch
|
| 6 |
+
dependencies:
|
| 7 |
+
- _libgcc_mutex==0.1=conda_forge
|
| 8 |
+
- _openmp_mutex==4.5=2_gnu
|
| 9 |
+
- anyio==3.7.1=pyhd8ed1ab_0
|
| 10 |
+
- argon2-cffi==21.3.0=pyhd8ed1ab_0
|
| 11 |
+
- argon2-cffi-bindings==21.2.0=py310h5764c6d_3
|
| 12 |
+
- arrow==1.2.3=pyhd8ed1ab_0
|
| 13 |
+
- asttokens==2.2.1=pyhd8ed1ab_0
|
| 14 |
+
- astunparse==1.6.3=pyhd8ed1ab_0
|
| 15 |
+
- async-lru==2.0.4=pyhd8ed1ab_0
|
| 16 |
+
- attrs==23.1.0=pyh71513ae_1
|
| 17 |
+
- babel==2.12.1=pyhd8ed1ab_1
|
| 18 |
+
- backcall==0.2.0=pyh9f0ad1d_0
|
| 19 |
+
- backports==1.0=pyhd8ed1ab_3
|
| 20 |
+
- backports.functools_lru_cache==1.6.5=pyhd8ed1ab_0
|
| 21 |
+
- beautifulsoup4==4.12.2=pyha770c72_0
|
| 22 |
+
- biopython==1.81=py310h1fa729e_0
|
| 23 |
+
- biotite==0.38.0
|
| 24 |
+
- bleach==6.0.0=pyhd8ed1ab_0
|
| 25 |
+
- blosc==1.21.4=h0f2a231_0
|
| 26 |
+
- brotli==1.0.9=h166bdaf_9
|
| 27 |
+
- brotli-bin==1.0.9=h166bdaf_9
|
| 28 |
+
- brotli-python==1.0.9=py310hd8f1fbe_9
|
| 29 |
+
- bzip2==1.0.8=h7f98852_4
|
| 30 |
+
- c-ares==1.19.1=hd590300_0
|
| 31 |
+
- c-blosc2==2.10.2=hb4ffafa_0
|
| 32 |
+
- ca-certificates==2023.7.22=hbcca054_0
|
| 33 |
+
- cached-property==1.5.2=hd8ed1ab_1
|
| 34 |
+
- cached_property==1.5.2=pyha770c72_1
|
| 35 |
+
- certifi==2023.7.22=pyhd8ed1ab_0
|
| 36 |
+
- cffi==1.15.1=py310h255011f_3
|
| 37 |
+
- charset-normalizer==3.2.0=pyhd8ed1ab_0
|
| 38 |
+
- comm==0.1.4=pyhd8ed1ab_0
|
| 39 |
+
- contourpy==1.1.0=py310hd41b1e2_0
|
| 40 |
+
- cuda==11.6.0=0
|
| 41 |
+
- cuda-cccl==11.6.55=hf6102b2_0
|
| 42 |
+
- cuda-command-line-tools==11.6.2=0
|
| 43 |
+
- cuda-compiler==11.6.2=0
|
| 44 |
+
- cuda-cudart==11.6.55=he381448_0
|
| 45 |
+
- cuda-cudart-dev==11.6.55=h42ad0f4_0
|
| 46 |
+
- cuda-cuobjdump==11.6.124=h2eeebcb_0
|
| 47 |
+
- cuda-cupti==11.6.124=h86345e5_0
|
| 48 |
+
- cuda-cuxxfilt==11.6.124=hecbf4f6_0
|
| 49 |
+
- cuda-driver-dev==11.6.55=0
|
| 50 |
+
- cuda-gdb==12.0.90=hd47b8d6_0
|
| 51 |
+
- cuda-libraries==11.6.2=0
|
| 52 |
+
- cuda-libraries-dev==11.6.0=0
|
| 53 |
+
- cuda-memcheck==11.8.86=0
|
| 54 |
+
- cuda-nsight==12.0.78=ha770c72_0
|
| 55 |
+
- cuda-nsight-compute==12.2.1=0
|
| 56 |
+
- cuda-nvcc==11.6.124=hbba6d2d_0
|
| 57 |
+
- cuda-nvdisasm==12.0.76=h59595ed_0
|
| 58 |
+
- cuda-nvml-dev==11.6.55=haa9ef22_0
|
| 59 |
+
- cuda-nvprof==12.0.90=h59595ed_0
|
| 60 |
+
- cuda-nvprune==11.6.124=he22ec0a_0
|
| 61 |
+
- cuda-nvrtc==11.6.124=h020bade_0
|
| 62 |
+
- cuda-nvrtc-dev==11.6.124=h249d397_0
|
| 63 |
+
- cuda-nvtx==11.6.124=h0630a44_0
|
| 64 |
+
- cuda-nvvp==12.0.90=h59595ed_0
|
| 65 |
+
- cuda-runtime==11.6.2=0
|
| 66 |
+
- cuda-samples==11.6.101=h8efea70_0
|
| 67 |
+
- cuda-sanitizer-api==12.0.90=h59595ed_0
|
| 68 |
+
- cuda-toolkit==11.6.0=0
|
| 69 |
+
- cuda-tools==11.6.0=0
|
| 70 |
+
- cuda-version==12.0=hffde075_2
|
| 71 |
+
- cuda-visual-tools==11.6.0=0
|
| 72 |
+
- cycler==0.11.0=pyhd8ed1ab_0
|
| 73 |
+
- debugpy==1.6.8=py310hc6cd4ac_0
|
| 74 |
+
- decorator==5.1.1=pyhd8ed1ab_0
|
| 75 |
+
- defusedxml==0.7.1=pyhd8ed1ab_0
|
| 76 |
+
- entrypoints==0.4=pyhd8ed1ab_0
|
| 77 |
+
- exceptiongroup==1.1.3=pyhd8ed1ab_0
|
| 78 |
+
- executing==1.2.0=pyhd8ed1ab_0
|
| 79 |
+
- flit-core==3.9.0=pyhd8ed1ab_0
|
| 80 |
+
- fonttools==4.42.0=py310h2372a71_0
|
| 81 |
+
- fqdn==1.5.1=pyhd8ed1ab_0
|
| 82 |
+
- freetype==2.12.1=hca18f0e_1
|
| 83 |
+
- gds-tools==1.5.0.59=hcb278e6_0
|
| 84 |
+
- gmp==6.2.1=h58526e2_0
|
| 85 |
+
- hdf5==1.14.1=nompi_h4f84152_100
|
| 86 |
+
- idna==3.4=pyhd8ed1ab_0
|
| 87 |
+
- importlib-metadata==6.8.0=pyha770c72_0
|
| 88 |
+
- importlib_metadata==6.8.0=hd8ed1ab_0
|
| 89 |
+
- importlib_resources==6.0.1=pyhd8ed1ab_0
|
| 90 |
+
- ipykernel==6.25.1=pyh71e2992_0
|
| 91 |
+
- ipython==8.14.0=pyh41d4057_0
|
| 92 |
+
- isoduration==20.11.0=pyhd8ed1ab_0
|
| 93 |
+
- jedi==0.19.0=pyhd8ed1ab_0
|
| 94 |
+
- jinja2==3.1.2=pyhd8ed1ab_1
|
| 95 |
+
- json5==0.9.14=pyhd8ed1ab_0
|
| 96 |
+
- jsonpointer==2.0=py_0
|
| 97 |
+
- jsonschema==4.19.0=pyhd8ed1ab_1
|
| 98 |
+
- jsonschema-specifications==2023.7.1=pyhd8ed1ab_0
|
| 99 |
+
- jsonschema-with-format-nongpl==4.19.0=pyhd8ed1ab_1
|
| 100 |
+
- jupyter-lsp==2.2.0=pyhd8ed1ab_0
|
| 101 |
+
- jupyter_client==8.3.0=pyhd8ed1ab_0
|
| 102 |
+
- jupyter_core==5.3.1=py310hff52083_0
|
| 103 |
+
- jupyter_events==0.7.0=pyhd8ed1ab_2
|
| 104 |
+
- jupyter_server==2.7.1=pyhd8ed1ab_0
|
| 105 |
+
- jupyter_server_terminals==0.4.4=pyhd8ed1ab_1
|
| 106 |
+
- jupyterlab==4.0.5=pyhd8ed1ab_0
|
| 107 |
+
- jupyterlab_pygments==0.2.2=pyhd8ed1ab_0
|
| 108 |
+
- jupyterlab_server==2.24.0=pyhd8ed1ab_0
|
| 109 |
+
- keyutils==1.6.1=h166bdaf_0
|
| 110 |
+
- kiwisolver==1.4.4=py310hbf28c38_1
|
| 111 |
+
- krb5==1.21.2=h659d440_0
|
| 112 |
+
- lcms2==2.15=haa2dc70_1
|
| 113 |
+
- ld_impl_linux-64==2.40=h41732ed_0
|
| 114 |
+
- lerc==4.0.0=h27087fc_0
|
| 115 |
+
- libaec==1.0.6=hcb278e6_1
|
| 116 |
+
- libblas==3.9.0=17_linux64_openblas
|
| 117 |
+
- libbrotlicommon==1.0.9=h166bdaf_9
|
| 118 |
+
- libbrotlidec==1.0.9=h166bdaf_9
|
| 119 |
+
- libbrotlienc==1.0.9=h166bdaf_9
|
| 120 |
+
- libcblas==3.9.0=17_linux64_openblas
|
| 121 |
+
- libcublas==12.0.1.189=hcb278e6_2
|
| 122 |
+
- libcublas-dev==12.0.1.189=hcb278e6_2
|
| 123 |
+
- libcufft==11.0.0.21=hcb278e6_1
|
| 124 |
+
- libcufft-dev==11.0.0.21=hcb278e6_1
|
| 125 |
+
- libcufile==1.5.0.59=hcb278e6_0
|
| 126 |
+
- libcufile-dev==1.5.0.59=hcb278e6_0
|
| 127 |
+
- libcurand==10.3.1.50=hcb278e6_0
|
| 128 |
+
- libcurand-dev==10.3.1.50=hcb278e6_0
|
| 129 |
+
- libcurl==8.2.1=hca28451_0
|
| 130 |
+
- libcusolver==11.4.2.57=hcb278e6_1
|
| 131 |
+
- libcusparse==12.0.0.76=hcb278e6_1
|
| 132 |
+
- libdeflate==1.18=h0b41bf4_0
|
| 133 |
+
- libedit==3.1.20191231=he28a2e2_2
|
| 134 |
+
- libev==4.33=h516909a_1
|
| 135 |
+
- libffi==3.4.2=h7f98852_5
|
| 136 |
+
- libgcc-ng==13.1.0=he5830b7_0
|
| 137 |
+
- libgfortran-ng==13.1.0=h69a702a_0
|
| 138 |
+
- libgfortran5==13.1.0=h15d22d2_0
|
| 139 |
+
- libgomp==13.1.0=he5830b7_0
|
| 140 |
+
- libjpeg-turbo==2.1.5.1=h0b41bf4_0
|
| 141 |
+
- liblapack==3.9.0=17_linux64_openblas
|
| 142 |
+
- libnghttp2==1.52.0=h61bc06f_0
|
| 143 |
+
- libnpp==12.0.0.30=h59595ed_0
|
| 144 |
+
- libnpp-dev==12.0.0.30=h59595ed_0
|
| 145 |
+
- libnsl==2.0.0=h7f98852_0
|
| 146 |
+
- libnuma==2.0.16=h0b41bf4_1
|
| 147 |
+
- libnvjitlink==12.0.76=hcb278e6_1
|
| 148 |
+
- libnvjpeg==12.0.0.28=hcb278e6_0
|
| 149 |
+
- libnvjpeg-dev==12.0.0.28=ha770c72_0
|
| 150 |
+
- libopenblas==0.3.23=pthreads_h80387f5_0
|
| 151 |
+
- libpng==1.6.39=h753d276_0
|
| 152 |
+
- libsodium==1.0.18=h36c2ea0_1
|
| 153 |
+
- libsqlite==3.42.0=h2797004_0
|
| 154 |
+
- libssh2==1.11.0=h0841786_0
|
| 155 |
+
- libstdcxx-ng==13.1.0=hfd8a6a1_0
|
| 156 |
+
- libtiff==4.5.1=h8b53f26_0
|
| 157 |
+
- libuuid==2.38.1=h0b41bf4_0
|
| 158 |
+
- libwebp-base==1.3.1=hd590300_0
|
| 159 |
+
- libxcb==1.15=h0b41bf4_0
|
| 160 |
+
- libzlib==1.2.13=hd590300_5
|
| 161 |
+
- lz4-c==1.9.4=hcb278e6_0
|
| 162 |
+
- lzo==2.10=h516909a_1000
|
| 163 |
+
- markupsafe==2.1.3=py310h2372a71_0
|
| 164 |
+
- matplotlib-base==3.7.2=py310hf38f957_0
|
| 165 |
+
- matplotlib-inline==0.1.6=pyhd8ed1ab_0
|
| 166 |
+
- mdtraj==1.9.9=py310h8e08b51_0
|
| 167 |
+
- mistune==3.0.1=pyhd8ed1ab_0
|
| 168 |
+
- munkres==1.1.4=pyh9f0ad1d_0
|
| 169 |
+
- nbclient==0.8.0=pyhd8ed1ab_0
|
| 170 |
+
- nbconvert-core==7.7.4=pyhd8ed1ab_0
|
| 171 |
+
- nbformat==5.9.2=pyhd8ed1ab_0
|
| 172 |
+
- ncurses==6.4=hcb278e6_0
|
| 173 |
+
- nest-asyncio==1.5.6=pyhd8ed1ab_0
|
| 174 |
+
- nomkl==1.0=h5ca1d4c_0
|
| 175 |
+
- notebook-shim==0.2.3=pyhd8ed1ab_0
|
| 176 |
+
- nsight-compute==2023.2.1.3=0
|
| 177 |
+
- numexpr==2.8.4=py310hd91493a_101
|
| 178 |
+
- numpy==1.25.2=py310ha4c1d20_0
|
| 179 |
+
- openjpeg==2.5.0=hfec8fc6_2
|
| 180 |
+
- openssl==3.1.2=hd590300_0
|
| 181 |
+
- overrides==7.4.0=pyhd8ed1ab_0
|
| 182 |
+
- packaging==23.1=pyhd8ed1ab_0
|
| 183 |
+
- pandas==2.0.3=py310h7cbd5c2_1
|
| 184 |
+
- pandocfilters==1.5.0=pyhd8ed1ab_0
|
| 185 |
+
- parso==0.8.3=pyhd8ed1ab_0
|
| 186 |
+
- patsy==0.5.3=pyhd8ed1ab_0
|
| 187 |
+
- pexpect==4.8.0=pyh1a96a4e_2
|
| 188 |
+
- pickleshare==0.7.5=py_1003
|
| 189 |
+
- pillow==10.0.0=py310h582fbeb_0
|
| 190 |
+
- pip
|
| 191 |
+
- pkgutil-resolve-name==1.3.10=pyhd8ed1ab_0
|
| 192 |
+
- platformdirs==3.10.0=pyhd8ed1ab_0
|
| 193 |
+
- pooch==1.7.0=pyha770c72_3
|
| 194 |
+
- prometheus_client==0.17.1=pyhd8ed1ab_0
|
| 195 |
+
- prompt-toolkit==3.0.39=pyha770c72_0
|
| 196 |
+
- prompt_toolkit==3.0.39=hd8ed1ab_0
|
| 197 |
+
- psutil==5.9.5=py310h1fa729e_0
|
| 198 |
+
- pthread-stubs==0.4=h36c2ea0_1001
|
| 199 |
+
- ptyprocess==0.7.0=pyhd3deb0d_0
|
| 200 |
+
- pure_eval==0.2.2=pyhd8ed1ab_0
|
| 201 |
+
- py-cpuinfo==9.0.0=pyhd8ed1ab_0
|
| 202 |
+
- pycparser==2.21=pyhd8ed1ab_0
|
| 203 |
+
- pygments==2.16.1=pyhd8ed1ab_0
|
| 204 |
+
- pyparsing==3.0.9=pyhd8ed1ab_0
|
| 205 |
+
- pysocks==1.7.1=pyha2e5f31_6
|
| 206 |
+
- pytables==3.8.0=py310ha028ce3_2
|
| 207 |
+
- python==3.10.12=hd12c33a_0_cpython
|
| 208 |
+
- python-dateutil==2.8.2=pyhd8ed1ab_0
|
| 209 |
+
- python-fastjsonschema==2.18.0=pyhd8ed1ab_0
|
| 210 |
+
- python-json-logger==2.0.7=pyhd8ed1ab_0
|
| 211 |
+
- python-tzdata==2023.3=pyhd8ed1ab_0
|
| 212 |
+
- python_abi==3.10=3_cp310
|
| 213 |
+
- pytorch-cuda==11.6=h867d48c_0
|
| 214 |
+
- pytz==2023.3=pyhd8ed1ab_0
|
| 215 |
+
- pyyaml==6.0=py310h5764c6d_5
|
| 216 |
+
- pyzmq==25.1.1=py310h5bbb5d0_0
|
| 217 |
+
- readline==8.2=h8228510_1
|
| 218 |
+
- referencing==0.30.2=pyhd8ed1ab_0
|
| 219 |
+
- requests==2.31.0=pyhd8ed1ab_0
|
| 220 |
+
- rfc3339-validator==0.1.4=pyhd8ed1ab_0
|
| 221 |
+
- rfc3986-validator==0.1.1=pyh9f0ad1d_0
|
| 222 |
+
- rpds-py==0.9.2=py310hcb5633a_0
|
| 223 |
+
- scipy==1.11.1=py310ha4c1d20_0
|
| 224 |
+
- seaborn==0.12.2=hd8ed1ab_0
|
| 225 |
+
- seaborn-base==0.12.2=pyhd8ed1ab_0
|
| 226 |
+
- send2trash==1.8.2=pyh41d4057_0
|
| 227 |
+
- setuptools==68.1.2=pyhd8ed1ab_0
|
| 228 |
+
- six==1.16.0=pyh6c4a22f_0
|
| 229 |
+
- snappy==1.1.10=h9fff704_0
|
| 230 |
+
- sniffio==1.3.0=pyhd8ed1ab_0
|
| 231 |
+
- soupsieve==2.3.2.post1=pyhd8ed1ab_0
|
| 232 |
+
- stack_data==0.6.2=pyhd8ed1ab_0
|
| 233 |
+
- statsmodels==0.14.0=py310h278f3c1_1
|
| 234 |
+
- terminado==0.17.1=pyh41d4057_0
|
| 235 |
+
- tinycss2==1.2.1=pyhd8ed1ab_0
|
| 236 |
+
- tk==8.6.12=h27826a3_0
|
| 237 |
+
- tomli==2.0.1=pyhd8ed1ab_0
|
| 238 |
+
- tornado==6.3.3=py310h2372a71_0
|
| 239 |
+
- traitlets==5.9.0=pyhd8ed1ab_0
|
| 240 |
+
- typing-extensions==4.7.1=hd8ed1ab_0
|
| 241 |
+
- typing_extensions==4.7.1=pyha770c72_0
|
| 242 |
+
- typing_utils==0.1.0=pyhd8ed1ab_0
|
| 243 |
+
- tzdata==2023c=h71feb2d_0
|
| 244 |
+
- unicodedata2==15.0.0=py310h5764c6d_0
|
| 245 |
+
- uri-template==1.3.0=pyhd8ed1ab_0
|
| 246 |
+
- urllib3==2.0.4=pyhd8ed1ab_0
|
| 247 |
+
- wcwidth==0.2.6=pyhd8ed1ab_0
|
| 248 |
+
- webcolors==1.13=pyhd8ed1ab_0
|
| 249 |
+
- webencodings==0.5.1=py_1
|
| 250 |
+
- websocket-client==1.6.1=pyhd8ed1ab_0
|
| 251 |
+
- wheel==0.41.1=pyhd8ed1ab_0
|
| 252 |
+
- xorg-libxau==1.0.11=hd590300_0
|
| 253 |
+
- xorg-libxdmcp==1.1.3=h7f98852_0
|
| 254 |
+
- xz==5.2.6=h166bdaf_0
|
| 255 |
+
- yaml==0.2.5=h7f98852_2
|
| 256 |
+
- zeromq==4.3.4=h9c3ff4c_1
|
| 257 |
+
- zipp==3.16.2=pyhd8ed1ab_0
|
| 258 |
+
- zlib==1.2.13=hd590300_5
|
| 259 |
+
- zlib-ng==2.0.7=h0b41bf4_0
|
| 260 |
+
- zstd==1.5.2=hfc55251_7
|
| 261 |
+
|
eval/align.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import re
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
RUNNER = '/datapool/data2/home/jiahan/Tool/TMalign-20180426/MMalign'
|
| 8 |
+
|
| 9 |
+
def align_pdb(pdb1,pdb2,pdb1_out):
|
| 10 |
+
subprocess.run([RUNNER,pdb1,pdb2,'-o',pdb1_out],stdout=subprocess.PIPE)
|
| 11 |
+
|
| 12 |
+
def get_tm_score(pdb1,pdb2):
|
| 13 |
+
cmd = subprocess.run(['TMscore',pdb1,pdb2],stdout=subprocess.PIPE)
|
| 14 |
+
out = cmd.stdout.decode()
|
| 15 |
+
tm_score = re.search(r"TM-score\s+=\s+(\d+\.\d+)", out)
|
| 16 |
+
rmsd = re.search(r"RMSD of the common residues=\s+(\d+\.\d+)", out)
|
| 17 |
+
return float(rmsd.group(1)),float(tm_score.group(1))
|
eval/energy.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pyrosetta
|
| 2 |
+
from pyrosetta import init, pose_from_pdb, get_fa_scorefxn
|
| 3 |
+
from pyrosetta.rosetta.protocols.relax import FastRelax
|
| 4 |
+
from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
|
| 5 |
+
from pyrosetta.rosetta.core.pack.task import TaskFactory
|
| 6 |
+
from pyrosetta.rosetta.core.pack.task.operation import RestrictToRepacking
|
| 7 |
+
from pyrosetta.rosetta.protocols.minimization_packing import PackRotamersMover
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import subprocess
|
| 12 |
+
import numpy as np
|
| 13 |
+
import shutil
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import pickle
|
| 16 |
+
|
| 17 |
+
from joblib import delayed, Parallel
|
| 18 |
+
from utils import *
|
| 19 |
+
|
| 20 |
+
input_dir=".Tests"
|
| 21 |
+
output_dir="./Pack"
|
| 22 |
+
|
| 23 |
+
def get_chain_dic(input_pdb):
|
| 24 |
+
parser = PDBParser()
|
| 25 |
+
structure = parser.get_structure("protein", input_pdb)
|
| 26 |
+
chain_dic = {}
|
| 27 |
+
for model in structure:
|
| 28 |
+
for chain in model:
|
| 29 |
+
chain_dic[chain.id] = len([res for res in chain if is_aa(res) and res.has_id('CA')])
|
| 30 |
+
|
| 31 |
+
return chain_dic
|
| 32 |
+
|
| 33 |
+
def get_rosetta_score_base(pdb_path,chain_id='A'):
|
| 34 |
+
try:
|
| 35 |
+
init()
|
| 36 |
+
pose = pyrosetta.pose_from_pdb(pdb_path)
|
| 37 |
+
chains = list(get_chain_dic(pdb_path).keys())
|
| 38 |
+
chains.remove(chain_id)
|
| 39 |
+
interface = f'{chain_id}_{"".join(chains)}'
|
| 40 |
+
fast_relax = FastRelax() # cant be pickled
|
| 41 |
+
scorefxn = get_fa_scorefxn()
|
| 42 |
+
fast_relax.set_scorefxn(scorefxn)
|
| 43 |
+
mover = InterfaceAnalyzerMover(interface)
|
| 44 |
+
mover.set_pack_separated(True)
|
| 45 |
+
stabs,binds = [],[]
|
| 46 |
+
for i in range(5):
|
| 47 |
+
fast_relax.apply(pose)
|
| 48 |
+
stab = scorefxn(pose)
|
| 49 |
+
mover.apply(pose)
|
| 50 |
+
bind = pose.scores['dG_separated']
|
| 51 |
+
stabs.append(stab)
|
| 52 |
+
binds.append(bind)
|
| 53 |
+
return {'name':pdb_path,'stab':np.array(stabs).mean(),'bind':np.array(binds).mean()}
|
| 54 |
+
except:
|
| 55 |
+
return {'name':pdb_path,'stab':999.0,'bind':999.0}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_rosetta_score(pdb_path,chain='A'):
|
| 59 |
+
try:
|
| 60 |
+
init()
|
| 61 |
+
pose = pyrosetta.pose_from_pdb(pdb_path)
|
| 62 |
+
# chains = list(get_chain_dic(os.path.join(input_dir,name,'pocket_merge_renum.pdb')).keys())
|
| 63 |
+
# chains.remove(chain)
|
| 64 |
+
# interface = f'{chain}_{"".join(chains)}'
|
| 65 |
+
interface='A_B'
|
| 66 |
+
fast_relax = FastRelax() # cant be pickled
|
| 67 |
+
scorefxn = get_fa_scorefxn()
|
| 68 |
+
fast_relax.set_scorefxn(scorefxn)
|
| 69 |
+
mover = InterfaceAnalyzerMover(interface)
|
| 70 |
+
mover.set_pack_separated(True)
|
| 71 |
+
fast_relax.apply(pose)
|
| 72 |
+
energy = scorefxn(pose)
|
| 73 |
+
mover.apply(pose)
|
| 74 |
+
dg = pose.scores['dG_separated']
|
| 75 |
+
return [pdb_path,energy,dg]
|
| 76 |
+
except:
|
| 77 |
+
return [pdb_path,999.0,999.0]
|
| 78 |
+
|
| 79 |
+
def pack_sc(name='1a1m_C',num_samples=10):
|
| 80 |
+
try:
|
| 81 |
+
if os.path.exists(os.path.join(output_dir,name,'rosetta')):
|
| 82 |
+
shutil.rmtree(os.path.join(output_dir,name,'rosetta'))
|
| 83 |
+
os.makedirs(os.path.join(output_dir,name,'rosetta'),exist_ok=True)
|
| 84 |
+
init()
|
| 85 |
+
tf = TaskFactory()
|
| 86 |
+
tf.push_back(RestrictToRepacking()) # Only repack, don't change amino acid types
|
| 87 |
+
packer = PackRotamersMover()
|
| 88 |
+
packer.task_factory(tf)
|
| 89 |
+
for i in range(num_samples):
|
| 90 |
+
pose = pose_from_pdb(os.path.join(input_dir,name,f'pocket_merge_renum_bb.pdb'))
|
| 91 |
+
packer.apply(pose)
|
| 92 |
+
pose.dump_pdb(os.path.join(output_dir,name,'rosetta',f'packed_{i}.pdb'))
|
| 93 |
+
except:
|
| 94 |
+
return None
|
eval/foldx.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
from joblib import Parallel, delayed
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import tempfile
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
import subprocess
|
| 11 |
+
|
| 12 |
+
from Bio.PDB import PDBParser
|
| 13 |
+
|
| 14 |
+
def fetch_stability_score(path):
|
| 15 |
+
u = pd.read_csv(path, sep='\t', header=None)
|
| 16 |
+
return u.values[0][1]
|
| 17 |
+
|
| 18 |
+
def fetch_binding_affinity(path):
|
| 19 |
+
with open(path, 'r') as f:
|
| 20 |
+
u = f.readlines()
|
| 21 |
+
return float(u[-1].split("\t")[-3])
|
| 22 |
+
|
| 23 |
+
class FoldXSession(object):
|
| 24 |
+
def __init__(self):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.tmpdir = tempfile.TemporaryDirectory()
|
| 27 |
+
self.pdb_names = []
|
| 28 |
+
|
| 29 |
+
def cleanup(self):
|
| 30 |
+
self.tmpdir.cleanup()
|
| 31 |
+
self.tmpdir = None
|
| 32 |
+
|
| 33 |
+
def __enter__(self):
|
| 34 |
+
return self
|
| 35 |
+
|
| 36 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 37 |
+
self.cleanup()
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def workdir(self):
|
| 41 |
+
return self.tmpdir.name
|
| 42 |
+
|
| 43 |
+
def path(self, filename):
|
| 44 |
+
return os.path.join(self.workdir, filename)
|
| 45 |
+
|
| 46 |
+
def preprocess_data(self, pdb_dir, pdb_name):
|
| 47 |
+
shutil.copy(os.path.join(pdb_dir, pdb_name), self.path(pdb_name))
|
| 48 |
+
return self.path(pdb_name)
|
| 49 |
+
|
| 50 |
+
def get_chain_names(pdb_dir,pdb_name):
|
| 51 |
+
pep_chain = pdb_name.split("_")[-1][0]
|
| 52 |
+
parser = PDBParser()
|
| 53 |
+
structure = parser.get_structure("name", os.path.join(pdb_dir,pdb_name))
|
| 54 |
+
chain_names = [chain.get_id() for model in structure for chain in model]
|
| 55 |
+
chains = f"{pep_chain},"
|
| 56 |
+
for chain in chain_names:
|
| 57 |
+
if chain != pep_chain:
|
| 58 |
+
chains += f"{chain}"
|
| 59 |
+
return chains
|
| 60 |
+
|
| 61 |
+
def process_one_file(pdb_dir,pdb_name):
|
| 62 |
+
chains = get_chain_names(pdb_dir,pdb_name)
|
| 63 |
+
with FoldXSession() as session:
|
| 64 |
+
try:
|
| 65 |
+
# print(session.workdir)
|
| 66 |
+
session.preprocess_data(pdb_dir, pdb_name)
|
| 67 |
+
assert(os.path.exists(session.path(pdb_name)))
|
| 68 |
+
# print(os.listdir(session.workdir))
|
| 69 |
+
ret = subprocess.run(['/datapool/data2/home/ruihan/bin/foldx', '--command='+'AnalyseComplex', '--pdb='+pdb_name, f'--analyseComplexChains={chains}'], cwd=session.workdir, stdout=None)
|
| 70 |
+
fxout_path = session.path(f'Summary_{pdb_name.split(".")[0]}_AC.fxout')
|
| 71 |
+
assert(os.path.exists(fxout_path))
|
| 72 |
+
return (pdb_name.split('.')[0],fetch_binding_affinity(fxout_path))
|
| 73 |
+
except:
|
| 74 |
+
print(f"Error in {pdb_name}")
|
| 75 |
+
print(os.path.exists(fxout_path))
|
| 76 |
+
return (pdb_name.split('.')[0],None)
|
| 77 |
+
|
eval/geometry.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Bio.PDB import PDBParser, Superimposer, is_aa, Select, NeighborSearch
|
| 2 |
+
import tmtools
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import mdtraj as md
|
| 6 |
+
from Bio.SeqUtils import seq1
|
| 7 |
+
|
| 8 |
+
import warnings
|
| 9 |
+
from Bio import BiopythonWarning, SeqIO
|
| 10 |
+
|
| 11 |
+
import difflib
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
# 忽略PDBConstructionWarning
|
| 15 |
+
warnings.filterwarnings('ignore', category=BiopythonWarning)
|
| 16 |
+
|
| 17 |
+
def get_chain_from_pdb(pdb_path, chain_id='A'):
|
| 18 |
+
parser = PDBParser()
|
| 19 |
+
structure = parser.get_structure('X', pdb_path)[0]
|
| 20 |
+
for chain in structure:
|
| 21 |
+
if chain.id == chain_id:
|
| 22 |
+
# print(len(chain))
|
| 23 |
+
return chain
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
def diff_ratio(str1, str2):
|
| 27 |
+
# Create a SequenceMatcher object
|
| 28 |
+
seq_matcher = difflib.SequenceMatcher(None, str1, str2)
|
| 29 |
+
|
| 30 |
+
# Calculate the difference ratio
|
| 31 |
+
return seq_matcher.ratio()
|
| 32 |
+
|
| 33 |
+
#######################################
|
| 34 |
+
|
| 35 |
+
#RMSD and Tm
|
| 36 |
+
|
| 37 |
+
#######################################
|
| 38 |
+
def align_chains(chain1, chain2):
|
| 39 |
+
reslist1 = []
|
| 40 |
+
reslist2 = []
|
| 41 |
+
for residue1,residue2 in zip(chain1.get_residues(),chain2.get_residues()):
|
| 42 |
+
if is_aa(residue1) and residue1.has_id('CA'): # at least have CA
|
| 43 |
+
reslist1.append(residue1)
|
| 44 |
+
reslist2.append(residue2)
|
| 45 |
+
return reslist1,reslist2
|
| 46 |
+
|
| 47 |
+
def get_rmsd(chain1, chain2):
|
| 48 |
+
# chain1 = get_chain_from_pdb(pdb1, chain_id1)
|
| 49 |
+
# chain2 = get_chain_from_pdb(pdb2, chain_id2)
|
| 50 |
+
if chain1 is None or chain2 is None:
|
| 51 |
+
return None
|
| 52 |
+
super_imposer = Superimposer()
|
| 53 |
+
pos1 = np.array([atom.get_coord() for atom in chain1.get_atoms() if atom.name == 'CA'])
|
| 54 |
+
pos2 = np.array([atom.get_coord() for atom in chain2.get_atoms() if atom.name == 'CA'])
|
| 55 |
+
rmsd1 = np.sqrt(np.sum((pos1 - pos2)**2) / len(pos1))
|
| 56 |
+
super_imposer.set_atoms([atom for atom in chain1.get_atoms() if atom.name == 'CA'],
|
| 57 |
+
[atom for atom in chain2.get_atoms() if atom.name == 'CA'])
|
| 58 |
+
rmsd2 = super_imposer.rms
|
| 59 |
+
return rmsd1,rmsd2
|
| 60 |
+
|
| 61 |
+
def get_tm(chain1,chain2):
|
| 62 |
+
# chain1 = get_chain_from_pdb(pdb1, chain_id1)
|
| 63 |
+
# chain2 = get_chain_from_pdb(pdb2, chain_id2)
|
| 64 |
+
pos1 = np.array([atom.get_coord() for atom in chain1.get_atoms() if atom.name == 'CA'])
|
| 65 |
+
pos2 = np.array([atom.get_coord() for atom in chain2.get_atoms() if atom.name == 'CA'])
|
| 66 |
+
tm_results = tmtools.tm_align(pos1, pos2, 'A'*len(pos1), 'A'*len(pos2))
|
| 67 |
+
# print(dir(tm_results))
|
| 68 |
+
return tm_results.tm_norm_chain2
|
| 69 |
+
|
| 70 |
+
def get_traj_chain(pdb, chain):
|
| 71 |
+
parser = PDBParser()
|
| 72 |
+
structure = parser.get_structure('X', pdb)[0]
|
| 73 |
+
chain2id = {chain.id:i for i,chain in enumerate(structure)}
|
| 74 |
+
traj = md.load(pdb)
|
| 75 |
+
chain_indices = traj.topology.select(f"chainid {chain2id[chain]}")
|
| 76 |
+
traj = traj.atom_slice(chain_indices)
|
| 77 |
+
return traj
|
| 78 |
+
|
| 79 |
+
def get_second_stru(pdb,chain):
|
| 80 |
+
parser = PDBParser()
|
| 81 |
+
structure = parser.get_structure('X', pdb)[0]
|
| 82 |
+
chain2id = {chain.id:i for i,chain in enumerate(structure)}
|
| 83 |
+
traj = md.load(pdb)
|
| 84 |
+
chain_indices = traj.topology.select(f"chainid {chain2id[chain]}")
|
| 85 |
+
traj = traj.atom_slice(chain_indices)
|
| 86 |
+
return md.compute_dssp(traj,simplified=True)
|
| 87 |
+
|
| 88 |
+
def get_ss(traj1,traj2):
|
| 89 |
+
# traj1,traj2 = get_traj_chain(pdb1,chain_id1),get_traj_chain(pdb2,chain_id2)
|
| 90 |
+
ss1,ss2 = md.compute_dssp(traj1,simplified=True),md.compute_dssp(traj2,simplified=True)
|
| 91 |
+
return (ss1==ss2).mean()
|
| 92 |
+
|
| 93 |
+
def get_bind_site(pdb,chain_id):
|
| 94 |
+
parser = PDBParser()
|
| 95 |
+
structure = parser.get_structure('X', pdb)[0]
|
| 96 |
+
peps = [atom for res in structure[chain_id] for atom in res if atom.get_name() == 'CA']
|
| 97 |
+
recs = [atom for chain in structure if chain.get_id()!=chain_id for res in chain for atom in res if atom.get_name() == 'CA']
|
| 98 |
+
# print(recs)
|
| 99 |
+
search = NeighborSearch(recs)
|
| 100 |
+
near_res = []
|
| 101 |
+
for atom in peps:
|
| 102 |
+
near_res += search.search(atom.get_coord(), 10.0, level='R')
|
| 103 |
+
near_res = set([res.get_id()[1] for res in near_res])
|
| 104 |
+
return near_res
|
| 105 |
+
|
| 106 |
+
def get_bind_ratio(pdb1, pdb2, chain_id1, chain_id2):
|
| 107 |
+
near_res1,near_res2 = get_bind_site(pdb1,chain_id1),get_bind_site(pdb2,chain_id2)
|
| 108 |
+
# print(near_res1)
|
| 109 |
+
# print(near_res2)
|
| 110 |
+
return len(near_res1.intersection(near_res2))/(len(near_res2)+1e-10) # last one is gt
|
| 111 |
+
|
| 112 |
+
def get_dihedral(pdb,chain):
|
| 113 |
+
traj = get_traj_chain(pdb,chain)
|
| 114 |
+
#TODO: dihedral
|
| 115 |
+
|
| 116 |
+
def get_seq(pdb,chain_id):
|
| 117 |
+
parser = PDBParser()
|
| 118 |
+
chain = parser.get_structure('X', pdb)[0][chain_id]
|
| 119 |
+
return seq1("".join([residue.get_resname() for residue in chain])) # ignore is_aa,used for extract seq from genrated pdb
|
| 120 |
+
|
| 121 |
+
def get_mpnn_seqs(path):
|
| 122 |
+
fastas = []
|
| 123 |
+
for record in SeqIO.parse(path, "fasta"):
|
| 124 |
+
tmp = [c for c in str(record.seq)]
|
| 125 |
+
fastas.append(tmp)
|
| 126 |
+
return fastas
|
| 127 |
+
|
eval/run_esmfold.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import subprocess
|
| 4 |
+
import torch
|
| 5 |
+
import esm
|
| 6 |
+
import numpy as np
|
| 7 |
+
import shutil
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from joblib import delayed, Parallel
|
| 11 |
+
|
| 12 |
+
import warnings
|
| 13 |
+
from Bio import BiopythonWarning, SeqIO
|
| 14 |
+
|
| 15 |
+
from geometry import *
|
| 16 |
+
|
| 17 |
+
# 忽略PDBConstructionWarning
|
| 18 |
+
warnings.filterwarnings('ignore', category=BiopythonWarning)
|
| 19 |
+
|
| 20 |
+
input_dir="./Data/Baselines_new/Tests"
|
| 21 |
+
output_dir="/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Baselines_new/Codesign"
|
| 22 |
+
|
| 23 |
+
model = esm.pretrained.esmfold_v1()
|
| 24 |
+
model = model.eval().to('cuda:2')
|
| 25 |
+
|
| 26 |
+
def process_rf(name='1aze_B'):
|
| 27 |
+
input_dir=".Data/Baselines_new/Tests"
|
| 28 |
+
output_dir=".Data/Baselines_new/Codesign"
|
| 29 |
+
struct_dir = os.path.join(output_dir,name,'rfs_refold')
|
| 30 |
+
seq_dir = os.path.join(output_dir,name,'mpnns','seqs')
|
| 31 |
+
os.makedirs(struct_dir,exist_ok=True)
|
| 32 |
+
seqs = {}
|
| 33 |
+
for seq_path in os.listdir(seq_dir):
|
| 34 |
+
tmp_seqs = []
|
| 35 |
+
if seq_path.endswith('.fasta'):
|
| 36 |
+
for record in SeqIO.parse(os.path.join(seq_dir,seq_path), "fasta"):
|
| 37 |
+
tmp_seqs.append(str(record.seq))
|
| 38 |
+
seqs[seq_path.split('.')[0]] = tmp_seqs[-1]
|
| 39 |
+
for seq_name,seq in seqs.items():
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
output = model.infer_pdb(seq)
|
| 42 |
+
with open(os.path.join(struct_dir,seq_name+'.pdb'),'w') as f:
|
| 43 |
+
f.write(output)
|
| 44 |
+
|
| 45 |
+
def process_pg(name='1aze_B',chain_id='A'):
|
| 46 |
+
input_dir=".Data/Baselines_new/Tests"
|
| 47 |
+
output_dir=".Data/Baselines_new/Codesign"
|
| 48 |
+
struct_dir = os.path.join(output_dir,name,'pgs_refold')
|
| 49 |
+
seq_dir = os.path.join(output_dir,name,'pgs')
|
| 50 |
+
os.makedirs(struct_dir,exist_ok=True)
|
| 51 |
+
seqs = {}
|
| 52 |
+
for seq_path in os.listdir(seq_dir):
|
| 53 |
+
if seq_path.endswith('.pdb'):
|
| 54 |
+
seqs[seq_path.split('.')[0]] = get_seq(os.path.join(seq_dir,seq_path),chain_id)
|
| 55 |
+
for seq_name,seq in seqs.items():
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
output = model.infer_pdb(seq)
|
| 58 |
+
with open(os.path.join(struct_dir,seq_name+'.pdb'),'w') as f:
|
| 59 |
+
f.write(output)
|
| 60 |
+
|
| 61 |
+
def refold(name,chain_id,sub_dir):
|
| 62 |
+
raw_dir = os.path.join('/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Codesign',sub_dir,'pdbs')
|
| 63 |
+
refold_dir = os.path.join('/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Codesign',sub_dir,'pdbs_refold')
|
| 64 |
+
os.makedirs(os.path.join(refold_dir,name),exist_ok=True)
|
| 65 |
+
seqs = {}
|
| 66 |
+
for seq_path in os.listdir(os.path.join(raw_dir,name)):
|
| 67 |
+
if seq_path.endswith('.pdb'):
|
| 68 |
+
seqs[seq_path.split('.')[0]] = get_seq(os.path.join(raw_dir,name,seq_path),chain_id)
|
| 69 |
+
for seq_name,seq in seqs.items():
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
output = model.infer_pdb(seq)
|
| 72 |
+
with open(os.path.join(refold_dir,name,seq_name+'.pdb'),'w') as f:
|
| 73 |
+
f.write(output)
|
eval/run_esmif.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import subprocess
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import shutil
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from joblib import delayed, Parallel
|
| 12 |
+
|
| 13 |
+
input_dir="./Baselines_new/Tests"
|
| 14 |
+
# output_dir="/datapool/data2/home/jiahan/Res Proj/PepDiff/frame-flow/Data/RF_samples"
|
| 15 |
+
output_dir="./Data/Baselines_new/Fixbb"
|
| 16 |
+
|
| 17 |
+
RUNNER = "/datapool/data2/home/jiahan/Tool/esm/examples/inverse_folding/sample_sequences.py"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def process_one_item_esmif(name='1a1m_C',chains_to_design="A",num_samples=10,temperature=0.1):
|
| 21 |
+
if not os.path.exists(os.path.join(output_dir,name,'esms')):
|
| 22 |
+
os.makedirs(os.path.join(output_dir,name,'esms'))
|
| 23 |
+
assert os.path.exists(os.path.join(output_dir,name,'esms'))
|
| 24 |
+
# if not os.path.exists(os.path.join(output_dir,name,'pocket_merge_renum.pdb')):
|
| 25 |
+
# chain_dic = renumber_pdb(os.path.join(input_dir,name,'pocket_merge.pdb'),os.path.join(output_dir,name,'pocket_merge_renum.pdb'))
|
| 26 |
+
dirname = os.path.join(output_dir,name,'esms')
|
| 27 |
+
cmd = [
|
| 28 |
+
"python", RUNNER, os.path.join(input_dir,name,'pocket_merge_renum.pdb'),
|
| 29 |
+
"--chain", chains_to_design, "--temperature", f"{temperature}", "--num-samples", f"{num_samples}",
|
| 30 |
+
"--outpath", os.path.join(dirname,'pocket_merge_renum.fasta'),
|
| 31 |
+
"--multichain-backbone", "--nogpu"
|
| 32 |
+
]
|
| 33 |
+
subprocess.run(cmd)
|
eval/run_mpnn.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
from geometry import *
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import subprocess
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import shutil
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from joblib import delayed, Parallel
|
| 13 |
+
|
| 14 |
+
from Bio.PDB import PDBParser, PDBIO, Select
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
HELPERS = "/datapool/data2/home/jiahan/Tool/ProteinMPNN/helper_scripts"
|
| 18 |
+
RUNNER = "/datapool/data2/home/jiahan/Tool/ProteinMPNN/protein_mpnn_run.py"
|
| 19 |
+
|
| 20 |
+
def get_chain_nums(pdb_path,chain_id):
|
| 21 |
+
parser = PDBParser()
|
| 22 |
+
chain = parser.get_structure('X',pdb_path)[0][chain_id]
|
| 23 |
+
residue_nums = [residue.get_id()[1] for residue in chain]
|
| 24 |
+
return residue_nums
|
| 25 |
+
|
| 26 |
+
def process_mpnn_bb(name='1aze_B',chains_to_design="A",num_samples=1):
|
| 27 |
+
input_dir = './Data/Models_new/Codesign/bb/pdbs'
|
| 28 |
+
output_dir = './Data/Models_new/Codesign/bb/seqs'
|
| 29 |
+
if not os.path.exists(os.path.join(output_dir,name)):
|
| 30 |
+
os.makedirs(os.path.join(output_dir,name))
|
| 31 |
+
dirname = os.path.join(output_dir,name)
|
| 32 |
+
# defined dirs
|
| 33 |
+
path_for_parsed_chains=os.path.join(dirname,'parsed_pdbs.jsonl')
|
| 34 |
+
path_for_assigned_chains=os.path.join(dirname,'assigned_pdbs.jsonl')
|
| 35 |
+
path_for_fixed_positions=os.path.join(dirname,'fixed_pdbs.jsonl')
|
| 36 |
+
residue_nums = get_chain_nums(os.path.join(input_dir,name,'gt.pdb'),chains_to_design)
|
| 37 |
+
design_only_positions = " ".join(map(str,residue_nums)) #design only these residues; use flag --specify_non_fixed
|
| 38 |
+
# print(path_for_assigned_chains)
|
| 39 |
+
# print(design_only_positions)
|
| 40 |
+
subprocess.run([
|
| 41 |
+
"python", os.path.join(HELPERS,"parse_multiple_chains.py"),
|
| 42 |
+
"--input_path", os.path.join(input_dir,name),
|
| 43 |
+
"--output_path", path_for_parsed_chains,
|
| 44 |
+
])
|
| 45 |
+
subprocess.run([
|
| 46 |
+
"python", os.path.join(HELPERS,"assign_fixed_chains.py"),
|
| 47 |
+
"--input_path", path_for_parsed_chains,
|
| 48 |
+
"--output_path", path_for_assigned_chains,
|
| 49 |
+
'--chain_list', chains_to_design,
|
| 50 |
+
])
|
| 51 |
+
subprocess.run([
|
| 52 |
+
"python", os.path.join(HELPERS,"make_fixed_positions_dict.py"),
|
| 53 |
+
"--input_path", path_for_parsed_chains,
|
| 54 |
+
"--output_path", path_for_fixed_positions,
|
| 55 |
+
'--chain_list', chains_to_design,
|
| 56 |
+
'--position_list', design_only_positions,
|
| 57 |
+
'--specify_non_fixed'
|
| 58 |
+
])
|
| 59 |
+
# run mpnn
|
| 60 |
+
# print('run mpnns')
|
| 61 |
+
subprocess.run([
|
| 62 |
+
"python", RUNNER,
|
| 63 |
+
"--jsonl_path", path_for_parsed_chains,
|
| 64 |
+
"--chain_id_jsonl", path_for_assigned_chains,
|
| 65 |
+
"--fixed_positions_jsonl", path_for_fixed_positions,
|
| 66 |
+
"--out_folder", dirname,
|
| 67 |
+
"--num_seq_per_target", f"{num_samples}",
|
| 68 |
+
"--sampling_temp", "0.1",
|
| 69 |
+
"--seed", "37",
|
| 70 |
+
"--batch_size","1",
|
| 71 |
+
'--device','cuda:1'
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
def process_one_item_mpnn(name='1a1m_C',chains_to_design="A",num_samples=1):
|
| 75 |
+
input_dir="./Data/Baselines_new/Tests"
|
| 76 |
+
output_dir="./Data/Baselines_new/Codesign"
|
| 77 |
+
if not os.path.exists(os.path.join(output_dir,name,'mpnns')):
|
| 78 |
+
os.makedirs(os.path.join(output_dir,name,'mpnns'))
|
| 79 |
+
# if not os.path.exists(os.path.join(output_dir,name,'pocket_merge_renum.pdb')):
|
| 80 |
+
# chain_dic = renumber_pdb(os.path.join(input_dir,name,'pocket_merge.pdb'),os.path.join(output_dir,name,'pocket_merge_renum.pdb'))
|
| 81 |
+
dirname = os.path.join(output_dir,name,'mpnns')
|
| 82 |
+
# defined dirs
|
| 83 |
+
path_for_parsed_chains=os.path.join(dirname,'parsed_pdbs.jsonl')
|
| 84 |
+
path_for_assigned_chains=os.path.join(dirname,'assigned_pdbs.jsonl')
|
| 85 |
+
path_for_fixed_positions=os.path.join(dirname,'fixed_pdbs.jsonl')
|
| 86 |
+
with open(os.path.join(input_dir,name,'seq.fasta'),'r') as f:
|
| 87 |
+
pep_len = len(f.readlines()[1].strip())
|
| 88 |
+
design_only_positions=" ".join(map(str,list(range(1,pep_len+1)))) #design only these residues; use flag --specify_non_fixed
|
| 89 |
+
# print(design_only_positions)
|
| 90 |
+
# parsed chains
|
| 91 |
+
# print("parsing chains")
|
| 92 |
+
subprocess.run([
|
| 93 |
+
"python", os.path.join(HELPERS,"parse_multiple_chains.py"),
|
| 94 |
+
"--input_path", os.path.join('./Data/Baselines_new/Codesign',name,'rfs'),#os.path.join('/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Baselines/Fixbb/',name),
|
| 95 |
+
"--output_path", path_for_parsed_chains,
|
| 96 |
+
])
|
| 97 |
+
subprocess.run([
|
| 98 |
+
"python", os.path.join(HELPERS,"assign_fixed_chains.py"),
|
| 99 |
+
"--input_path", path_for_parsed_chains,
|
| 100 |
+
"--output_path", path_for_assigned_chains,
|
| 101 |
+
'--chain_list', chains_to_design,
|
| 102 |
+
])
|
| 103 |
+
subprocess.run([
|
| 104 |
+
"python", os.path.join(HELPERS,"make_fixed_positions_dict.py"),
|
| 105 |
+
"--input_path", path_for_parsed_chains,
|
| 106 |
+
"--output_path", path_for_fixed_positions,
|
| 107 |
+
'--chain_list', chains_to_design,
|
| 108 |
+
'--position_list', design_only_positions,
|
| 109 |
+
'--specify_non_fixed'
|
| 110 |
+
])
|
| 111 |
+
# run mpnn
|
| 112 |
+
# print('run mpnns')
|
| 113 |
+
subprocess.run([
|
| 114 |
+
"python", RUNNER,
|
| 115 |
+
"--jsonl_path", path_for_parsed_chains,
|
| 116 |
+
"--chain_id_jsonl", path_for_assigned_chains,
|
| 117 |
+
"--fixed_positions_jsonl", path_for_fixed_positions,
|
| 118 |
+
"--out_folder", dirname,
|
| 119 |
+
"--num_seq_per_target", f"{num_samples}",
|
| 120 |
+
"--sampling_temp", "0.1",
|
| 121 |
+
"--seed", "37",
|
| 122 |
+
"--batch_size","1",
|
| 123 |
+
'--device','cuda:1'
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def write_seq_to_pdb(seq_path,pdb_path,out_path,chain_id):
|
| 128 |
+
# first we should fix GGGGG in rfs with mpnn generated seq
|
| 129 |
+
aa_mapping = {"A": "ALA","C": "CYS","D": "ASP","E": "GLU","F": "PHE","G": "GLY","H": "HIS","I": "ILE","K": "LYS","L": "LEU","M": "MET","N": "ASN","P": "PRO","Q": "GLN","R": "ARG","S": "SER","T": "THR","V": "VAL","W": "TRP","Y": "TYR",
|
| 130 |
+
'X':'UNK'}
|
| 131 |
+
tmps = []
|
| 132 |
+
for record in SeqIO.parse(seq_path, "fasta"):
|
| 133 |
+
tmps.append(str(record.seq))
|
| 134 |
+
seq = tmps[-1]
|
| 135 |
+
|
| 136 |
+
parser = PDBParser()
|
| 137 |
+
structure = parser.get_structure("X", pdb_path)
|
| 138 |
+
model = structure[0]
|
| 139 |
+
for chain in model:
|
| 140 |
+
if chain.id == chain_id: # 假设你要更改的是链A
|
| 141 |
+
for i,res in enumerate(chain):
|
| 142 |
+
if i<len(seq):
|
| 143 |
+
res.resname = aa_mapping[seq[i]]
|
| 144 |
+
io = PDBIO()
|
| 145 |
+
io.set_structure(structure)
|
| 146 |
+
io.save(out_path)
|
eval/run_rfdiffusion.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
from geometry import *
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import subprocess
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import shutil
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from joblib import delayed, Parallel
|
| 13 |
+
|
| 14 |
+
input_dir="./Data/Baselines_new/Tests"
|
| 15 |
+
output_dir=".Data/Baselines_new/Codesign"
|
| 16 |
+
|
| 17 |
+
PROGEN="/datapool/data2/home/jiahan/Tool/protein_generator/inference.py"
|
| 18 |
+
|
| 19 |
+
def process_one_item_rf(name='1a1m_C',num_samples=10):
|
| 20 |
+
if not os.path.exists(os.path.join(output_dir,name,'rfs')):
|
| 21 |
+
os.makedirs(os.path.join(output_dir,name,'rfs'))
|
| 22 |
+
chain_dic = get_chain_dic(os.path.join(input_dir,name,'pocket_renum.pdb'))
|
| 23 |
+
with open(os.path.join(input_dir,name,'seq.fasta'),'r') as f:
|
| 24 |
+
pep_len = len(f.readlines()[1].strip())
|
| 25 |
+
# rfdiffusion
|
| 26 |
+
contigs = []
|
| 27 |
+
for chain,chain_len in chain_dic.items():
|
| 28 |
+
contigs.append(f'{chain}1-{chain_len}/0')
|
| 29 |
+
contigs.append(f'{pep_len}-{pep_len}')
|
| 30 |
+
contigs = " ".join(contigs)
|
| 31 |
+
command = [
|
| 32 |
+
"run_inference.py",
|
| 33 |
+
f"inference.output_prefix='{os.path.join(output_dir,name,'rfs','sample')}'",
|
| 34 |
+
f"inference.input_pdb='{os.path.join(input_dir,name,'pocket_renum.pdb')}'",
|
| 35 |
+
f"contigmap.contigs=[{contigs}]",
|
| 36 |
+
f"inference.num_designs={num_samples}",
|
| 37 |
+
]
|
| 38 |
+
# print(command)
|
| 39 |
+
try:
|
| 40 |
+
result = subprocess.run(command, check=True, capture_output=True, text=True)
|
| 41 |
+
return name
|
| 42 |
+
except:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
def process_one_item_pg(name='1a1m_C',num_samples=10):
|
| 46 |
+
if not os.path.exists(os.path.join(output_dir,name,'pgs')):
|
| 47 |
+
os.makedirs(os.path.join(output_dir,name,'pgs'))
|
| 48 |
+
os.makedirs(os.path.join(output_dir,name,'pgs'),exist_ok=True)
|
| 49 |
+
chain_dic = get_chain_dic(os.path.join(input_dir,name,'pocket_renum.pdb'))
|
| 50 |
+
with open(os.path.join(input_dir,name,'seq.fasta'),'r') as f:
|
| 51 |
+
pep_len = len(f.readlines()[1].strip())
|
| 52 |
+
# protein_generator settings
|
| 53 |
+
contigs = []
|
| 54 |
+
for chain,chain_len in chain_dic.items():
|
| 55 |
+
contigs.append(f'{chain}1-{chain_len},0')
|
| 56 |
+
contigs.append(f'{pep_len}-{pep_len}')
|
| 57 |
+
command = [
|
| 58 |
+
"python", PROGEN,
|
| 59 |
+
"--num_designs", f"{num_samples}",
|
| 60 |
+
"--out", os.path.join(output_dir,name,'pgs','sample'),
|
| 61 |
+
"--pdb", os.path.join(input_dir,name,'pocket_renum.pdb'),
|
| 62 |
+
"--T", "25", # default setting
|
| 63 |
+
"--save_best_plddt", # default setting
|
| 64 |
+
"--contigs", *contigs,
|
| 65 |
+
]
|
| 66 |
+
# print(command)
|
| 67 |
+
try:
|
| 68 |
+
result = subprocess.run(command, check=True, capture_output=True, text=True)
|
| 69 |
+
return name
|
| 70 |
+
except:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
def process_one_item(name='1a1m_C',num_samples=10):
|
| 74 |
+
process_one_item_pg(name,num_samples)
|
| 75 |
+
process_one_item_rf(name,num_samples)
|
eval/run_scwrl4.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import *
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import subprocess
|
| 6 |
+
import numpy as np
|
| 7 |
+
import shutil
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from joblib import delayed, Parallel
|
| 11 |
+
|
| 12 |
+
input_dir="./Data/Baselines_new/Tests"
|
| 13 |
+
output_dir="./Data/Baselines_new/Pack"
|
| 14 |
+
|
| 15 |
+
RUNNER = "/datapool/data2/home/jiahan/Tool/bin/Scwrl4"
|
| 16 |
+
|
| 17 |
+
def process_one_item_scwrl4(name='1a1m_C',num_samples=10):
|
| 18 |
+
if not os.path.exists(os.path.join(output_dir,name,'scwrls')):
|
| 19 |
+
os.makedirs(os.path.join(output_dir,name,'scwrls'))
|
| 20 |
+
# if not os.path.exists(os.path.join(output_dir,name,'pocket_merge_renum.pdb')):
|
| 21 |
+
# chain_dic = renumber_pdb(os.path.join(input_dir,name,'pocket_merge.pdb'),os.path.join(output_dir,name,'pocket_merge_renum.pdb'))
|
| 22 |
+
# keep_backbone_atoms(os.path.join(output_dir,name,'pocket_merge_renum.pdb'),os.path.join(output_dir,name,'pocket_merge_renum_backbone.pdb'))
|
| 23 |
+
dirname = os.path.join(output_dir,name,'scwrls')
|
| 24 |
+
for i in range(num_samples):
|
| 25 |
+
cmd = [
|
| 26 |
+
RUNNER,
|
| 27 |
+
'-i',os.path.join(input_dir,name,'pocket_merge_renum_bb.pdb'),
|
| 28 |
+
'-o',os.path.join(dirname,f'packed_{i}.pdb'),
|
| 29 |
+
]
|
| 30 |
+
subprocess.run(cmd)
|
eval/utils.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import subprocess
|
| 5 |
+
from difflib import SequenceMatcher
|
| 6 |
+
|
| 7 |
+
from Bio import SeqIO
|
| 8 |
+
from Bio.PDB import PDBParser, PDBIO, Chain, Select, is_aa
|
| 9 |
+
from Bio.PDB.Polypeptide import PPBuilder
|
| 10 |
+
|
| 11 |
+
from Bio.PDB import PDBParser
|
| 12 |
+
from Bio.SeqUtils import seq1
|
| 13 |
+
|
| 14 |
+
# def parse_pdb_chains(pdb_file):
|
| 15 |
+
# parser = PDBParser()
|
| 16 |
+
# structure = parser.get_structure("protein", pdb_file)
|
| 17 |
+
# pp_builder = PPBuilder()
|
| 18 |
+
|
| 19 |
+
# sequences = {}
|
| 20 |
+
# for model in structure:
|
| 21 |
+
# for chain in model:
|
| 22 |
+
# chain_id = chain.get_id()
|
| 23 |
+
# sequence = "".join([str(pp.get_sequence()) for pp in pp_builder.build_peptides(chain)])
|
| 24 |
+
# print(len(sequence))
|
| 25 |
+
# sequences[chain_id] = sequence
|
| 26 |
+
|
| 27 |
+
# return sequences
|
| 28 |
+
|
| 29 |
+
def get_fasta_from_pdb(pdb_file):
|
| 30 |
+
parser = PDBParser()
|
| 31 |
+
structure = parser.get_structure("pdb", pdb_file)
|
| 32 |
+
|
| 33 |
+
fasta_sequence = {}
|
| 34 |
+
for chain in structure.get_chains():
|
| 35 |
+
seq = ""
|
| 36 |
+
for residue in chain.get_residues():
|
| 37 |
+
seq += seq1(residue.get_resname())
|
| 38 |
+
fasta_sequence[chain.id] = seq
|
| 39 |
+
|
| 40 |
+
return fasta_sequence
|
| 41 |
+
|
| 42 |
+
def parse_fasta(file):
|
| 43 |
+
sequences = {}
|
| 44 |
+
with open(file, "r") as fasta_file:
|
| 45 |
+
for i, record in enumerate(SeqIO.parse(fasta_file, "fasta")):
|
| 46 |
+
sequences[i] = str(record.seq).split("/")
|
| 47 |
+
return sequences
|
| 48 |
+
|
| 49 |
+
def renumber_pdb(input_pdb, output_pdb):
|
| 50 |
+
parser = PDBParser()
|
| 51 |
+
structure = parser.get_structure("protein", input_pdb)
|
| 52 |
+
|
| 53 |
+
chain_dic = {}
|
| 54 |
+
|
| 55 |
+
for model in structure:
|
| 56 |
+
old_chains = []
|
| 57 |
+
new_chains = []
|
| 58 |
+
for chain in model: # this may include HEAATM atoms
|
| 59 |
+
new_chain_id = chain.id + "_renum"
|
| 60 |
+
new_chain = Chain.Chain(new_chain_id)
|
| 61 |
+
for i, residue in enumerate(chain):
|
| 62 |
+
new_residue = residue.copy()
|
| 63 |
+
new_residue_id = (residue.id[0], i + 1, residue.id[2])
|
| 64 |
+
new_residue.id = new_residue_id
|
| 65 |
+
new_chain.add(new_residue)
|
| 66 |
+
old_chains.append(chain)
|
| 67 |
+
new_chains.append(new_chain)
|
| 68 |
+
chain_dic[chain.id] = len(list(chain))
|
| 69 |
+
|
| 70 |
+
for chain, new_chain in zip(old_chains, new_chains):
|
| 71 |
+
model.detach_child(chain.id)
|
| 72 |
+
new_chain.id = chain.id
|
| 73 |
+
model.add(new_chain)
|
| 74 |
+
|
| 75 |
+
io = PDBIO()
|
| 76 |
+
io.set_structure(structure)
|
| 77 |
+
io.save(output_pdb)
|
| 78 |
+
|
| 79 |
+
return chain_dic
|
| 80 |
+
|
| 81 |
+
def get_chain_dic(input_pdb):
|
| 82 |
+
parser = PDBParser()
|
| 83 |
+
structure = parser.get_structure("protein", input_pdb)
|
| 84 |
+
|
| 85 |
+
chain_dic = {}
|
| 86 |
+
|
| 87 |
+
for model in structure:
|
| 88 |
+
for chain in model:
|
| 89 |
+
chain_dic[chain.id] = len([res for res in chain if is_aa(res) and res.has_id('CA')])
|
| 90 |
+
|
| 91 |
+
return chain_dic
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def keep_backbone_atoms(input_file, output_file):
|
| 95 |
+
|
| 96 |
+
class BackboneSelect(Select):
|
| 97 |
+
def accept_atom(self, atom):
|
| 98 |
+
return atom.get_name() in ["N", "CA", "C", "O"]
|
| 99 |
+
|
| 100 |
+
parser = PDBParser()
|
| 101 |
+
io = PDBIO()
|
| 102 |
+
|
| 103 |
+
structure = parser.get_structure("protein", input_file)
|
| 104 |
+
|
| 105 |
+
io.set_structure(structure)
|
| 106 |
+
io.save(output_file, BackboneSelect())
|
models_con/edge.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from pepflow.modules.common.geometry import angstrom_to_nm, pairwise_dihedrals
|
| 7 |
+
from pepflow.modules.common.layers import AngularEncoding
|
| 8 |
+
from pepflow.modules.protein.constants import BBHeavyAtom, AA
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EdgeEmbedder(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(self, feat_dim, max_num_atoms, max_aa_types=22, max_relpos=32, num_bins=16):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.max_num_atoms = max_num_atoms
|
| 16 |
+
self.max_aa_types = max_aa_types
|
| 17 |
+
self.max_relpos = max_relpos
|
| 18 |
+
self.num_bins = num_bins
|
| 19 |
+
self.aa_pair_embed = nn.Embedding(self.max_aa_types*self.max_aa_types, feat_dim)
|
| 20 |
+
self.relpos_embed = nn.Embedding(2*max_relpos+1, feat_dim)
|
| 21 |
+
|
| 22 |
+
self.aapair_to_distcoef = nn.Embedding(self.max_aa_types*self.max_aa_types, max_num_atoms*max_num_atoms)
|
| 23 |
+
nn.init.zeros_(self.aapair_to_distcoef.weight)
|
| 24 |
+
self.distance_embed = nn.Sequential(
|
| 25 |
+
nn.Linear(max_num_atoms*max_num_atoms, feat_dim), nn.ReLU(),
|
| 26 |
+
nn.Linear(feat_dim, feat_dim), nn.ReLU(),
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.dihedral_embed = AngularEncoding()
|
| 30 |
+
feat_dihed_dim = self.dihedral_embed.get_out_dim(2) # Phi and Psi
|
| 31 |
+
|
| 32 |
+
infeat_dim = feat_dim + feat_dim + feat_dim + feat_dihed_dim
|
| 33 |
+
self.out_mlp = nn.Sequential(
|
| 34 |
+
nn.Linear(infeat_dim, feat_dim), nn.ReLU(),
|
| 35 |
+
nn.Linear(feat_dim, feat_dim), nn.ReLU(),
|
| 36 |
+
nn.Linear(feat_dim, feat_dim),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
aa: (N, L).
|
| 43 |
+
res_nb: (N, L).
|
| 44 |
+
chain_nb: (N, L).
|
| 45 |
+
pos_atoms: (N, L, A, 3)
|
| 46 |
+
mask_atoms: (N, L, A)
|
| 47 |
+
trans, sc_trans: (N,L,3)
|
| 48 |
+
structure_mask: (N, L)
|
| 49 |
+
sequence_mask: (N, L), mask out unknown amino acids to generate.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
(N, L, L, feat_dim)
|
| 53 |
+
"""
|
| 54 |
+
N, L = aa.size()
|
| 55 |
+
|
| 56 |
+
# Remove other atoms
|
| 57 |
+
pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
|
| 58 |
+
mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
|
| 59 |
+
|
| 60 |
+
mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
|
| 61 |
+
mask_pair = mask_residue[:, :, None] * mask_residue[:, None, :]
|
| 62 |
+
pair_structure_mask = structure_mask[:, :, None] * structure_mask[:, None, :] if structure_mask is not None else None
|
| 63 |
+
|
| 64 |
+
# Pair identities
|
| 65 |
+
if sequence_mask is not None:
|
| 66 |
+
# Avoid data leakage at training time
|
| 67 |
+
aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
|
| 68 |
+
aa_pair = aa[:,:,None]*self.max_aa_types + aa[:,None,:] # (N, L, L)
|
| 69 |
+
feat_aapair = self.aa_pair_embed(aa_pair)
|
| 70 |
+
|
| 71 |
+
# Relative sequential positions
|
| 72 |
+
same_chain = (chain_nb[:, :, None] == chain_nb[:, None, :])
|
| 73 |
+
relpos = torch.clamp(
|
| 74 |
+
res_nb[:,:,None] - res_nb[:,None,:],
|
| 75 |
+
min=-self.max_relpos, max=self.max_relpos,
|
| 76 |
+
) # (N, L, L)
|
| 77 |
+
feat_relpos = self.relpos_embed(relpos + self.max_relpos) * same_chain[:,:,:,None]
|
| 78 |
+
|
| 79 |
+
# Distances
|
| 80 |
+
d = angstrom_to_nm(torch.linalg.norm(
|
| 81 |
+
pos_atoms[:,:,None,:,None] - pos_atoms[:,None,:,None,:],
|
| 82 |
+
dim = -1, ord = 2,
|
| 83 |
+
)).reshape(N, L, L, -1) # (N, L, L, A*A)
|
| 84 |
+
c = F.softplus(self.aapair_to_distcoef(aa_pair)) # (N, L, L, A*A)
|
| 85 |
+
d_gauss = torch.exp(-1 * c * d**2)
|
| 86 |
+
mask_atom_pair = (mask_atoms[:,:,None,:,None] * mask_atoms[:,None,:,None,:]).reshape(N, L, L, -1)
|
| 87 |
+
feat_dist = self.distance_embed(d_gauss * mask_atom_pair)
|
| 88 |
+
if pair_structure_mask is not None:
|
| 89 |
+
# Avoid data leakage at training time
|
| 90 |
+
feat_dist = feat_dist * pair_structure_mask[:, :, :, None]
|
| 91 |
+
|
| 92 |
+
# Orientations
|
| 93 |
+
dihed = pairwise_dihedrals(pos_atoms) # (N, L, L, 2)
|
| 94 |
+
feat_dihed = self.dihedral_embed(dihed)
|
| 95 |
+
if pair_structure_mask is not None:
|
| 96 |
+
# Avoid data leakage at training time
|
| 97 |
+
feat_dihed = feat_dihed * pair_structure_mask[:, :, :, None]
|
| 98 |
+
|
| 99 |
+
# # trans embed
|
| 100 |
+
# dist_feats = calc_distogram(
|
| 101 |
+
# trans, min_bin=1e-3, max_bin=20.0, num_bins=self.num_bins)
|
| 102 |
+
# if sc_trans == None:
|
| 103 |
+
# sc_trans = torch.zeros_like(trans)
|
| 104 |
+
# sc_feats = calc_distogram(
|
| 105 |
+
# sc_trans, min_bin=1e-3, max_bin=20.0, num_bins=self.num_bins)
|
| 106 |
+
|
| 107 |
+
# All
|
| 108 |
+
feat_all = torch.cat([feat_aapair, feat_relpos, feat_dist, feat_dihed], dim=-1)
|
| 109 |
+
feat_all = self.out_mlp(feat_all) # (N, L, L, F)
|
| 110 |
+
feat_all = feat_all * mask_pair[:, :, :, None]
|
| 111 |
+
|
| 112 |
+
return feat_all
|
models_con/flow_model.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import math
|
| 8 |
+
from tqdm.auto import tqdm
|
| 9 |
+
import functools
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
import pandas as pd
|
| 15 |
+
|
| 16 |
+
from models_con.edge import EdgeEmbedder
|
| 17 |
+
from models_con.node import NodeEmbedder
|
| 18 |
+
from pepflow.modules.common.layers import sample_from, clampped_one_hot
|
| 19 |
+
from models_con.ga import GAEncoder
|
| 20 |
+
from pepflow.modules.protein.constants import AA, BBHeavyAtom, max_num_heavyatoms
|
| 21 |
+
from pepflow.modules.common.geometry import construct_3d_basis
|
| 22 |
+
from pepflow.utils.data import mask_select_data, find_longest_true_segment, PaddingCollate
|
| 23 |
+
from pepflow.utils.misc import seed_all
|
| 24 |
+
from pepflow.utils.train import sum_weighted_losses
|
| 25 |
+
from torch.nn.utils import clip_grad_norm_
|
| 26 |
+
|
| 27 |
+
from pepflow.modules.so3.dist import centered_gaussian,uniform_so3
|
| 28 |
+
from pepflow.modules.common.geometry import batch_align, align
|
| 29 |
+
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
|
| 32 |
+
import wandb
|
| 33 |
+
|
| 34 |
+
from data import so3_utils
|
| 35 |
+
from data import all_atom
|
| 36 |
+
|
| 37 |
+
from models_con.pep_dataloader import PepDataset
|
| 38 |
+
|
| 39 |
+
from pepflow.utils.misc import load_config
|
| 40 |
+
from pepflow.utils.train import recursive_to
|
| 41 |
+
from easydict import EasyDict
|
| 42 |
+
|
| 43 |
+
from models_con.utils import process_dic
|
| 44 |
+
from models_con.torsion import get_torsion_angle, torsions_mask
|
| 45 |
+
import models_con.torus as torus
|
| 46 |
+
|
| 47 |
+
import gc
|
| 48 |
+
|
| 49 |
+
from copy import deepcopy
|
| 50 |
+
from pepflow.utils.data import PaddingCollate
|
| 51 |
+
collate_fn = PaddingCollate(eight=False)
|
| 52 |
+
from pepflow.utils.train import recursive_to
|
| 53 |
+
|
| 54 |
+
resolution_to_num_atoms = {
|
| 55 |
+
'backbone+CB': 5,
|
| 56 |
+
'full': max_num_heavyatoms
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
class FlowModel(nn.Module):
|
| 60 |
+
def __init__(self,cfg):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self._model_cfg = cfg.encoder
|
| 63 |
+
self._interpolant_cfg = cfg.interpolant
|
| 64 |
+
|
| 65 |
+
self.node_embedder = NodeEmbedder(cfg.encoder.node_embed_size,max_num_heavyatoms)
|
| 66 |
+
self.edge_embedder = EdgeEmbedder(cfg.encoder.edge_embed_size,max_num_heavyatoms)
|
| 67 |
+
self.ga_encoder = GAEncoder(cfg.encoder.ipa)
|
| 68 |
+
|
| 69 |
+
self.sample_structure = self._interpolant_cfg.sample_structure
|
| 70 |
+
self.sample_sequence = self._interpolant_cfg.sample_sequence
|
| 71 |
+
|
| 72 |
+
self.K = self._interpolant_cfg.seqs.num_classes
|
| 73 |
+
self.k = self._interpolant_cfg.seqs.simplex_value
|
| 74 |
+
|
| 75 |
+
def encode(self, batch):
|
| 76 |
+
rotmats_1 = construct_3d_basis(batch['pos_heavyatom'][:, :, BBHeavyAtom.CA],batch['pos_heavyatom'][:, :, BBHeavyAtom.C],batch['pos_heavyatom'][:, :, BBHeavyAtom.N] )
|
| 77 |
+
trans_1 = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA]
|
| 78 |
+
seqs_1 = batch['aa']
|
| 79 |
+
|
| 80 |
+
# ignore psi
|
| 81 |
+
# batch['torsion_angle'] = batch['torsion_angle'][:,:,1:]
|
| 82 |
+
# batch['torsion_angle_mask'] = batch['torsion_angle_mask'][:,:,1:]
|
| 83 |
+
angles_1 = batch['torsion_angle']
|
| 84 |
+
|
| 85 |
+
context_mask = torch.logical_and(batch['mask_heavyatom'][:, :, BBHeavyAtom.CA], ~batch['generate_mask'])
|
| 86 |
+
structure_mask = context_mask if self.sample_structure else None
|
| 87 |
+
sequence_mask = context_mask if self.sample_sequence else None
|
| 88 |
+
node_embed = self.node_embedder(batch['aa'], batch['res_nb'], batch['chain_nb'], batch['pos_heavyatom'],
|
| 89 |
+
batch['mask_heavyatom'], structure_mask=structure_mask, sequence_mask=sequence_mask)
|
| 90 |
+
edge_embed = self.edge_embedder(batch['aa'], batch['res_nb'], batch['chain_nb'], batch['pos_heavyatom'],
|
| 91 |
+
batch['mask_heavyatom'], structure_mask=structure_mask, sequence_mask=sequence_mask)
|
| 92 |
+
|
| 93 |
+
return rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed
|
| 94 |
+
|
| 95 |
+
def zero_center_part(self,pos,gen_mask,res_mask):
|
| 96 |
+
"""
|
| 97 |
+
move pos by center of gen_mask
|
| 98 |
+
pos: (B,N,3)
|
| 99 |
+
gen_mask, res_mask: (B,N)
|
| 100 |
+
"""
|
| 101 |
+
center = torch.sum(pos * gen_mask[...,None], dim=1) / (torch.sum(gen_mask,dim=-1,keepdim=True) + 1e-8) # (B,N,3)*(B,N,1)->(B,3)/(B,1)->(B,3)
|
| 102 |
+
center = center.unsqueeze(1) # (B,1,3)
|
| 103 |
+
# center = 0. it seems not center didnt influence the result, but its good for training stabilty
|
| 104 |
+
pos = pos - center
|
| 105 |
+
pos = pos * res_mask[...,None]
|
| 106 |
+
return pos,center
|
| 107 |
+
|
| 108 |
+
def seq_to_simplex(self,seqs):
|
| 109 |
+
return clampped_one_hot(seqs, self.K).float() * self.k * 2 - self.k # (B,L,K)
|
| 110 |
+
|
| 111 |
+
def forward(self, batch):
|
| 112 |
+
|
| 113 |
+
num_batch, num_res = batch['aa'].shape
|
| 114 |
+
gen_mask,res_mask,angle_mask = batch['generate_mask'].long(),batch['res_mask'].long(),batch['torsion_angle_mask'].long()
|
| 115 |
+
|
| 116 |
+
#encode
|
| 117 |
+
rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed = self.encode(batch) # no generate mask
|
| 118 |
+
|
| 119 |
+
# prepare for denoise
|
| 120 |
+
trans_1_c,_ = self.zero_center_part(trans_1,gen_mask,res_mask)
|
| 121 |
+
trans_1_c = trans_1 # already centered when constructing dataset
|
| 122 |
+
seqs_1_simplex = self.seq_to_simplex(seqs_1)
|
| 123 |
+
seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
|
| 124 |
+
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
t = torch.rand((num_batch,1), device=batch['aa'].device)
|
| 127 |
+
t = t*(1-2 * self._interpolant_cfg.t_normalization_clip) + self._interpolant_cfg.t_normalization_clip # avoid 0
|
| 128 |
+
if self.sample_structure:
|
| 129 |
+
# corrupt trans
|
| 130 |
+
trans_0 = torch.randn((num_batch,num_res,3), device=batch['aa'].device) * self._interpolant_cfg.trans.sigma # scale with sigma?
|
| 131 |
+
trans_0_c,_ = self.zero_center_part(trans_0,gen_mask,res_mask)
|
| 132 |
+
trans_t = (1-t[...,None])*trans_0_c + t[...,None]*trans_1_c
|
| 133 |
+
trans_t_c = torch.where(batch['generate_mask'][...,None],trans_t,trans_1_c)
|
| 134 |
+
# corrupt rotmats
|
| 135 |
+
rotmats_0 = uniform_so3(num_batch,num_res,device=batch['aa'].device)
|
| 136 |
+
rotmats_t = so3_utils.geodesic_t(t[..., None], rotmats_1, rotmats_0)
|
| 137 |
+
rotmats_t = torch.where(batch['generate_mask'][...,None,None],rotmats_t,rotmats_1)
|
| 138 |
+
# corrup angles
|
| 139 |
+
angles_0 = torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype) # (B,L,5)
|
| 140 |
+
angles_t = torus.tor_geodesic_t(t[..., None], angles_1, angles_0)
|
| 141 |
+
angles_t = torch.where(batch['generate_mask'][...,None],angles_t,angles_1)
|
| 142 |
+
else:
|
| 143 |
+
trans_t_c = trans_1_c.detach().clone()
|
| 144 |
+
rotmats_t = rotmats_1.detach().clone()
|
| 145 |
+
angles_t = angles_1.detach().clone()
|
| 146 |
+
if self.sample_sequence:
|
| 147 |
+
# corrupt seqs
|
| 148 |
+
seqs_0_simplex = self.k * torch.randn_like(seqs_1_simplex) # (B,L,K)
|
| 149 |
+
seqs_0_prob = F.softmax(seqs_0_simplex,dim=-1) # (B,L,K)
|
| 150 |
+
seqs_t_simplex = ((1 - t[..., None]) * seqs_0_simplex) + (t[..., None] * seqs_1_simplex) # (B,L,K)
|
| 151 |
+
seqs_t_simplex = torch.where(batch['generate_mask'][...,None],seqs_t_simplex,seqs_1_simplex)
|
| 152 |
+
seqs_t_prob = F.softmax(seqs_t_simplex,dim=-1) # (B,L,K)
|
| 153 |
+
seqs_t = sample_from(seqs_t_prob) # (B,L)
|
| 154 |
+
seqs_t = torch.where(batch['generate_mask'],seqs_t,seqs_1)
|
| 155 |
+
else:
|
| 156 |
+
seqs_t = seqs_1.detach().clone()
|
| 157 |
+
seqs_t_simplex = seqs_1_simplex.detach().clone()
|
| 158 |
+
seqs_t_prob = seqs_1_prob.detach().clone()
|
| 159 |
+
|
| 160 |
+
# denoise
|
| 161 |
+
pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t, trans_t_c, angles_t, seqs_t, node_embed, edge_embed, gen_mask, res_mask)
|
| 162 |
+
pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
|
| 163 |
+
pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,torch.clamp(seqs_1,0,19))
|
| 164 |
+
pred_trans_1_c,_ = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
|
| 165 |
+
pred_trans_1_c = pred_trans_1 # implicitly enforce zero center in gen_mask, in this way, we dont need to move receptor when sampling
|
| 166 |
+
|
| 167 |
+
norm_scale = 1 / (1 - torch.min(t[...,None], torch.tensor(self._interpolant_cfg.t_normalization_clip))) # yim etal.trick, 1/1-t
|
| 168 |
+
|
| 169 |
+
# trans vf loss
|
| 170 |
+
trans_loss = torch.sum((pred_trans_1_c - trans_1_c)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
|
| 171 |
+
trans_loss = torch.mean(trans_loss)
|
| 172 |
+
|
| 173 |
+
# rots vf loss
|
| 174 |
+
gt_rot_vf = so3_utils.calc_rot_vf(rotmats_t, rotmats_1)
|
| 175 |
+
pred_rot_vf = so3_utils.calc_rot_vf(rotmats_t, pred_rotmats_1)
|
| 176 |
+
rot_loss = torch.sum(((gt_rot_vf - pred_rot_vf) * norm_scale)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
|
| 177 |
+
rot_loss = torch.mean(rot_loss)
|
| 178 |
+
|
| 179 |
+
# bb aux loss
|
| 180 |
+
gt_bb_atoms = all_atom.to_atom37(trans_1_c, rotmats_1)[:, :, :3]
|
| 181 |
+
pred_bb_atoms = all_atom.to_atom37(pred_trans_1_c, pred_rotmats_1)[:, :, :3]
|
| 182 |
+
# gt_bb_atoms = all_atom.to_bb_atoms(trans_1_c, rotmats_1, angles_1[:,:,0]) # N,CA,C,O,CB
|
| 183 |
+
# pred_bb_atoms = all_atom.to_bb_atoms(pred_trans_1_c, pred_rotmats_1, pred_angles_1[:,:,0])
|
| 184 |
+
# print(gt_bb_atoms.shape)
|
| 185 |
+
bb_atom_loss = torch.sum(
|
| 186 |
+
(gt_bb_atoms - pred_bb_atoms) ** 2 * gen_mask[..., None, None],
|
| 187 |
+
dim=(-1, -2, -3)
|
| 188 |
+
) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
|
| 189 |
+
bb_atom_loss = torch.mean(bb_atom_loss)
|
| 190 |
+
# bb_atom_loss = torch.mean(torch.where(t[:,0]>=0.75,bb_atom_loss,torch.zeros_like(bb_atom_loss))) # penalty for near gt point
|
| 191 |
+
|
| 192 |
+
# seqs vf loss
|
| 193 |
+
seqs_loss = F.cross_entropy(pred_seqs_1_prob.view(-1,pred_seqs_1_prob.shape[-1]),torch.clamp(seqs_1,0,19).view(-1), reduction='none').view(pred_seqs_1_prob.shape[:-1]) # (N,L), not softmax
|
| 194 |
+
seqs_loss = torch.sum(seqs_loss * gen_mask, dim=-1) / (torch.sum(gen_mask,dim=-1) + 1e-8)
|
| 195 |
+
seqs_loss = torch.mean(seqs_loss)
|
| 196 |
+
|
| 197 |
+
# we should not use angle mask, as you dont know aa type when generating
|
| 198 |
+
# angle_mask_loss = torch.cat([angle_mask,angle_mask],dim=-1) # (B,L,10)
|
| 199 |
+
# angle vf loss
|
| 200 |
+
angle_mask_loss = torsions_mask.to(batch['aa'].device)
|
| 201 |
+
angle_mask_loss = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
|
| 202 |
+
angle_mask_loss = torch.cat([angle_mask_loss,angle_mask_loss],dim=-1) # (B,L,10)
|
| 203 |
+
angle_mask_loss = torch.logical_and(batch['generate_mask'][...,None].bool(),angle_mask_loss)
|
| 204 |
+
gt_angle_vf = torus.tor_logmap(angles_t, angles_1)
|
| 205 |
+
gt_angle_vf_vec = torch.cat([torch.sin(gt_angle_vf),torch.cos(gt_angle_vf)],dim=-1)
|
| 206 |
+
pred_angle_vf = torus.tor_logmap(angles_t, pred_angles_1)
|
| 207 |
+
pred_angle_vf_vec = torch.cat([torch.sin(pred_angle_vf),torch.cos(pred_angle_vf)],dim=-1)
|
| 208 |
+
# angle_loss = torch.sum(((gt_angle_vf_vec - pred_angle_vf_vec) * norm_scale)**2*gen_mask[...,None],dim=(-1,-2)) / ((torch.sum(gen_mask,dim=-1)) + 1e-8) # (B,)
|
| 209 |
+
angle_loss = torch.sum(((gt_angle_vf_vec - pred_angle_vf_vec) * norm_scale)**2*angle_mask_loss,dim=(-1,-2)) / (torch.sum(angle_mask_loss,dim=(-1,-2)) + 1e-8) # (B,)
|
| 210 |
+
angle_loss = torch.mean(angle_loss)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# angle aux loss
|
| 214 |
+
angles_1_vec = torch.cat([torch.sin(angles_1),torch.cos(angles_1)],dim=-1)
|
| 215 |
+
pred_angles_1_vec = torch.cat([torch.sin(pred_angles_1),torch.cos(pred_angles_1)],dim=-1)
|
| 216 |
+
# torsion_loss = torch.sum((pred_angles_1_vec - angles_1_vec)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
|
| 217 |
+
torsion_loss = torch.sum((pred_angles_1_vec - angles_1_vec)**2*angle_mask_loss,dim=(-1,-2)) / (torch.sum(angle_mask_loss,dim=(-1,-2)) + 1e-8) # (B,)
|
| 218 |
+
torsion_loss = torch.mean(torsion_loss)
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
"trans_loss": trans_loss,
|
| 222 |
+
'rot_loss': rot_loss,
|
| 223 |
+
'bb_atom_loss': bb_atom_loss,
|
| 224 |
+
'seqs_loss': seqs_loss,
|
| 225 |
+
'angle_loss': angle_loss,
|
| 226 |
+
'torsion_loss': torsion_loss,
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
@torch.no_grad()
|
| 230 |
+
def sample(self, batch, num_steps = 100, sample_bb=True, sample_ang=True, sample_seq=True):
|
| 231 |
+
|
| 232 |
+
num_batch, num_res = batch['aa'].shape
|
| 233 |
+
gen_mask,res_mask = batch['generate_mask'],batch['res_mask']
|
| 234 |
+
K = self._interpolant_cfg.seqs.num_classes
|
| 235 |
+
k = self._interpolant_cfg.seqs.simplex_value
|
| 236 |
+
angle_mask_loss = torsions_mask.to(batch['aa'].device)
|
| 237 |
+
|
| 238 |
+
#encode
|
| 239 |
+
rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed = self.encode(batch)
|
| 240 |
+
# trans_1_c,center = self.zero_center_part(trans_1,gen_mask,res_mask)
|
| 241 |
+
trans_1_c = trans_1
|
| 242 |
+
seqs_1_simplex = self.seq_to_simplex(seqs_1)
|
| 243 |
+
seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
|
| 244 |
+
|
| 245 |
+
# # # only sample bb, angle and seq with noise
|
| 246 |
+
# angles_1 = torch.where(batch['generate_mask'][...,None],angles_1,torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype))
|
| 247 |
+
# seqs_1 = torch.where(batch['generate_mask'],seqs_1,torch.randint_like(seqs_1,0,20))
|
| 248 |
+
# seqs_1_simplex = self.seq_to_simplex(seqs_1)
|
| 249 |
+
# seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
|
| 250 |
+
|
| 251 |
+
#initial noise
|
| 252 |
+
if sample_bb:
|
| 253 |
+
rotmats_0 = uniform_so3(num_batch,num_res,device=batch['aa'].device)
|
| 254 |
+
rotmats_0 = torch.where(batch['generate_mask'][...,None,None],rotmats_0,rotmats_1)
|
| 255 |
+
trans_0 = torch.randn((num_batch,num_res,3), device=batch['aa'].device) # scale with sigma?
|
| 256 |
+
# move center and receptor
|
| 257 |
+
trans_0_c,center = self.zero_center_part(trans_0,gen_mask,res_mask)
|
| 258 |
+
trans_0_c = torch.where(batch['generate_mask'][...,None],trans_0_c,trans_1_c)
|
| 259 |
+
else:
|
| 260 |
+
rotmats_0 = rotmats_1.detach().clone()
|
| 261 |
+
trans_0_c = trans_1_c.detach().clone()
|
| 262 |
+
if sample_ang:
|
| 263 |
+
# angle noise
|
| 264 |
+
angles_0 = torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype) # (B,L,5)
|
| 265 |
+
angles_0 = torch.where(batch['generate_mask'][...,None],angles_0,angles_1)
|
| 266 |
+
else:
|
| 267 |
+
angles_0 = angles_1.detach().clone()
|
| 268 |
+
if sample_seq:
|
| 269 |
+
seqs_0_simplex = k * torch.randn((num_batch,num_res,K), device=batch['aa'].device)
|
| 270 |
+
seqs_0_prob = F.softmax(seqs_0_simplex,dim=-1)
|
| 271 |
+
seqs_0 = sample_from(seqs_0_prob)
|
| 272 |
+
seqs_0 = torch.where(batch['generate_mask'],seqs_0,seqs_1)
|
| 273 |
+
seqs_0_simplex = torch.where(batch['generate_mask'][...,None],seqs_0_simplex,seqs_1_simplex)
|
| 274 |
+
else:
|
| 275 |
+
seqs_0 = seqs_1.detach().clone()
|
| 276 |
+
seqs_0_prob = seqs_1_prob.detach().clone()
|
| 277 |
+
seqs_0_simplex = seqs_1_simplex.detach().clone()
|
| 278 |
+
|
| 279 |
+
# Set-up time
|
| 280 |
+
ts = torch.linspace(1.e-2, 1.0, num_steps)
|
| 281 |
+
t_1 = ts[0]
|
| 282 |
+
# prot_traj = [{'rotmats':rotmats_0,'trans':trans_0_c,'seqs':seqs_0,'seqs_simplex':seqs_0_simplex,'rotmats_1':rotmats_1,'trans_1':trans_1-center,'seqs_1':seqs_1}]
|
| 283 |
+
clean_traj = []
|
| 284 |
+
rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, seqs_t_1_simplex = rotmats_0, trans_0_c, angles_0, seqs_0, seqs_0_simplex
|
| 285 |
+
|
| 286 |
+
# denoise loop
|
| 287 |
+
for t_2 in ts[1:]:
|
| 288 |
+
t = torch.ones((num_batch, 1), device=batch['aa'].device) * t_1
|
| 289 |
+
# rots
|
| 290 |
+
pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, node_embed, edge_embed, batch['generate_mask'].long(), batch['res_mask'].long())
|
| 291 |
+
pred_rotmats_1 = torch.where(batch['generate_mask'][...,None,None],pred_rotmats_1,rotmats_1)
|
| 292 |
+
# trans, move center
|
| 293 |
+
# pred_trans_1_c,center = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
|
| 294 |
+
pred_trans_1_c = torch.where(batch['generate_mask'][...,None],pred_trans_1,trans_1_c) # move receptor also
|
| 295 |
+
# angles
|
| 296 |
+
pred_angles_1 = torch.where(batch['generate_mask'][...,None],pred_angles_1,angles_1)
|
| 297 |
+
# seqs
|
| 298 |
+
pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
|
| 299 |
+
pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,seqs_1)
|
| 300 |
+
pred_seqs_1_simplex = self.seq_to_simplex(pred_seqs_1)
|
| 301 |
+
# seq-angle
|
| 302 |
+
torsion_mask = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
|
| 303 |
+
pred_angles_1 = torch.where(torsion_mask.bool(),pred_angles_1,torch.zeros_like(pred_angles_1))
|
| 304 |
+
if not sample_bb:
|
| 305 |
+
pred_trans_1_c = trans_1_c.detach().clone()
|
| 306 |
+
# _,center = self.zero_center_part(trans_1,gen_mask,res_mask)
|
| 307 |
+
pred_rotmats_1 = rotmats_1.detach().clone()
|
| 308 |
+
if not sample_ang:
|
| 309 |
+
pred_angles_1 = angles_1.detach().clone()
|
| 310 |
+
if not sample_seq:
|
| 311 |
+
pred_seqs_1 = seqs_1.detach().clone()
|
| 312 |
+
pred_seqs_1_simplex = seqs_1_simplex.detach().clone()
|
| 313 |
+
clean_traj.append({'rotmats':pred_rotmats_1.cpu(),'trans':pred_trans_1_c.cpu(),'angles':pred_angles_1.cpu(),'seqs':pred_seqs_1.cpu(),'seqs_simplex':pred_seqs_1_simplex.cpu(),
|
| 314 |
+
'rotmats_1':rotmats_1.cpu(),'trans_1':trans_1_c.cpu(),'angles_1':angles_1.cpu(),'seqs_1':seqs_1.cpu()})
|
| 315 |
+
# reverse step, also only for gen mask region
|
| 316 |
+
d_t = (t_2-t_1) * torch.ones((num_batch, 1), device=batch['aa'].device)
|
| 317 |
+
# Euler step
|
| 318 |
+
trans_t_2 = trans_t_1_c + (pred_trans_1_c-trans_0_c)*d_t[...,None]
|
| 319 |
+
# trans_t_2_c,center = self.zero_center_part(trans_t_2,gen_mask,res_mask)
|
| 320 |
+
trans_t_2_c = torch.where(batch['generate_mask'][...,None],trans_t_2,trans_1_c) # move receptor also
|
| 321 |
+
# rotmats_t_2 = so3_utils.geodesic_t(d_t[...,None] / (1-t[...,None]), pred_rotmats_1, rotmats_t_1)
|
| 322 |
+
rotmats_t_2 = so3_utils.geodesic_t(d_t[...,None] * 10, pred_rotmats_1, rotmats_t_1)
|
| 323 |
+
rotmats_t_2 = torch.where(batch['generate_mask'][...,None,None],rotmats_t_2,rotmats_1)
|
| 324 |
+
# angles
|
| 325 |
+
angles_t_2 = torus.tor_geodesic_t(d_t[...,None],pred_angles_1, angles_t_1)
|
| 326 |
+
angles_t_2 = torch.where(batch['generate_mask'][...,None],angles_t_2,angles_1)
|
| 327 |
+
# seqs
|
| 328 |
+
seqs_t_2_simplex = seqs_t_1_simplex + (pred_seqs_1_simplex - seqs_0_simplex) * d_t[...,None]
|
| 329 |
+
seqs_t_2 = sample_from(F.softmax(seqs_t_2_simplex,dim=-1))
|
| 330 |
+
seqs_t_2 = torch.where(batch['generate_mask'],seqs_t_2,seqs_1)
|
| 331 |
+
# seq-angle
|
| 332 |
+
torsion_mask = angle_mask_loss[seqs_t_2.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
|
| 333 |
+
angles_t_2 = torch.where(torsion_mask.bool(),angles_t_2,torch.zeros_like(angles_t_2))
|
| 334 |
+
|
| 335 |
+
if not sample_bb:
|
| 336 |
+
trans_t_2_c = trans_1_c.detach().clone()
|
| 337 |
+
rotmats_t_2 = rotmats_1.detach().clone()
|
| 338 |
+
if not sample_ang:
|
| 339 |
+
angles_t_2 = angles_1.detach().clone()
|
| 340 |
+
if not sample_seq:
|
| 341 |
+
seqs_t_2 = seqs_1.detach().clone()
|
| 342 |
+
rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, seqs_t_1_simplex = rotmats_t_2, trans_t_2_c, angles_t_2, seqs_t_2, seqs_t_2_simplex
|
| 343 |
+
t_1 = t_2
|
| 344 |
+
|
| 345 |
+
# final step
|
| 346 |
+
t_1 = ts[-1]
|
| 347 |
+
t = torch.ones((num_batch, 1), device=batch['aa'].device) * t_1
|
| 348 |
+
pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, node_embed, edge_embed, batch['generate_mask'].long(), batch['res_mask'].long())
|
| 349 |
+
pred_rotmats_1 = torch.where(batch['generate_mask'][...,None,None],pred_rotmats_1,rotmats_1)
|
| 350 |
+
# move center
|
| 351 |
+
# pred_trans_1_c,center = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
|
| 352 |
+
pred_trans_1_c = torch.where(batch['generate_mask'][...,None],pred_trans_1,trans_1_c) # move receptor also
|
| 353 |
+
# angles
|
| 354 |
+
pred_angles_1 = torch.where(batch['generate_mask'][...,None],pred_angles_1,angles_1)
|
| 355 |
+
# seqs
|
| 356 |
+
pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
|
| 357 |
+
pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,seqs_1)
|
| 358 |
+
pred_seqs_1_simplex = self.seq_to_simplex(pred_seqs_1)
|
| 359 |
+
# seq-angle
|
| 360 |
+
torsion_mask = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
|
| 361 |
+
pred_angles_1 = torch.where(torsion_mask.bool(),pred_angles_1,torch.zeros_like(pred_angles_1))
|
| 362 |
+
if not sample_bb:
|
| 363 |
+
pred_trans_1_c = trans_1_c.detach().clone()
|
| 364 |
+
# _,center = self.zero_center_part(trans_1,gen_mask,res_mask)
|
| 365 |
+
pred_rotmats_1 = rotmats_1.detach().clone()
|
| 366 |
+
if not sample_ang:
|
| 367 |
+
pred_angles_1 = angles_1.detach().clone()
|
| 368 |
+
if not sample_seq:
|
| 369 |
+
pred_seqs_1 = seqs_1.detach().clone()
|
| 370 |
+
pred_seqs_1_simplex = seqs_1_simplex.detach().clone()
|
| 371 |
+
clean_traj.append({'rotmats':pred_rotmats_1.cpu(),'trans':pred_trans_1_c.cpu(),'angles':pred_angles_1.cpu(),'seqs':pred_seqs_1.cpu(),'seqs_simplex':pred_seqs_1_simplex.cpu(),
|
| 372 |
+
'rotmats_1':rotmats_1.cpu(),'trans_1':trans_1_c.cpu(),'angles_1':angles_1.cpu(),'seqs_1':seqs_1.cpu()})
|
| 373 |
+
|
| 374 |
+
return clean_traj
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# if __name__ == '__main__':
|
| 378 |
+
# prefix_dir = './pepflowww'
|
| 379 |
+
# # config,cfg_name = load_config("../configs/angle/learn_sc.yaml")
|
| 380 |
+
# config,cfg_name = load_config(os.path.join(prefix_dir,"configs/angle/learn_sc.yaml"))
|
| 381 |
+
# # print(config)
|
| 382 |
+
# device = 'cuda:0'
|
| 383 |
+
# dataset = PepDataset(structure_dir = config.dataset.val.structure_dir, dataset_dir = config.dataset.val.dataset_dir,
|
| 384 |
+
# name = config.dataset.val.name, transform=None, reset=config.dataset.val.reset)
|
| 385 |
+
# dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=PaddingCollate(eight=False), num_workers=4, pin_memory=True)
|
| 386 |
+
# ckpt = torch.load("./checkpoints/600000.pt", map_location=device)
|
| 387 |
+
# seed_all(114514)
|
| 388 |
+
# model = FlowModel(config.model).to(device)
|
| 389 |
+
# model.load_state_dict(process_dic(ckpt['model']))
|
| 390 |
+
# model.eval()
|
| 391 |
+
|
| 392 |
+
# # print(model)
|
| 393 |
+
|
| 394 |
+
# # print(dataset[0]['chain_id'])
|
| 395 |
+
# # print(dataset[0]['id'])
|
| 396 |
+
# # print(dataset[0]['resseq'])
|
| 397 |
+
# # print(dataset[0]['res_nb'])
|
| 398 |
+
# # print(dataset[0]['icode'])
|
| 399 |
+
|
| 400 |
+
# dic = {'id':[],'len':[],'tran':[],'aar':[],'rot':[],'trans_loss':[],'rot_loss':[]}
|
| 401 |
+
|
| 402 |
+
# # for batch in tqdm(dataloader):
|
| 403 |
+
# # batch = recursive_to(batch,device)
|
| 404 |
+
# for i in tqdm(range(len(dataset))):
|
| 405 |
+
# item = dataset[i]
|
| 406 |
+
# data_list = [deepcopy(item) for _ in range(16)]
|
| 407 |
+
# batch = recursive_to(collate_fn(data_list),device)
|
| 408 |
+
# loss_dic = model(batch)
|
| 409 |
+
# # traj_1 = model.sample(batch,num_steps=50,sample_bb=False,sample_ang=True,sample_seq=False)
|
| 410 |
+
# traj_1 = model.sample(batch,num_steps=50,sample_bb=True,sample_ang=True,sample_seq=True)
|
| 411 |
+
# ca_dist = torch.sqrt(torch.sum((traj_1[-1]['trans']-traj_1[-1]['trans_1'])**2*batch['generate_mask'][...,None].cpu().long()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
|
| 412 |
+
# rot_dist = torch.sqrt(torch.sum((traj_1[-1]['rotmats']-traj_1[-1]['rotmats_1'])**2*batch['generate_mask'][...,None,None].long().cpu()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
|
| 413 |
+
# aar = torch.sum((traj_1[-1]['seqs']==traj_1[-1]['seqs_1']) * batch['generate_mask'].long().cpu()) / (torch.sum(batch['generate_mask']).cpu() + 1e-8)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
# print(loss_dic)
|
| 417 |
+
# print(f'tran:{ca_dist},rot:{rot_dist},aar:{aar},len:{batch["generate_mask"].sum().item()}')
|
| 418 |
+
|
| 419 |
+
# # free
|
| 420 |
+
# torch.cuda.empty_cache()
|
| 421 |
+
# gc.collect()
|
| 422 |
+
|
| 423 |
+
# # dic['tran'].append(ca_dist.item())
|
| 424 |
+
# # dic['rot'].append(rot_dist.item())
|
| 425 |
+
# dic['aar'].append(aar.item())
|
| 426 |
+
# dic['trans_loss'].append(loss_dic['trans_loss'].item())
|
| 427 |
+
# dic['rot_loss'].append(loss_dic['rot_loss'].item())
|
| 428 |
+
# dic['id'].append(batch['id'][0])
|
| 429 |
+
# dic['len'].append(batch['generate_mask'].sum().item())
|
| 430 |
+
# # # break
|
| 431 |
+
|
| 432 |
+
# # traj_1[-1]['batch'] = batch
|
| 433 |
+
# # torch.save(traj_1[-1],f'/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Pack_new/outputs/{batch["id"][0]}.pt')
|
| 434 |
+
|
| 435 |
+
# # print(dic)
|
| 436 |
+
# # dic = pd.DataFrame(dic)
|
| 437 |
+
# # dic.to_csv(f'/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Pack/outputs.csv',index=None)
|
| 438 |
+
|
| 439 |
+
# print(np.mean(dic['aar']))
|
| 440 |
+
# print(np.mean(dic['trans_loss']))
|
| 441 |
+
|
| 442 |
+
# if __name__ == '__main__':
|
| 443 |
+
# config,cfg_name = load_config("./configs/angle/learn_angle.yaml")
|
| 444 |
+
# seed_all(114514)
|
| 445 |
+
# device = 'cpu'
|
| 446 |
+
# dataset = PepDataset(structure_dir = config.dataset.train.structure_dir, dataset_dir = config.dataset.train.dataset_dir,
|
| 447 |
+
# name = config.dataset.train.name, transform=None, reset=config.dataset.train.reset)
|
| 448 |
+
# dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=PaddingCollate(), num_workers=4, pin_memory=True)
|
| 449 |
+
# model = FlowModel(config.model).to(device)
|
| 450 |
+
# optimizer = torch.optim.Adam(model.parameters(),lr=1.e-4)
|
| 451 |
+
|
| 452 |
+
# # ckpt = torch.load('./checkpoints/90000.pt',map_location=device)
|
| 453 |
+
# # model.load_state_dict(process_dic(ckpt['model']))
|
| 454 |
+
# # optimizer.load_state_dict(ckpt['optimizer'])
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# # torch.autograd.set_detect_anomaly(True)
|
| 458 |
+
# for i,batch in tqdm(enumerate(dataloader)):
|
| 459 |
+
# batch = recursive_to(batch,device)
|
| 460 |
+
# loss_dict = model(batch)
|
| 461 |
+
# loss = sum_weighted_losses(loss_dict, config.train.loss_weights)
|
| 462 |
+
# # if torch.isnan(loss):
|
| 463 |
+
# # print(i)
|
| 464 |
+
# # print(batch['id'])
|
| 465 |
+
|
| 466 |
+
# loss.backward()
|
| 467 |
+
# orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
|
| 468 |
+
|
| 469 |
+
# print(f'{loss_dict},{loss},{orig_grad_norm}')
|
| 470 |
+
|
| 471 |
+
# optimizer.step()
|
| 472 |
+
# optimizer.zero_grad()
|
models_con/ga.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from models_con import ipa_pytorch as ipa_pytorch
|
| 5 |
+
from data import utils as du
|
| 6 |
+
|
| 7 |
+
from models_con.utils import get_index_embedding, get_time_embedding
|
| 8 |
+
|
| 9 |
+
from pepflow.modules.protein.constants import ANG_TO_NM_SCALE, NM_TO_ANG_SCALE
|
| 10 |
+
from pepflow.modules.common.layers import AngularEncoding
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GAEncoder(nn.Module):
|
| 16 |
+
def __init__(self, ipa_conf):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self._ipa_conf = ipa_conf
|
| 19 |
+
|
| 20 |
+
# angles
|
| 21 |
+
self.angles_embedder = AngularEncoding(num_funcs=12) # 25*5=120, for competitive embedding size
|
| 22 |
+
self.angle_net = nn.Sequential(
|
| 23 |
+
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
|
| 24 |
+
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
|
| 25 |
+
nn.Linear(self._ipa_conf.c_s, 5)
|
| 26 |
+
# nn.Linear(self._ipa_conf.c_s, 22)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# for condition on current seq
|
| 30 |
+
self.current_seq_embedder = nn.Embedding(22, self._ipa_conf.c_s)
|
| 31 |
+
self.seq_net = nn.Sequential(
|
| 32 |
+
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
|
| 33 |
+
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),nn.ReLU(),
|
| 34 |
+
nn.Linear(self._ipa_conf.c_s, 20)
|
| 35 |
+
# nn.Linear(self._ipa_conf.c_s, 22)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# mixer
|
| 39 |
+
self.res_feat_mixer = nn.Sequential(
|
| 40 |
+
nn.Linear(3 * self._ipa_conf.c_s + self.angles_embedder.get_out_dim(in_dim=5), self._ipa_conf.c_s),
|
| 41 |
+
nn.ReLU(),
|
| 42 |
+
nn.Linear(self._ipa_conf.c_s, self._ipa_conf.c_s),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.feat_dim = self._ipa_conf.c_s
|
| 46 |
+
|
| 47 |
+
# Attention trunk
|
| 48 |
+
self.trunk = nn.ModuleDict()
|
| 49 |
+
for b in range(self._ipa_conf.num_blocks):
|
| 50 |
+
self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf)
|
| 51 |
+
self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s)
|
| 52 |
+
tfmr_in = self._ipa_conf.c_s
|
| 53 |
+
tfmr_layer = torch.nn.TransformerEncoderLayer(
|
| 54 |
+
d_model=tfmr_in,
|
| 55 |
+
nhead=self._ipa_conf.seq_tfmr_num_heads,
|
| 56 |
+
dim_feedforward=tfmr_in,
|
| 57 |
+
batch_first=True,
|
| 58 |
+
dropout=0.0,
|
| 59 |
+
norm_first=False
|
| 60 |
+
)
|
| 61 |
+
self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
|
| 62 |
+
tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False)
|
| 63 |
+
self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(
|
| 64 |
+
tfmr_in, self._ipa_conf.c_s, init="final")
|
| 65 |
+
self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(
|
| 66 |
+
c=self._ipa_conf.c_s)
|
| 67 |
+
self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(
|
| 68 |
+
self._ipa_conf.c_s, use_rot_updates=True)
|
| 69 |
+
|
| 70 |
+
if b < self._ipa_conf.num_blocks-1:
|
| 71 |
+
# No edge update on the last block.
|
| 72 |
+
edge_in = self._ipa_conf.c_z
|
| 73 |
+
self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition(
|
| 74 |
+
node_embed_size=self._ipa_conf.c_s,
|
| 75 |
+
edge_embed_in=edge_in,
|
| 76 |
+
edge_embed_out=self._ipa_conf.c_z,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def embed_t(self, timesteps, mask):
|
| 80 |
+
timestep_emb = get_time_embedding(
|
| 81 |
+
timesteps[:, 0],
|
| 82 |
+
self.feat_dim,
|
| 83 |
+
max_positions=2056
|
| 84 |
+
)[:, None, :].repeat(1, mask.shape[1], 1)
|
| 85 |
+
return timestep_emb
|
| 86 |
+
|
| 87 |
+
def forward(self, t, rotmats_t, trans_t, angles_t, seqs_t, node_embed, edge_embed, generate_mask, res_mask):
|
| 88 |
+
num_batch, num_res = seqs_t.shape
|
| 89 |
+
|
| 90 |
+
# incorperate current seq and timesteps
|
| 91 |
+
node_mask = res_mask
|
| 92 |
+
edge_mask = node_mask[:, None] * node_mask[:, :, None]
|
| 93 |
+
|
| 94 |
+
node_embed = self.res_feat_mixer(torch.cat([node_embed, self.current_seq_embedder(seqs_t), self.embed_t(t,node_mask), self.angles_embedder(angles_t).reshape(num_batch,num_res,-1)],dim=-1))
|
| 95 |
+
node_embed = node_embed * node_mask[..., None]
|
| 96 |
+
curr_rigids = du.create_rigid(rotmats_t, trans_t)
|
| 97 |
+
for b in range(self._ipa_conf.num_blocks):
|
| 98 |
+
ipa_embed = self.trunk[f'ipa_{b}'](
|
| 99 |
+
node_embed,
|
| 100 |
+
edge_embed,
|
| 101 |
+
curr_rigids,
|
| 102 |
+
node_mask)
|
| 103 |
+
ipa_embed *= node_mask[..., None]
|
| 104 |
+
node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
|
| 105 |
+
seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
|
| 106 |
+
node_embed, src_key_padding_mask=(1 - node_mask).bool())
|
| 107 |
+
node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
|
| 108 |
+
node_embed = self.trunk[f'node_transition_{b}'](node_embed)
|
| 109 |
+
node_embed = node_embed * node_mask[..., None]
|
| 110 |
+
rigid_update = self.trunk[f'bb_update_{b}'](
|
| 111 |
+
node_embed * node_mask[..., None])
|
| 112 |
+
curr_rigids = curr_rigids.compose_q_update_vec(
|
| 113 |
+
rigid_update, node_mask[..., None])
|
| 114 |
+
|
| 115 |
+
if b < self._ipa_conf.num_blocks-1:
|
| 116 |
+
edge_embed = self.trunk[f'edge_transition_{b}'](
|
| 117 |
+
node_embed, edge_embed)
|
| 118 |
+
edge_embed *= edge_mask[..., None]
|
| 119 |
+
|
| 120 |
+
# curr_rigids = self.rigids_nm_to_ang(curr_rigids)
|
| 121 |
+
pred_trans1 = curr_rigids.get_trans()
|
| 122 |
+
pred_rotmats1 = curr_rigids.get_rots().get_rot_mats()
|
| 123 |
+
pred_seqs1_prob = self.seq_net(node_embed)
|
| 124 |
+
pred_angles1 = self.angle_net(node_embed)
|
| 125 |
+
pred_angles1 = pred_angles1 % (2*math.pi) # inductive bias to bound between (0,2pi)
|
| 126 |
+
|
| 127 |
+
return pred_rotmats1, pred_trans1, pred_angles1, pred_seqs1_prob
|
models_con/inference.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
import math
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
import functools
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
|
| 16 |
+
from models_con.pep_dataloader import PepDataset
|
| 17 |
+
|
| 18 |
+
from pepflow.utils.misc import load_config
|
| 19 |
+
from pepflow.utils.train import recursive_to
|
| 20 |
+
|
| 21 |
+
from pepflow.modules.common.geometry import reconstruct_backbone, reconstruct_backbone_partially, align, batch_align
|
| 22 |
+
from pepflow.modules.protein.writers import save_pdb
|
| 23 |
+
|
| 24 |
+
from pepflow.utils.data import PaddingCollate
|
| 25 |
+
|
| 26 |
+
from models_con.utils import process_dic
|
| 27 |
+
|
| 28 |
+
import gc
|
| 29 |
+
|
| 30 |
+
from models_con.flow_model import FlowModel
|
| 31 |
+
|
| 32 |
+
from pepflow.utils.misc import seed_all
|
| 33 |
+
|
| 34 |
+
from models_con.torsion import full_atom_reconstruction, get_heavyatom_mask
|
| 35 |
+
|
| 36 |
+
collate_fn = PaddingCollate(eight=False)
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == '__main__':
|
| 42 |
+
args = argparse.ArgumentParser()
|
| 43 |
+
args.add_argument('--config', type=str)
|
| 44 |
+
args.add_argument('--device', type=str)
|
| 45 |
+
args.add_argument('--ckpt', type=str)
|
| 46 |
+
args.add_argument('--output', type=str)
|
| 47 |
+
args.add_argument('--num_steps', type=int, default=200)
|
| 48 |
+
args.add_argument('--num_samples', type=int, default=64)
|
| 49 |
+
args.add_argument('--sample_bb', type=bool, default=True)
|
| 50 |
+
args.add_argument('--sample_ang', type=bool, default=True)
|
| 51 |
+
args.add_argument('--sample_seq', type=bool, default=True)
|
| 52 |
+
args.add_argument('--num_samples', type=int, default=64)
|
| 53 |
+
args.add_argument('--num_samples', type=int, default=64)
|
| 54 |
+
parser = args.parse_args()
|
| 55 |
+
|
| 56 |
+
config,cfg_name = load_config(parser.config)
|
| 57 |
+
device = parser.device
|
| 58 |
+
dataset = PepDataset(structure_dir = config.dataset.val.structure_dir, dataset_dir = config.dataset.val.dataset_dir,
|
| 59 |
+
name = config.dataset.val.name, transform=None, reset=config.dataset.val.reset)
|
| 60 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=PaddingCollate(eight=False), num_workers=4, pin_memory=True)
|
| 61 |
+
ckpt = torch.load(parser.ckpt, map_location=device)
|
| 62 |
+
|
| 63 |
+
seed_all(114514)
|
| 64 |
+
model = FlowModel(config.model).to(device)
|
| 65 |
+
model.load_state_dict(process_dic(ckpt['model']))
|
| 66 |
+
model.eval()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
dic = {'id':[],'len':[],'tran':[],'aar':[],'rot':[],'trans_loss':[],'rot_loss':[]}
|
| 70 |
+
|
| 71 |
+
for i in tqdm(range(len(dataset))):
|
| 72 |
+
item = dataset[i]
|
| 73 |
+
data_list = [deepcopy(item) for _ in range(parser.num_samples)]
|
| 74 |
+
batch = recursive_to(collate_fn(data_list),device)
|
| 75 |
+
loss_dic = model(batch)
|
| 76 |
+
traj_1 = model.sample(batch,num_steps=parser.num_steps,sample_bb=parser.sample_bb,sample_ang=parser.sample_ang,sample_seq=parser.sample_seq)
|
| 77 |
+
ca_dist = torch.sqrt(torch.sum((traj_1[-1]['trans']-traj_1[-1]['trans_1'])**2*batch['generate_mask'][...,None].cpu().long()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
|
| 78 |
+
rot_dist = torch.sqrt(torch.sum((traj_1[-1]['rotmats']-traj_1[-1]['rotmats_1'])**2*batch['generate_mask'][...,None,None].long().cpu()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
|
| 79 |
+
aar = torch.sum((traj_1[-1]['seqs']==traj_1[-1]['seqs_1']) * batch['generate_mask'].long().cpu()) / (torch.sum(batch['generate_mask']).cpu() + 1e-8)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
print(loss_dic)
|
| 83 |
+
print(f'tran:{ca_dist},rot:{rot_dist},aar:{aar},len:{batch["generate_mask"].sum().item()}')
|
| 84 |
+
|
| 85 |
+
# free
|
| 86 |
+
torch.cuda.empty_cache()
|
| 87 |
+
gc.collect()
|
| 88 |
+
|
| 89 |
+
dic['tran'].append(ca_dist.item())
|
| 90 |
+
dic['rot'].append(rot_dist.item())
|
| 91 |
+
dic['aar'].append(aar.item())
|
| 92 |
+
dic['trans_loss'].append(loss_dic['trans_loss'].item())
|
| 93 |
+
dic['rot_loss'].append(loss_dic['rot_loss'].item())
|
| 94 |
+
dic['id'].append(batch['id'][0])
|
| 95 |
+
dic['len'].append(batch['generate_mask'].sum().item())
|
| 96 |
+
# break
|
| 97 |
+
|
| 98 |
+
traj_1[-1]['batch'] = batch
|
| 99 |
+
torch.save(traj_1[-1],f'{parser.output}/outputs/{batch["id"][0]}.pt')
|
| 100 |
+
dic = pd.DataFrame(dic)
|
| 101 |
+
dic.to_csv(f'{parser.output}/outputs.csv',index=None)
|
models_con/ipa_pytorch.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Modified code of Openfold's IPA."""
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import math
|
| 21 |
+
from scipy.stats import truncnorm
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from typing import Optional, Callable, List, Sequence
|
| 24 |
+
from openfold.utils.rigid_utils import Rigid
|
| 25 |
+
from data import all_atom
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
| 29 |
+
zero_index = -1 * len(inds)
|
| 30 |
+
first_inds = list(range(len(tensor.shape[:zero_index])))
|
| 31 |
+
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
| 35 |
+
return t.reshape(t.shape[:-no_dims] + (-1,))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def ipa_point_weights_init_(weights):
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
softplus_inverse_1 = 0.541324854612918
|
| 41 |
+
weights.fill_(softplus_inverse_1)
|
| 42 |
+
|
| 43 |
+
def _prod(nums):
|
| 44 |
+
out = 1
|
| 45 |
+
for n in nums:
|
| 46 |
+
out = out * n
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
| 51 |
+
fan_out, fan_in = linear_weight_shape
|
| 52 |
+
|
| 53 |
+
if fan == "fan_in":
|
| 54 |
+
f = fan_in
|
| 55 |
+
elif fan == "fan_out":
|
| 56 |
+
f = fan_out
|
| 57 |
+
elif fan == "fan_avg":
|
| 58 |
+
f = (fan_in + fan_out) / 2
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError("Invalid fan option")
|
| 61 |
+
|
| 62 |
+
return f
|
| 63 |
+
|
| 64 |
+
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
|
| 65 |
+
shape = weights.shape
|
| 66 |
+
f = _calculate_fan(shape, fan)
|
| 67 |
+
scale = scale / max(1, f)
|
| 68 |
+
a = -2
|
| 69 |
+
b = 2
|
| 70 |
+
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
|
| 71 |
+
size = _prod(shape)
|
| 72 |
+
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
|
| 73 |
+
samples = np.reshape(samples, shape)
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
weights.copy_(torch.tensor(samples, device=weights.device))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def lecun_normal_init_(weights):
|
| 79 |
+
trunc_normal_init_(weights, scale=1.0)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def he_normal_init_(weights):
|
| 83 |
+
trunc_normal_init_(weights, scale=2.0)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def glorot_uniform_init_(weights):
|
| 87 |
+
nn.init.xavier_uniform_(weights, gain=1)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def final_init_(weights):
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
weights.fill_(0.0)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def gating_init_(weights):
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
weights.fill_(0.0)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def normal_init_(weights):
|
| 101 |
+
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def compute_angles(ca_pos, pts):
|
| 105 |
+
batch_size, num_res, num_heads, num_pts, _ = pts.shape
|
| 106 |
+
calpha_vecs = (ca_pos[:, :, None, :] - ca_pos[:, None, :, :]) + 1e-10
|
| 107 |
+
calpha_vecs = torch.tile(calpha_vecs[:, :, :, None, None, :], (1, 1, 1, num_heads, num_pts, 1))
|
| 108 |
+
ipa_pts = pts[:, :, None, :, :, :] - torch.tile(ca_pos[:, :, None, None, None, :], (1, 1, num_res, num_heads, num_pts, 1))
|
| 109 |
+
phi_angles = all_atom.calculate_neighbor_angles(
|
| 110 |
+
calpha_vecs.reshape(-1, 3),
|
| 111 |
+
ipa_pts.reshape(-1, 3)
|
| 112 |
+
).reshape(batch_size, num_res, num_res, num_heads, num_pts)
|
| 113 |
+
return phi_angles
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Linear(nn.Linear):
|
| 117 |
+
"""
|
| 118 |
+
A Linear layer with built-in nonstandard initializations. Called just
|
| 119 |
+
like torch.nn.Linear.
|
| 120 |
+
|
| 121 |
+
Implements the initializers in 1.11.4, plus some additional ones found
|
| 122 |
+
in the code.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
in_dim: int,
|
| 128 |
+
out_dim: int,
|
| 129 |
+
bias: bool = True,
|
| 130 |
+
init: str = "default",
|
| 131 |
+
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Args:
|
| 135 |
+
in_dim:
|
| 136 |
+
The final dimension of inputs to the layer
|
| 137 |
+
out_dim:
|
| 138 |
+
The final dimension of layer outputs
|
| 139 |
+
bias:
|
| 140 |
+
Whether to learn an additive bias. True by default
|
| 141 |
+
init:
|
| 142 |
+
The initializer to use. Choose from:
|
| 143 |
+
|
| 144 |
+
"default": LeCun fan-in truncated normal initialization
|
| 145 |
+
"relu": He initialization w/ truncated normal distribution
|
| 146 |
+
"glorot": Fan-average Glorot uniform initialization
|
| 147 |
+
"gating": Weights=0, Bias=1
|
| 148 |
+
"normal": Normal initialization with std=1/sqrt(fan_in)
|
| 149 |
+
"final": Weights=0, Bias=0
|
| 150 |
+
|
| 151 |
+
Overridden by init_fn if the latter is not None.
|
| 152 |
+
init_fn:
|
| 153 |
+
A custom initializer taking weight and bias as inputs.
|
| 154 |
+
Overrides init if not None.
|
| 155 |
+
"""
|
| 156 |
+
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
|
| 157 |
+
|
| 158 |
+
if bias:
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
self.bias.fill_(0)
|
| 161 |
+
|
| 162 |
+
if init_fn is not None:
|
| 163 |
+
init_fn(self.weight, self.bias)
|
| 164 |
+
else:
|
| 165 |
+
if init == "default":
|
| 166 |
+
lecun_normal_init_(self.weight)
|
| 167 |
+
elif init == "relu":
|
| 168 |
+
he_normal_init_(self.weight)
|
| 169 |
+
elif init == "glorot":
|
| 170 |
+
glorot_uniform_init_(self.weight)
|
| 171 |
+
elif init == "gating":
|
| 172 |
+
gating_init_(self.weight)
|
| 173 |
+
if bias:
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
self.bias.fill_(1.0)
|
| 176 |
+
elif init == "normal":
|
| 177 |
+
normal_init_(self.weight)
|
| 178 |
+
elif init == "final":
|
| 179 |
+
final_init_(self.weight)
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError("Invalid init string.")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class StructureModuleTransition(nn.Module):
|
| 185 |
+
def __init__(self, c):
|
| 186 |
+
super(StructureModuleTransition, self).__init__()
|
| 187 |
+
|
| 188 |
+
self.c = c
|
| 189 |
+
|
| 190 |
+
self.linear_1 = Linear(self.c, self.c, init="relu")
|
| 191 |
+
self.linear_2 = Linear(self.c, self.c, init="relu")
|
| 192 |
+
self.linear_3 = Linear(self.c, self.c, init="final")
|
| 193 |
+
self.relu = nn.ReLU()
|
| 194 |
+
self.ln = nn.LayerNorm(self.c)
|
| 195 |
+
|
| 196 |
+
def forward(self, s):
|
| 197 |
+
s_initial = s
|
| 198 |
+
s = self.linear_1(s)
|
| 199 |
+
s = self.relu(s)
|
| 200 |
+
s = self.linear_2(s)
|
| 201 |
+
s = self.relu(s)
|
| 202 |
+
s = self.linear_3(s)
|
| 203 |
+
s = s + s_initial
|
| 204 |
+
s = self.ln(s)
|
| 205 |
+
|
| 206 |
+
return s
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class EdgeTransition(nn.Module):
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
*,
|
| 213 |
+
node_embed_size,
|
| 214 |
+
edge_embed_in,
|
| 215 |
+
edge_embed_out,
|
| 216 |
+
num_layers=2,
|
| 217 |
+
node_dilation=2
|
| 218 |
+
):
|
| 219 |
+
super(EdgeTransition, self).__init__()
|
| 220 |
+
|
| 221 |
+
bias_embed_size = node_embed_size // node_dilation
|
| 222 |
+
self.initial_embed = Linear(
|
| 223 |
+
node_embed_size, bias_embed_size, init="relu")
|
| 224 |
+
hidden_size = bias_embed_size * 2 + edge_embed_in
|
| 225 |
+
trunk_layers = []
|
| 226 |
+
for _ in range(num_layers):
|
| 227 |
+
trunk_layers.append(Linear(hidden_size, hidden_size, init="relu"))
|
| 228 |
+
trunk_layers.append(nn.ReLU())
|
| 229 |
+
self.trunk = nn.Sequential(*trunk_layers)
|
| 230 |
+
self.final_layer = Linear(hidden_size, edge_embed_out, init="final")
|
| 231 |
+
self.layer_norm = nn.LayerNorm(edge_embed_out)
|
| 232 |
+
|
| 233 |
+
def forward(self, node_embed, edge_embed):
|
| 234 |
+
node_embed = self.initial_embed(node_embed)
|
| 235 |
+
batch_size, num_res, _ = node_embed.shape
|
| 236 |
+
edge_bias = torch.cat([
|
| 237 |
+
torch.tile(node_embed[:, :, None, :], (1, 1, num_res, 1)),
|
| 238 |
+
torch.tile(node_embed[:, None, :, :], (1, num_res, 1, 1)),
|
| 239 |
+
], axis=-1)
|
| 240 |
+
edge_embed = torch.cat(
|
| 241 |
+
[edge_embed, edge_bias], axis=-1).reshape(
|
| 242 |
+
batch_size * num_res**2, -1)
|
| 243 |
+
edge_embed = self.final_layer(self.trunk(edge_embed) + edge_embed)
|
| 244 |
+
edge_embed = self.layer_norm(edge_embed)
|
| 245 |
+
edge_embed = edge_embed.reshape(
|
| 246 |
+
batch_size, num_res, num_res, -1
|
| 247 |
+
)
|
| 248 |
+
return edge_embed
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class InvariantPointAttention(nn.Module):
|
| 252 |
+
"""
|
| 253 |
+
Implements Algorithm 22.
|
| 254 |
+
"""
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
ipa_conf,
|
| 258 |
+
inf: float = 1e5,
|
| 259 |
+
eps: float = 1e-8,
|
| 260 |
+
):
|
| 261 |
+
"""
|
| 262 |
+
Args:
|
| 263 |
+
c_s:
|
| 264 |
+
Single representation channel dimension
|
| 265 |
+
c_z:
|
| 266 |
+
Pair representation channel dimension
|
| 267 |
+
c_hidden:
|
| 268 |
+
Hidden channel dimension
|
| 269 |
+
no_heads:
|
| 270 |
+
Number of attention heads
|
| 271 |
+
no_qk_points:
|
| 272 |
+
Number of query/key points to generate
|
| 273 |
+
no_v_points:
|
| 274 |
+
Number of value points to generate
|
| 275 |
+
"""
|
| 276 |
+
super(InvariantPointAttention, self).__init__()
|
| 277 |
+
self._ipa_conf = ipa_conf
|
| 278 |
+
|
| 279 |
+
self.c_s = ipa_conf.c_s
|
| 280 |
+
self.c_z = ipa_conf.c_z
|
| 281 |
+
self.c_hidden = ipa_conf.c_hidden
|
| 282 |
+
self.no_heads = ipa_conf.no_heads
|
| 283 |
+
self.no_qk_points = ipa_conf.no_qk_points
|
| 284 |
+
self.no_v_points = ipa_conf.no_v_points
|
| 285 |
+
self.inf = inf
|
| 286 |
+
self.eps = eps
|
| 287 |
+
|
| 288 |
+
# These linear layers differ from their specifications in the
|
| 289 |
+
# supplement. There, they lack bias and use Glorot initialization.
|
| 290 |
+
# Here as in the official source, they have bias and use the default
|
| 291 |
+
# Lecun initialization.
|
| 292 |
+
hc = self.c_hidden * self.no_heads
|
| 293 |
+
self.linear_q = Linear(self.c_s, hc)
|
| 294 |
+
self.linear_kv = Linear(self.c_s, 2 * hc)
|
| 295 |
+
|
| 296 |
+
hpq = self.no_heads * self.no_qk_points * 3
|
| 297 |
+
self.linear_q_points = Linear(self.c_s, hpq)
|
| 298 |
+
|
| 299 |
+
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
|
| 300 |
+
self.linear_kv_points = Linear(self.c_s, hpkv)
|
| 301 |
+
|
| 302 |
+
self.linear_b = Linear(self.c_z, self.no_heads)
|
| 303 |
+
self.down_z = Linear(self.c_z, self.c_z // 4)
|
| 304 |
+
|
| 305 |
+
self.head_weights = nn.Parameter(torch.zeros((ipa_conf.no_heads)))
|
| 306 |
+
ipa_point_weights_init_(self.head_weights)
|
| 307 |
+
|
| 308 |
+
concat_out_dim = (
|
| 309 |
+
self.c_z // 4 + self.c_hidden + self.no_v_points * 4
|
| 310 |
+
)
|
| 311 |
+
self.linear_out = Linear(self.no_heads * concat_out_dim, self.c_s, init="final")
|
| 312 |
+
|
| 313 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 314 |
+
self.softplus = nn.Softplus()
|
| 315 |
+
|
| 316 |
+
def forward(
|
| 317 |
+
self,
|
| 318 |
+
s: torch.Tensor,
|
| 319 |
+
z: Optional[torch.Tensor],
|
| 320 |
+
r: Rigid,
|
| 321 |
+
mask: torch.Tensor,
|
| 322 |
+
_offload_inference: bool = False,
|
| 323 |
+
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
|
| 324 |
+
) -> torch.Tensor:
|
| 325 |
+
"""
|
| 326 |
+
Args:
|
| 327 |
+
s:
|
| 328 |
+
[*, N_res, C_s] single representation
|
| 329 |
+
z:
|
| 330 |
+
[*, N_res, N_res, C_z] pair representation
|
| 331 |
+
r:
|
| 332 |
+
[*, N_res] transformation object
|
| 333 |
+
mask:
|
| 334 |
+
[*, N_res] mask
|
| 335 |
+
Returns:
|
| 336 |
+
[*, N_res, C_s] single representation update
|
| 337 |
+
"""
|
| 338 |
+
if _offload_inference:
|
| 339 |
+
z = _z_reference_list
|
| 340 |
+
else:
|
| 341 |
+
z = [z]
|
| 342 |
+
|
| 343 |
+
#######################################
|
| 344 |
+
# Generate scalar and point activations
|
| 345 |
+
#######################################
|
| 346 |
+
# [*, N_res, H * C_hidden]
|
| 347 |
+
q = self.linear_q(s)
|
| 348 |
+
kv = self.linear_kv(s)
|
| 349 |
+
|
| 350 |
+
# [*, N_res, H, C_hidden]
|
| 351 |
+
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
| 352 |
+
|
| 353 |
+
# [*, N_res, H, 2 * C_hidden]
|
| 354 |
+
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
|
| 355 |
+
|
| 356 |
+
# [*, N_res, H, C_hidden]
|
| 357 |
+
k, v = torch.split(kv, self.c_hidden, dim=-1)
|
| 358 |
+
|
| 359 |
+
# [*, N_res, H * P_q * 3]
|
| 360 |
+
q_pts = self.linear_q_points(s)
|
| 361 |
+
|
| 362 |
+
# This is kind of clunky, but it's how the original does it
|
| 363 |
+
# [*, N_res, H * P_q, 3]
|
| 364 |
+
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
|
| 365 |
+
q_pts = torch.stack(q_pts, dim=-1)
|
| 366 |
+
q_pts = r[..., None].apply(q_pts)
|
| 367 |
+
|
| 368 |
+
# [*, N_res, H, P_q, 3]
|
| 369 |
+
q_pts = q_pts.view(
|
| 370 |
+
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# [*, N_res, H * (P_q + P_v) * 3]
|
| 374 |
+
kv_pts = self.linear_kv_points(s)
|
| 375 |
+
|
| 376 |
+
# [*, N_res, H * (P_q + P_v), 3]
|
| 377 |
+
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
|
| 378 |
+
kv_pts = torch.stack(kv_pts, dim=-1)
|
| 379 |
+
kv_pts = r[..., None].apply(kv_pts)
|
| 380 |
+
|
| 381 |
+
# [*, N_res, H, (P_q + P_v), 3]
|
| 382 |
+
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
|
| 383 |
+
|
| 384 |
+
# [*, N_res, H, P_q/P_v, 3]
|
| 385 |
+
k_pts, v_pts = torch.split(
|
| 386 |
+
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
##########################
|
| 390 |
+
# Compute attention scores
|
| 391 |
+
##########################
|
| 392 |
+
# [*, N_res, N_res, H]
|
| 393 |
+
b = self.linear_b(z[0])
|
| 394 |
+
|
| 395 |
+
if(_offload_inference):
|
| 396 |
+
z[0] = z[0].cpu()
|
| 397 |
+
|
| 398 |
+
# [*, H, N_res, N_res]
|
| 399 |
+
a = torch.matmul(
|
| 400 |
+
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
|
| 401 |
+
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
|
| 402 |
+
)
|
| 403 |
+
a *= math.sqrt(1.0 / (3 * self.c_hidden))
|
| 404 |
+
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
|
| 405 |
+
|
| 406 |
+
# [*, N_res, N_res, H, P_q, 3]
|
| 407 |
+
pt_displacement = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
|
| 408 |
+
pt_att = pt_displacement ** 2
|
| 409 |
+
|
| 410 |
+
# [*, N_res, N_res, H, P_q]
|
| 411 |
+
pt_att = sum(torch.unbind(pt_att, dim=-1))
|
| 412 |
+
head_weights = self.softplus(self.head_weights).view(
|
| 413 |
+
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
|
| 414 |
+
)
|
| 415 |
+
head_weights = head_weights * math.sqrt(
|
| 416 |
+
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
|
| 417 |
+
)
|
| 418 |
+
pt_att = pt_att * head_weights
|
| 419 |
+
|
| 420 |
+
# [*, N_res, N_res, H]
|
| 421 |
+
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
|
| 422 |
+
# [*, N_res, N_res]
|
| 423 |
+
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
| 424 |
+
square_mask = self.inf * (square_mask - 1)
|
| 425 |
+
|
| 426 |
+
# [*, H, N_res, N_res]
|
| 427 |
+
pt_att = permute_final_dims(pt_att, (2, 0, 1))
|
| 428 |
+
|
| 429 |
+
a = a + pt_att
|
| 430 |
+
a = a + square_mask.unsqueeze(-3)
|
| 431 |
+
a = self.softmax(a)
|
| 432 |
+
|
| 433 |
+
################
|
| 434 |
+
# Compute output
|
| 435 |
+
################
|
| 436 |
+
# [*, N_res, H, C_hidden]
|
| 437 |
+
o = torch.matmul(
|
| 438 |
+
a, v.transpose(-2, -3)
|
| 439 |
+
).transpose(-2, -3)
|
| 440 |
+
|
| 441 |
+
# [*, N_res, H * C_hidden]
|
| 442 |
+
o = flatten_final_dims(o, 2)
|
| 443 |
+
|
| 444 |
+
# [*, H, 3, N_res, P_v]
|
| 445 |
+
o_pt = torch.sum(
|
| 446 |
+
(
|
| 447 |
+
a[..., None, :, :, None]
|
| 448 |
+
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
|
| 449 |
+
),
|
| 450 |
+
dim=-2,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# [*, N_res, H, P_v, 3]
|
| 454 |
+
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
|
| 455 |
+
o_pt = r[..., None, None].invert_apply(o_pt)
|
| 456 |
+
|
| 457 |
+
# [*, N_res, H * P_v]
|
| 458 |
+
o_pt_dists = torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps)
|
| 459 |
+
o_pt_norm_feats = flatten_final_dims(
|
| 460 |
+
o_pt_dists, 2)
|
| 461 |
+
|
| 462 |
+
# [*, N_res, H * P_v, 3]
|
| 463 |
+
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
|
| 464 |
+
|
| 465 |
+
if(_offload_inference):
|
| 466 |
+
z[0] = z[0].to(o_pt.device)
|
| 467 |
+
|
| 468 |
+
# [*, N_res, H, C_z // 4]
|
| 469 |
+
pair_z = self.down_z(z[0])
|
| 470 |
+
o_pair = torch.matmul(a.transpose(-2, -3), pair_z)
|
| 471 |
+
|
| 472 |
+
# [*, N_res, H * C_z // 4]
|
| 473 |
+
o_pair = flatten_final_dims(o_pair, 2)
|
| 474 |
+
|
| 475 |
+
o_feats = [o, *torch.unbind(o_pt, dim=-1), o_pt_norm_feats, o_pair]
|
| 476 |
+
|
| 477 |
+
# [*, N_res, C_s]
|
| 478 |
+
s = self.linear_out(
|
| 479 |
+
torch.cat(
|
| 480 |
+
o_feats, dim=-1
|
| 481 |
+
)
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
return s
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class TorsionAngles(nn.Module):
|
| 488 |
+
def __init__(self, c, num_torsions, eps=1e-8):
|
| 489 |
+
super(TorsionAngles, self).__init__()
|
| 490 |
+
|
| 491 |
+
self.c = c
|
| 492 |
+
self.eps = eps
|
| 493 |
+
self.num_torsions = num_torsions
|
| 494 |
+
|
| 495 |
+
self.linear_1 = Linear(self.c, self.c, init="relu")
|
| 496 |
+
self.linear_2 = Linear(self.c, self.c, init="relu")
|
| 497 |
+
# TODO: Remove after published checkpoint is updated without these weights.
|
| 498 |
+
self.linear_3 = Linear(self.c, self.c, init="final")
|
| 499 |
+
self.linear_final = Linear(
|
| 500 |
+
self.c, self.num_torsions * 2, init="final")
|
| 501 |
+
|
| 502 |
+
self.relu = nn.ReLU()
|
| 503 |
+
|
| 504 |
+
def forward(self, s):
|
| 505 |
+
s_initial = s
|
| 506 |
+
s = self.linear_1(s)
|
| 507 |
+
s = self.relu(s)
|
| 508 |
+
s = self.linear_2(s)
|
| 509 |
+
|
| 510 |
+
s = s + s_initial
|
| 511 |
+
unnormalized_s = self.linear_final(s)
|
| 512 |
+
norm_denom = torch.sqrt(
|
| 513 |
+
torch.clamp(
|
| 514 |
+
torch.sum(unnormalized_s ** 2, dim=-1, keepdim=True),
|
| 515 |
+
min=self.eps,
|
| 516 |
+
)
|
| 517 |
+
)
|
| 518 |
+
normalized_s = unnormalized_s / norm_denom
|
| 519 |
+
|
| 520 |
+
return unnormalized_s, normalized_s
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class RotationVFLayer(nn.Module):
|
| 524 |
+
def __init__(self, dim):
|
| 525 |
+
super(RotationVFLayer, self).__init__()
|
| 526 |
+
|
| 527 |
+
self.linear_1 = Linear(dim, dim, init="relu")
|
| 528 |
+
self.linear_2 = Linear(dim, dim, init="relu")
|
| 529 |
+
self.linear_3 = Linear(dim, dim)
|
| 530 |
+
self.final_linear = Linear(dim, 6, init="final")
|
| 531 |
+
self.relu = nn.ReLU()
|
| 532 |
+
|
| 533 |
+
def forward(self, s):
|
| 534 |
+
s_initial = s
|
| 535 |
+
s = self.linear_1(s)
|
| 536 |
+
s = self.relu(s)
|
| 537 |
+
s = self.linear_2(s)
|
| 538 |
+
s = self.relu(s)
|
| 539 |
+
s = self.linear_3(s)
|
| 540 |
+
s = s + s_initial
|
| 541 |
+
return self.final_linear(s)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class BackboneUpdate(nn.Module):
|
| 545 |
+
"""
|
| 546 |
+
Implements part of Algorithm 23.
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
def __init__(self, c_s, use_rot_updates):
|
| 550 |
+
"""
|
| 551 |
+
Args:
|
| 552 |
+
c_s:
|
| 553 |
+
Single representation channel dimension
|
| 554 |
+
"""
|
| 555 |
+
super(BackboneUpdate, self).__init__()
|
| 556 |
+
|
| 557 |
+
self.c_s = c_s
|
| 558 |
+
self._use_rot_updates = use_rot_updates
|
| 559 |
+
update_dim = 6 if use_rot_updates else 3
|
| 560 |
+
self.linear = Linear(self.c_s, update_dim, init="final")
|
| 561 |
+
|
| 562 |
+
def forward(self, s: torch.Tensor):
|
| 563 |
+
"""
|
| 564 |
+
Args:
|
| 565 |
+
[*, N_res, C_s] single representation
|
| 566 |
+
Returns:
|
| 567 |
+
[*, N_res, 6] update vector
|
| 568 |
+
"""
|
| 569 |
+
# [*, 6]
|
| 570 |
+
update = self.linear(s)
|
| 571 |
+
|
| 572 |
+
return update
|
| 573 |
+
|
| 574 |
+
class IpaScore(nn.Module):
|
| 575 |
+
|
| 576 |
+
def __init__(self, model_conf, diffuser):
|
| 577 |
+
super(IpaScore, self).__init__()
|
| 578 |
+
self._model_conf = model_conf
|
| 579 |
+
ipa_conf = model_conf.ipa
|
| 580 |
+
self._ipa_conf = ipa_conf
|
| 581 |
+
self.diffuser = diffuser
|
| 582 |
+
|
| 583 |
+
self.scale_pos = lambda x: x * ipa_conf.coordinate_scaling
|
| 584 |
+
self.scale_rigids = lambda x: x.apply_trans_fn(self.scale_pos)
|
| 585 |
+
|
| 586 |
+
self.unscale_pos = lambda x: x / ipa_conf.coordinate_scaling
|
| 587 |
+
self.unscale_rigids = lambda x: x.apply_trans_fn(self.unscale_pos)
|
| 588 |
+
self.trunk = nn.ModuleDict()
|
| 589 |
+
|
| 590 |
+
for b in range(ipa_conf.num_blocks):
|
| 591 |
+
self.trunk[f'ipa_{b}'] = InvariantPointAttention(ipa_conf)
|
| 592 |
+
self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(ipa_conf.c_s)
|
| 593 |
+
self.trunk[f'skip_embed_{b}'] = Linear(
|
| 594 |
+
self._model_conf.node_embed_size,
|
| 595 |
+
self._ipa_conf.c_skip,
|
| 596 |
+
init="final"
|
| 597 |
+
)
|
| 598 |
+
tfmr_in = ipa_conf.c_s + self._ipa_conf.c_skip
|
| 599 |
+
tfmr_layer = torch.nn.TransformerEncoderLayer(
|
| 600 |
+
d_model=tfmr_in,
|
| 601 |
+
nhead=ipa_conf.seq_tfmr_num_heads,
|
| 602 |
+
dim_feedforward=tfmr_in,
|
| 603 |
+
batch_first=True,
|
| 604 |
+
dropout=0.0,
|
| 605 |
+
norm_first=False
|
| 606 |
+
)
|
| 607 |
+
self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
|
| 608 |
+
tfmr_layer, ipa_conf.seq_tfmr_num_layers)
|
| 609 |
+
self.trunk[f'post_tfmr_{b}'] = Linear(
|
| 610 |
+
tfmr_in, ipa_conf.c_s, init="final")
|
| 611 |
+
self.trunk[f'node_transition_{b}'] = StructureModuleTransition(
|
| 612 |
+
c=ipa_conf.c_s)
|
| 613 |
+
self.trunk[f'bb_update_{b}'] = BackboneUpdate(ipa_conf.c_s)
|
| 614 |
+
|
| 615 |
+
if b < ipa_conf.num_blocks-1:
|
| 616 |
+
# No edge update on the last block.
|
| 617 |
+
edge_in = self._model_conf.edge_embed_size
|
| 618 |
+
self.trunk[f'edge_transition_{b}'] = EdgeTransition(
|
| 619 |
+
node_embed_size=ipa_conf.c_s,
|
| 620 |
+
edge_embed_in=edge_in,
|
| 621 |
+
edge_embed_out=self._model_conf.edge_embed_size,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
self.torsion_pred = TorsionAngles(ipa_conf.c_s, 1)
|
| 625 |
+
|
| 626 |
+
def forward(self, init_node_embed, edge_embed, input_feats):
|
| 627 |
+
node_mask = input_feats['res_mask'].type(torch.float32)
|
| 628 |
+
diffuse_mask = (1 - input_feats['fixed_mask'].type(torch.float32)) * node_mask
|
| 629 |
+
edge_mask = node_mask[..., None] * node_mask[..., None, :]
|
| 630 |
+
init_frames = input_feats['rigids_t'].type(torch.float32)
|
| 631 |
+
|
| 632 |
+
curr_rigids = Rigid.from_tensor_7(torch.clone(init_frames))
|
| 633 |
+
init_rigids = Rigid.from_tensor_7(init_frames)
|
| 634 |
+
init_rots = init_rigids.get_rots()
|
| 635 |
+
|
| 636 |
+
# Main trunk
|
| 637 |
+
curr_rigids = self.scale_rigids(curr_rigids)
|
| 638 |
+
init_node_embed = init_node_embed * node_mask[..., None]
|
| 639 |
+
node_embed = init_node_embed * node_mask[..., None]
|
| 640 |
+
for b in range(self._ipa_conf.num_blocks):
|
| 641 |
+
ipa_embed = self.trunk[f'ipa_{b}'](
|
| 642 |
+
node_embed,
|
| 643 |
+
edge_embed,
|
| 644 |
+
curr_rigids,
|
| 645 |
+
node_mask)
|
| 646 |
+
ipa_embed *= node_mask[..., None]
|
| 647 |
+
node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
|
| 648 |
+
seq_tfmr_in = torch.cat([
|
| 649 |
+
node_embed, self.trunk[f'skip_embed_{b}'](init_node_embed)
|
| 650 |
+
], dim=-1)
|
| 651 |
+
seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
|
| 652 |
+
seq_tfmr_in, src_key_padding_mask=1 - node_mask)
|
| 653 |
+
node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
|
| 654 |
+
node_embed = self.trunk[f'node_transition_{b}'](node_embed)
|
| 655 |
+
node_embed = node_embed * node_mask[..., None]
|
| 656 |
+
rigid_update = self.trunk[f'bb_update_{b}'](
|
| 657 |
+
node_embed * diffuse_mask[..., None])
|
| 658 |
+
curr_rigids = curr_rigids.compose_q_update_vec(
|
| 659 |
+
rigid_update, diffuse_mask[..., None])
|
| 660 |
+
|
| 661 |
+
if b < self._ipa_conf.num_blocks-1:
|
| 662 |
+
edge_embed = self.trunk[f'edge_transition_{b}'](
|
| 663 |
+
node_embed, edge_embed)
|
| 664 |
+
edge_embed *= edge_mask[..., None]
|
| 665 |
+
rot_score = self.diffuser.calc_rot_score(
|
| 666 |
+
init_rigids.get_rots(),
|
| 667 |
+
curr_rigids.get_rots(),
|
| 668 |
+
input_feats['t']
|
| 669 |
+
)
|
| 670 |
+
rot_score = rot_score * node_mask[..., None]
|
| 671 |
+
|
| 672 |
+
curr_rigids = self.unscale_rigids(curr_rigids)
|
| 673 |
+
trans_score = self.diffuser.calc_trans_score(
|
| 674 |
+
init_rigids.get_trans(),
|
| 675 |
+
curr_rigids.get_trans(),
|
| 676 |
+
input_feats['t'][:, None, None],
|
| 677 |
+
use_torch=True,
|
| 678 |
+
)
|
| 679 |
+
trans_score = trans_score * node_mask[..., None]
|
| 680 |
+
_, psi_pred = self.torsion_pred(node_embed)
|
| 681 |
+
model_out = {
|
| 682 |
+
'psi': psi_pred,
|
| 683 |
+
'rot_score': rot_score,
|
| 684 |
+
'trans_score': trans_score,
|
| 685 |
+
'final_rigids': curr_rigids,
|
| 686 |
+
}
|
| 687 |
+
return model_out
|
models_con/node.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from pepflow.modules.common.geometry import construct_3d_basis, global_to_local, get_backbone_dihedral_angles
|
| 5 |
+
from pepflow.modules.common.layers import AngularEncoding
|
| 6 |
+
from pepflow.modules.protein.constants import BBHeavyAtom, AA
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class NodeEmbedder(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self, feat_dim, max_num_atoms, max_aa_types=22):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.max_num_atoms = max_num_atoms
|
| 14 |
+
self.max_aa_types = max_aa_types
|
| 15 |
+
self.feat_dim = feat_dim
|
| 16 |
+
self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim)
|
| 17 |
+
self.dihed_embed = AngularEncoding()
|
| 18 |
+
|
| 19 |
+
infeat_dim = feat_dim + (self.max_aa_types*max_num_atoms*3) + self.dihed_embed.get_out_dim(3)
|
| 20 |
+
self.mlp = nn.Sequential(
|
| 21 |
+
nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(),
|
| 22 |
+
nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(),
|
| 23 |
+
nn.Linear(feat_dim, feat_dim), nn.ReLU(),
|
| 24 |
+
nn.Linear(feat_dim, feat_dim)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# def embed_t(self, timesteps, mask):
|
| 28 |
+
# timestep_emb = get_time_embedding(
|
| 29 |
+
# timesteps[:, 0],
|
| 30 |
+
# self.feat_dim,
|
| 31 |
+
# max_positions=2056
|
| 32 |
+
# )[:, None, :].repeat(1, mask.shape[1], 1)
|
| 33 |
+
# return timestep_emb
|
| 34 |
+
|
| 35 |
+
def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
aa: (N, L).
|
| 39 |
+
res_nb: (N, L).
|
| 40 |
+
chain_nb: (N, L).
|
| 41 |
+
pos_atoms: (N, L, A, 3).
|
| 42 |
+
mask_atoms: (N, L, A).
|
| 43 |
+
structure_mask: (N, L), mask out unknown structures to generate.
|
| 44 |
+
sequence_mask: (N, L), mask out unknown amino acids to generate.
|
| 45 |
+
"""
|
| 46 |
+
N, L = aa.size()
|
| 47 |
+
mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
|
| 48 |
+
|
| 49 |
+
# Remove other atoms
|
| 50 |
+
pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
|
| 51 |
+
mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
|
| 52 |
+
|
| 53 |
+
# Amino acid identity features
|
| 54 |
+
if sequence_mask is not None:
|
| 55 |
+
# Avoid data leakage at training time
|
| 56 |
+
aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
|
| 57 |
+
aa_feat = self.aatype_embed(aa) # (N, L, feat)
|
| 58 |
+
|
| 59 |
+
# Coordinate features
|
| 60 |
+
R = construct_3d_basis(
|
| 61 |
+
pos_atoms[:, :, BBHeavyAtom.CA],
|
| 62 |
+
pos_atoms[:, :, BBHeavyAtom.C],
|
| 63 |
+
pos_atoms[:, :, BBHeavyAtom.N]
|
| 64 |
+
)
|
| 65 |
+
t = pos_atoms[:, :, BBHeavyAtom.CA]
|
| 66 |
+
crd = global_to_local(R, t, pos_atoms) # (N, L, A, 3)
|
| 67 |
+
crd_mask = mask_atoms[:, :, :, None].expand_as(crd)
|
| 68 |
+
crd = torch.where(crd_mask, crd, torch.zeros_like(crd))
|
| 69 |
+
|
| 70 |
+
aa_expand = aa[:, :, None, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
|
| 71 |
+
rng_expand = torch.arange(0, self.max_aa_types)[None, None, :, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3).to(aa_expand)
|
| 72 |
+
place_mask = (aa_expand == rng_expand)
|
| 73 |
+
crd_expand = crd[:, :, None, :, :].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
|
| 74 |
+
crd_expand = torch.where(place_mask, crd_expand, torch.zeros_like(crd_expand))
|
| 75 |
+
crd_feat = crd_expand.reshape(N, L, self.max_aa_types*self.max_num_atoms*3)
|
| 76 |
+
if structure_mask is not None:
|
| 77 |
+
# Avoid data leakage at training time
|
| 78 |
+
crd_feat = crd_feat * structure_mask[:, :, None]
|
| 79 |
+
|
| 80 |
+
# Backbone dihedral features
|
| 81 |
+
bb_dihedral, mask_bb_dihed = get_backbone_dihedral_angles(pos_atoms, chain_nb=chain_nb, res_nb=res_nb, mask=mask_residue)
|
| 82 |
+
dihed_feat = self.dihed_embed(bb_dihedral[:, :, :, None]) * mask_bb_dihed[:, :, :, None] # (N, L, 3, dihed/3)
|
| 83 |
+
dihed_feat = dihed_feat.reshape(N, L, -1)
|
| 84 |
+
if structure_mask is not None:
|
| 85 |
+
# Avoid data leakage at training time
|
| 86 |
+
dihed_mask = torch.logical_and(
|
| 87 |
+
structure_mask,
|
| 88 |
+
torch.logical_and(
|
| 89 |
+
torch.roll(structure_mask, shifts=+1, dims=1),
|
| 90 |
+
torch.roll(structure_mask, shifts=-1, dims=1)
|
| 91 |
+
),
|
| 92 |
+
) # Avoid slight data leakage via dihedral angles of anchor residues
|
| 93 |
+
dihed_feat = dihed_feat * dihed_mask[:, :, None]
|
| 94 |
+
|
| 95 |
+
# # timestep
|
| 96 |
+
# timestep_emb = self.embed_t(timesteps, mask_residue)
|
| 97 |
+
|
| 98 |
+
out_feat = self.mlp(torch.cat([aa_feat, crd_feat, dihed_feat], dim=-1)) # (N, L, F)
|
| 99 |
+
out_feat = out_feat * mask_residue[:, :, None]
|
| 100 |
+
|
| 101 |
+
# print(f'aa_seq:{aa},aa:{aa_feat},crd:{crd_feat},dihed:{dihed_feat},time:{timestep_emb}')
|
| 102 |
+
|
| 103 |
+
# print(f'weight:{self.aatype_embed.weight}') # nan, why?
|
| 104 |
+
|
| 105 |
+
return out_feat
|
models_con/pep_dataloader.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""pep-rec dataset"""
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import joblib
|
| 5 |
+
import pickle
|
| 6 |
+
import lmdb
|
| 7 |
+
from Bio import PDB
|
| 8 |
+
from Bio.PDB import PDBExceptions
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
|
| 12 |
+
from pepflow.modules.protein.parsers import parse_pdb
|
| 13 |
+
from pepflow.modules.common.geometry import *
|
| 14 |
+
from pepflow.modules.protein.constants import *
|
| 15 |
+
from pepflow.utils.data import mask_select_data, find_longest_true_segment, PaddingCollate
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
from easydict import EasyDict
|
| 20 |
+
|
| 21 |
+
from torch.utils.data import DataLoader, Dataset
|
| 22 |
+
from torch.utils.data.distributed import DistributedSampler, dist
|
| 23 |
+
|
| 24 |
+
from pepflow.utils.misc import load_config
|
| 25 |
+
from pepflow.utils.train import recursive_to
|
| 26 |
+
|
| 27 |
+
from models_con.torsion import get_torsion_angle
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
from pepflow.modules.protein.writers import save_pdb
|
| 32 |
+
|
| 33 |
+
# bind_dic = torch.load("/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/misc/affinity_dict.pt")
|
| 34 |
+
|
| 35 |
+
# testset
|
| 36 |
+
names = []
|
| 37 |
+
with open('/datapool/data2/home/ruihan/data/jiahan/ResProj/PepDiff/pepflowww/Data/names.txt','r') as f:
|
| 38 |
+
for line in f:
|
| 39 |
+
names.append(line.strip())
|
| 40 |
+
|
| 41 |
+
def preprocess_structure(task):
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
if task['id'] in names:
|
| 45 |
+
raise ValueError(f'{task["id"]} not in names')
|
| 46 |
+
pdb_path = task['pdb_path']
|
| 47 |
+
# pep
|
| 48 |
+
# process peptide and find center of mass
|
| 49 |
+
pep = parse_pdb(os.path.join(pdb_path,'peptide.pdb'))[0]
|
| 50 |
+
center = torch.sum(pep['pos_heavyatom'][pep['mask_heavyatom'][:, BBHeavyAtom.CA], BBHeavyAtom.CA], dim=0) / (torch.sum(pep['mask_heavyatom'][:, BBHeavyAtom.CA]) + 1e-8)
|
| 51 |
+
pep['pos_heavyatom'] = pep['pos_heavyatom'] - center[None, None, :]
|
| 52 |
+
pep['torsion_angle'],pep['torsion_angle_mask'] = get_torsion_angle(pep['pos_heavyatom'],pep['aa']) # calc angles after translation
|
| 53 |
+
if len(pep['aa'])<3 or len(pep['aa'])>25:
|
| 54 |
+
raise ValueError('peptide length not in [3,25]')
|
| 55 |
+
# rec
|
| 56 |
+
rec = parse_pdb(os.path.join(pdb_path,'pocket.pdb'))[0]
|
| 57 |
+
rec['pos_heavyatom'] = rec['pos_heavyatom'] - center[None, None, :]
|
| 58 |
+
rec['torsion_angle'],rec['torsion_angle_mask'] = get_torsion_angle(rec['pos_heavyatom'],rec['aa']) # calc angles after translation
|
| 59 |
+
rec['chain_nb'] += 1
|
| 60 |
+
# meta data
|
| 61 |
+
data = {}
|
| 62 |
+
data['id'] = task['id']
|
| 63 |
+
data['generate_mask'] = torch.cat([torch.zeros_like(rec['aa']), torch.ones_like(pep['aa'])], dim=0).bool()
|
| 64 |
+
for k in rec.keys():
|
| 65 |
+
if isinstance(rec[k], torch.Tensor):
|
| 66 |
+
data[k] = torch.cat([rec[k], pep[k]], dim=0)
|
| 67 |
+
elif isinstance(rec[k], list):
|
| 68 |
+
data[k] = rec[k] + pep[k]
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(f'Unknown type of {rec[k]}')
|
| 71 |
+
return data
|
| 72 |
+
|
| 73 |
+
except (
|
| 74 |
+
PDBExceptions.PDBConstructionException,
|
| 75 |
+
KeyError,
|
| 76 |
+
ValueError,
|
| 77 |
+
TypeError
|
| 78 |
+
) as e:
|
| 79 |
+
logging.warning('[{}] {}: {}'.format(
|
| 80 |
+
task['id'],
|
| 81 |
+
e.__class__.__name__,
|
| 82 |
+
str(e)
|
| 83 |
+
))
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class PepDataset(Dataset):
|
| 88 |
+
|
| 89 |
+
MAP_SIZE = 32*(1024*1024*1024) # 32GB
|
| 90 |
+
|
| 91 |
+
def __init__(self, structure_dir = "./Data/PepMerge_new/", dataset_dir = "./Data/",
|
| 92 |
+
name = 'pep', transform=None, reset=False):
|
| 93 |
+
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.structure_dir = structure_dir
|
| 96 |
+
self.dataset_dir = dataset_dir
|
| 97 |
+
self.transform = transform
|
| 98 |
+
self.name = name
|
| 99 |
+
|
| 100 |
+
self.db_conn = None
|
| 101 |
+
self.db_ids = None
|
| 102 |
+
self._load_structures(reset)
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def _cache_db_path(self):
|
| 106 |
+
return os.path.join(self.dataset_dir, f'{self.name}_structure_cache.lmdb')
|
| 107 |
+
|
| 108 |
+
def _connect_db(self):
|
| 109 |
+
self._close_db()
|
| 110 |
+
self.db_conn = lmdb.open(
|
| 111 |
+
self._cache_db_path,
|
| 112 |
+
map_size=self.MAP_SIZE,
|
| 113 |
+
create=False,
|
| 114 |
+
subdir=False,
|
| 115 |
+
readonly=True,
|
| 116 |
+
lock=False,
|
| 117 |
+
readahead=False,
|
| 118 |
+
meminit=False,
|
| 119 |
+
)
|
| 120 |
+
with self.db_conn.begin() as txn:
|
| 121 |
+
keys = [k.decode() for k in txn.cursor().iternext(values=False)]
|
| 122 |
+
self.db_ids = keys
|
| 123 |
+
|
| 124 |
+
def _close_db(self):
|
| 125 |
+
if self.db_conn is not None:
|
| 126 |
+
self.db_conn.close()
|
| 127 |
+
self.db_conn = None
|
| 128 |
+
self.db_ids = None
|
| 129 |
+
|
| 130 |
+
def _load_structures(self, reset):
|
| 131 |
+
all_pdbs = os.listdir(self.structure_dir)
|
| 132 |
+
|
| 133 |
+
if reset:
|
| 134 |
+
if os.path.exists(self._cache_db_path):
|
| 135 |
+
os.remove(self._cache_db_path)
|
| 136 |
+
lock_file = self._cache_db_path + "-lock"
|
| 137 |
+
if os.path.exists(lock_file):
|
| 138 |
+
os.remove(lock_file)
|
| 139 |
+
self._close_db()
|
| 140 |
+
todo_pdbs = all_pdbs
|
| 141 |
+
else:
|
| 142 |
+
if not os.path.exists(self._cache_db_path):
|
| 143 |
+
todo_pdbs = all_pdbs
|
| 144 |
+
else:
|
| 145 |
+
todo_pdbs = []
|
| 146 |
+
# self._connect_db()
|
| 147 |
+
# processed_pdbs = self.db_ids
|
| 148 |
+
# self._close_db()
|
| 149 |
+
# todo_pdbs = list(set(all_pdbs) - set(processed_pdbs))
|
| 150 |
+
|
| 151 |
+
if len(todo_pdbs) > 0:
|
| 152 |
+
self._preprocess_structures(todo_pdbs)
|
| 153 |
+
|
| 154 |
+
def _preprocess_structures(self, pdb_list):
|
| 155 |
+
tasks = []
|
| 156 |
+
for pdb_fname in pdb_list:
|
| 157 |
+
pdb_path = os.path.join(self.structure_dir, pdb_fname)
|
| 158 |
+
tasks.append({
|
| 159 |
+
'id': pdb_fname,
|
| 160 |
+
'pdb_path': pdb_path,
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
data_list = joblib.Parallel(
|
| 164 |
+
n_jobs = max(joblib.cpu_count() // 2, 1),
|
| 165 |
+
)(
|
| 166 |
+
joblib.delayed(preprocess_structure)(task)
|
| 167 |
+
for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess')
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
db_conn = lmdb.open(
|
| 171 |
+
self._cache_db_path,
|
| 172 |
+
map_size = self.MAP_SIZE,
|
| 173 |
+
create=True,
|
| 174 |
+
subdir=False,
|
| 175 |
+
readonly=False,
|
| 176 |
+
)
|
| 177 |
+
ids = []
|
| 178 |
+
with db_conn.begin(write=True, buffers=True) as txn:
|
| 179 |
+
for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'):
|
| 180 |
+
if data is None:
|
| 181 |
+
continue
|
| 182 |
+
ids.append(data['id'])
|
| 183 |
+
txn.put(data['id'].encode('utf-8'), pickle.dumps(data))
|
| 184 |
+
|
| 185 |
+
def __len__(self):
|
| 186 |
+
self._connect_db() # make sure db_ids is not None
|
| 187 |
+
return len(self.db_ids)
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, index):
|
| 190 |
+
self._connect_db()
|
| 191 |
+
id = self.db_ids[index]
|
| 192 |
+
with self.db_conn.begin() as txn:
|
| 193 |
+
data = pickle.loads(txn.get(id.encode()))
|
| 194 |
+
if self.transform is not None:
|
| 195 |
+
data = self.transform(data)
|
| 196 |
+
return data
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if __name__ == '__main__':
|
| 201 |
+
device = 'cuda:1'
|
| 202 |
+
config,cfg_name = load_config("./configs/learn/learn_all.yaml")
|
| 203 |
+
dataset = PepDataset(structure_dir = "./Data/PepMerge_new/", dataset_dir = "/Data/Fixed Data",
|
| 204 |
+
name = 'pep_pocket_test', transform=None, reset=True)
|
| 205 |
+
print(len(dataset))
|
| 206 |
+
print(dataset[0])
|
| 207 |
+
|
| 208 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=PaddingCollate(eight=False))
|
| 209 |
+
|
| 210 |
+
batch = next(iter(dataloader))
|
| 211 |
+
print(batch['torsion_angle'].shape)
|
| 212 |
+
print(batch['torsion_angle_mask'].shape)
|
models_con/sample.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
import math
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
import functools
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
|
| 16 |
+
from models_con.pep_dataloader import PepDataset
|
| 17 |
+
|
| 18 |
+
from pepflow.utils.train import recursive_to
|
| 19 |
+
|
| 20 |
+
from pepflow.modules.common.geometry import reconstruct_backbone, reconstruct_backbone_partially, align, batch_align
|
| 21 |
+
from pepflow.modules.protein.writers import save_pdb
|
| 22 |
+
|
| 23 |
+
from pepflow.utils.data import PaddingCollate
|
| 24 |
+
|
| 25 |
+
from models_con.utils import process_dic
|
| 26 |
+
|
| 27 |
+
from models_con.flow_model import FlowModel
|
| 28 |
+
|
| 29 |
+
from models_con.torsion import full_atom_reconstruction, get_heavyatom_mask
|
| 30 |
+
|
| 31 |
+
collate_fn = PaddingCollate(eight=False)
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def item_to_batch(item, nums=32):
|
| 37 |
+
data_list = [deepcopy(item) for i in range(nums)]
|
| 38 |
+
return collate_fn(data_list)
|
| 39 |
+
|
| 40 |
+
def sample_for_data_bb(data, model, device, save_root, num_steps=200, sample_structure=True, sample_sequence=True, nums=8):
|
| 41 |
+
if not os.path.exists(os.path.join(save_root,data["id"])):
|
| 42 |
+
os.makedirs(os.path.join(save_root,data["id"]))
|
| 43 |
+
batch = recursive_to(item_to_batch(data, nums=nums),device=device)
|
| 44 |
+
traj = model.sample(batch, num_steps=num_steps, sample_structure=sample_structure, sample_sequence=sample_sequence)
|
| 45 |
+
final = recursive_to(traj[-1], device=device)
|
| 46 |
+
pos_bb = reconstruct_backbone(R=final['rotmats'],t=final['trans'],aa=final['seqs'],chain_nb=batch['chain_nb'],res_nb=batch['res_nb'],mask=batch['res_mask']) # (32,L,4,3)
|
| 47 |
+
pos_ha = F.pad(pos_bb, pad=(0,0,0,15-4), value=0.) # (32,L,A,3) pos14 A=14
|
| 48 |
+
pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
|
| 49 |
+
mask_bb_atoms = torch.zeros_like(batch['mask_heavyatom'])
|
| 50 |
+
mask_bb_atoms[:,:,:4] = True
|
| 51 |
+
mask_new = torch.where(batch['generate_mask'][:,:,None],mask_bb_atoms,batch['mask_heavyatom'])
|
| 52 |
+
aa_new = final['seqs']
|
| 53 |
+
|
| 54 |
+
chain_nb = torch.LongTensor([0 if gen_mask else 1 for gen_mask in data['generate_mask']])
|
| 55 |
+
chain_id = ['A' if gen_mask else 'B' for gen_mask in data['generate_mask']]
|
| 56 |
+
icode = [' ' for _ in range(len(data['icode']))]
|
| 57 |
+
for i in range(nums):
|
| 58 |
+
ref_bb_pos = data['pos_heavyatom'][i][:,:4].cpu()
|
| 59 |
+
pred_bb_pos = pos_new[i][:,:4].cpu()
|
| 60 |
+
data_saved = {
|
| 61 |
+
'chain_nb':data['chain_nb'],'chain_id':data['chain_id'],'resseq':data['resseq'],'icode':data['icode'],
|
| 62 |
+
'aa':aa_new[i].cpu(), 'mask_heavyatom':mask_new[i].cpu(), 'pos_heavyatom':pos_new[i].cpu(),
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
save_pdb(data_saved,path=os.path.join(save_root,data["id"],f'{data["id"]}_{i}.pdb'))
|
| 66 |
+
save_pdb(data,path=os.path.join(save_root,data["id"],f'{data["id"]}_gt.pdb'))
|
| 67 |
+
|
| 68 |
+
def save_samples_bb(samples,save_dir):
|
| 69 |
+
# meta data
|
| 70 |
+
batch = recursive_to(samples['batch'],'cpu')
|
| 71 |
+
chain_id = [list(item) for item in zip(*batch['chain_id'])][0] # fix chain id in collate func
|
| 72 |
+
icode = [' ' for _ in range(len(chain_id))] # batch icode have same problem
|
| 73 |
+
nums = len(batch['id'])
|
| 74 |
+
id = batch['id'][0]
|
| 75 |
+
# batch convert
|
| 76 |
+
# aa=batch['aa] if only bb level
|
| 77 |
+
pos_bb = reconstruct_backbone(R=samples['rotmats'],t=samples['trans'],aa=samples['seqs'],chain_nb=batch['chain_nb'],res_nb=batch['res_nb'],mask=batch['res_mask']) # (32,L,4,3)
|
| 78 |
+
pos_ha = F.pad(pos_bb, pad=(0,0,0,15-4), value=0.) # (32,L,A,3) pos14 A=14
|
| 79 |
+
pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
|
| 80 |
+
mask_bb_atoms = torch.zeros_like(batch['mask_heavyatom'])
|
| 81 |
+
mask_bb_atoms[:,:,:4] = True
|
| 82 |
+
mask_new = torch.where(batch['generate_mask'][:,:,None],mask_bb_atoms,batch['mask_heavyatom'])
|
| 83 |
+
aa_new = samples['seqs']
|
| 84 |
+
for i in range(nums):
|
| 85 |
+
data_saved = {
|
| 86 |
+
'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
|
| 87 |
+
'aa':aa_new[i], 'mask_heavyatom':mask_new[i], 'pos_heavyatom':pos_new[i],
|
| 88 |
+
}
|
| 89 |
+
save_pdb(data_saved,path=os.path.join(save_dir,f'sample_{i}.pdb'))
|
| 90 |
+
data_saved = {
|
| 91 |
+
'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
|
| 92 |
+
'aa':batch['aa'][0], 'mask_heavyatom':batch['mask_heavyatom'][0], 'pos_heavyatom':batch['pos_heavyatom'][0],
|
| 93 |
+
}
|
| 94 |
+
save_pdb(data_saved,path=os.path.join(save_dir,f'gt.pdb'))
|
| 95 |
+
|
| 96 |
+
def save_samples_sc(samples,save_dir):
|
| 97 |
+
# meta data
|
| 98 |
+
batch = recursive_to(samples['batch'],'cpu')
|
| 99 |
+
chain_id = [list(item) for item in zip(*batch['chain_id'])][0] # fix chain id in collate func
|
| 100 |
+
icode = [' ' for _ in range(len(chain_id))] # batch icode have same problem
|
| 101 |
+
nums = len(batch['id'])
|
| 102 |
+
id = batch['id'][0]
|
| 103 |
+
# batch convert
|
| 104 |
+
# aa=batch['aa] if only bb level
|
| 105 |
+
pos_ha,_,_ = full_atom_reconstruction(R_bb=samples['rotmats'],t_bb=samples['trans'],angles=samples['angles'],aa=samples['seqs']) # (32,L,14,3), instead of 15, ignore OXT masked
|
| 106 |
+
pos_ha = F.pad(pos_ha, pad=(0,0,0,15-14), value=0.) # (32,L,A,3) pos14 A=14
|
| 107 |
+
pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
|
| 108 |
+
mask_new = get_heavyatom_mask(samples['seqs'])
|
| 109 |
+
aa_new = samples['seqs']
|
| 110 |
+
for i in range(nums):
|
| 111 |
+
data_saved = {
|
| 112 |
+
'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
|
| 113 |
+
'aa':aa_new[i], 'mask_heavyatom':mask_new[i], 'pos_heavyatom':pos_new[i],
|
| 114 |
+
}
|
| 115 |
+
save_pdb(data_saved,path=os.path.join(save_dir,f'sample_{i}.pdb'))
|
| 116 |
+
data_saved = {
|
| 117 |
+
'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
|
| 118 |
+
'aa':batch['aa'][0], 'mask_heavyatom':batch['mask_heavyatom'][0], 'pos_heavyatom':batch['pos_heavyatom'][0],
|
| 119 |
+
}
|
| 120 |
+
save_pdb(data_saved,path=os.path.join(save_dir,f'gt.pdb'))
|
| 121 |
+
|
| 122 |
+
if __name__ == '__main__':
|
| 123 |
+
# sample = torch.load('./Codesign/outputs/1aze_B.pt')
|
| 124 |
+
# save_samples_sc(sample,'./misc/test')
|
| 125 |
+
# save_samples_bb(sample,'./misc/test')
|
| 126 |
+
# for k,v in sample.items():
|
| 127 |
+
# if isinstance(v,torch.Tensor):
|
| 128 |
+
# print(f'{k},{v.shape}')
|
| 129 |
+
|
| 130 |
+
# # subdir = 'bb_seq_angle' # bb,bb_seq,bb_seq_angle
|
| 131 |
+
# names = [n.split('.')[0] for n in os.listdir(os.path.join(SAMPLE_DIR,subdir,'outputs'))]
|
| 132 |
+
# for name in tqdm(names):
|
| 133 |
+
# sample = torch.load(os.path.join(SAMPLE_DIR,subdir,'outputs',f'{name}.pt'))
|
| 134 |
+
# os.makedirs(os.path.join(SAMPLE_DIR,subdir,'pdbs',name),exist_ok=True)
|
| 135 |
+
# save_samples_sc(sample,os.path.join(SAMPLE_DIR,subdir,'pdbs',name))
|
| 136 |
+
|
| 137 |
+
args = argparse.ArgumentParser()
|
| 138 |
+
args.add_argument('--SAMPLEDIR', type=str)
|
| 139 |
+
parser = args.parse_args()
|
| 140 |
+
SAMPLE_DIR = parser.SAMPLEDIR
|
| 141 |
+
names = [n.split('.')[0] for n in os.listdir(os.path.join(SAMPLE_DIR,'outputs'))]
|
| 142 |
+
for name in tqdm(names):
|
| 143 |
+
sample = torch.load(os.path.join(SAMPLE_DIR,'outputs',f'{name}.pt'))
|
| 144 |
+
os.makedirs(os.path.join(SAMPLE_DIR,'pdbs',name),exist_ok=True)
|
| 145 |
+
save_samples_sc(sample,os.path.join(SAMPLE_DIR,'pdbs',name))
|
models_con/torsion.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from typing import Any, Optional, Union, cast
|
| 5 |
+
|
| 6 |
+
from pepflow.modules.common.geometry import *
|
| 7 |
+
import pepflow.modules.protein.constants as constants
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
calc torsion angles between (0,2pi)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def _get_torsion(p0, p1, p2, p3):
|
| 14 |
+
"""
|
| 15 |
+
Args:
|
| 16 |
+
p0-3: (*, 3).
|
| 17 |
+
Returns:
|
| 18 |
+
Dihedral angles in radian, (*, ).
|
| 19 |
+
"""
|
| 20 |
+
v0 = p2 - p1
|
| 21 |
+
v1 = p0 - p1
|
| 22 |
+
v2 = p3 - p2
|
| 23 |
+
u1 = torch.cross(v0, v1, dim=-1)
|
| 24 |
+
n1 = u1 / torch.linalg.norm(u1, dim=-1, keepdim=True)
|
| 25 |
+
u2 = torch.cross(v0, v2, dim=-1)
|
| 26 |
+
n2 = u2 / torch.linalg.norm(u2, dim=-1, keepdim=True)
|
| 27 |
+
sgn = torch.sign( (torch.cross(v1, v2, dim=-1) * v0).sum(-1) )
|
| 28 |
+
dihed = sgn*torch.acos( (n1 * n2).sum(-1).clamp(min=-0.999999, max=0.999999))
|
| 29 |
+
return dihed
|
| 30 |
+
|
| 31 |
+
def get_chi_angles(restype, pos14):
|
| 32 |
+
chi_angles = torch.full([4], fill_value=float("inf")).to(pos14)
|
| 33 |
+
base_atom_names = constants.chi_angles_atoms[restype]
|
| 34 |
+
for i, four_atom_names in enumerate(base_atom_names):
|
| 35 |
+
atom_indices = [constants.restype_atom14_name_to_index[restype][a] for a in four_atom_names]
|
| 36 |
+
p = torch.stack([pos14[i] for i in atom_indices])
|
| 37 |
+
# if torch.eq(p, 99999).any():
|
| 38 |
+
# continue
|
| 39 |
+
torsion = _get_torsion(*torch.unbind(p, dim=0))
|
| 40 |
+
chi_angles[i] = torsion
|
| 41 |
+
return chi_angles
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_psi_angle(pos14: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
return _get_torsion(pos14[0], pos14[1], pos14[2], pos14[3]).reshape([1]) # af style psi, N,CA,C,O
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_torsion_angle(pos14: torch.Tensor, aa: torch.LongTensor):
|
| 49 |
+
torsion, torsion_mask = [], []
|
| 50 |
+
for i in range(pos14.shape[0]):
|
| 51 |
+
if aa[i] < constants.AA.UNK: # 0-19
|
| 52 |
+
chi = get_chi_angles(aa[i].item(), pos14[i])
|
| 53 |
+
psi = get_psi_angle(pos14[i])
|
| 54 |
+
torsion_this = torch.cat([psi, chi], dim=0)
|
| 55 |
+
torsion_mask_this = torsion_this.isfinite()
|
| 56 |
+
else:
|
| 57 |
+
torsion_this = torch.full([5], 0.)
|
| 58 |
+
torsion_mask_this = torch.full([5], False)
|
| 59 |
+
torsion.append(torsion_this.nan_to_num(posinf=0.))
|
| 60 |
+
torsion_mask.append(torsion_mask_this)
|
| 61 |
+
|
| 62 |
+
torsion = torch.stack(torsion) % (2*math.pi)
|
| 63 |
+
torsion_mask = torch.stack(torsion_mask).bool()
|
| 64 |
+
|
| 65 |
+
return torsion, torsion_mask
|
| 66 |
+
|
| 67 |
+
def _make_psi_chi_rotation_matrices(angles: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
"""Compute psi and chi rotation matrices from torsional angles.
|
| 69 |
+
|
| 70 |
+
Here we provide angles instead of alpha in af2 between (0,2pi)
|
| 71 |
+
|
| 72 |
+
See alphafold supplementary Algorithm 25 for details.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
angles: (B, N, 5), angles between (0,2pi)
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Torsional angle rotation matrices, (B, N, 5, 3, 3).
|
| 79 |
+
"""
|
| 80 |
+
batch_size, n_res = angles.shape[:2]
|
| 81 |
+
sine,cosine = torch.sin(angles), torch.cos(angles)
|
| 82 |
+
sine = sine.reshape(batch_size, n_res, -1, 1, 1)
|
| 83 |
+
cosine = cosine.reshape(batch_size, n_res, -1, 1, 1)
|
| 84 |
+
zero = torch.zeros_like(sine)
|
| 85 |
+
one = torch.ones_like(sine)
|
| 86 |
+
|
| 87 |
+
row1 = torch.cat([one, zero, zero], dim=-1) # (B, N, 5, 1, 3)
|
| 88 |
+
row2 = torch.cat([zero, cosine, -sine], dim=-1) # (B, N, 5, 1, 3)
|
| 89 |
+
row3 = torch.cat([zero, sine, cosine], dim=-1) # (B, N, 5, 1, 3)
|
| 90 |
+
R = torch.cat([row1, row2, row3], dim=-2) # (B, N, 5, 3, 3)
|
| 91 |
+
|
| 92 |
+
return R
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _get_rigid_group(aa: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 96 |
+
"""Extract rigid group constants.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
aa: Amino acid types, (B, N).
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
A tuple of rigid group rotation, translation, atom14 group and atom14 position.
|
| 103 |
+
"""
|
| 104 |
+
batch_size, n_res = aa.size()
|
| 105 |
+
aa = aa.flatten()
|
| 106 |
+
rotation = constants.restype_rigid_group_rotation.to(aa.device)[aa].reshape(batch_size, n_res, 8, 3, 3)
|
| 107 |
+
translation = constants.restype_rigid_group_translation.to(aa.device)[aa].reshape(batch_size, n_res, 8, 3)
|
| 108 |
+
atom14_group = constants.restype_heavyatom_to_rigid_group.to(aa.device)[aa].reshape(batch_size, n_res, 14)
|
| 109 |
+
atom14_position = constants.restype_heavyatom_rigid_group_positions.to(aa.device)[aa].reshape(
|
| 110 |
+
batch_size, n_res, 14, 3
|
| 111 |
+
)
|
| 112 |
+
return rotation, translation, atom14_group, atom14_position
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# construct heavy atom masks for genrating
|
| 116 |
+
# restype_to_heavyatom_masks = {
|
| 117 |
+
# restype: [name != "" and name !='OXT' for name in names]
|
| 118 |
+
# for restype, names in constants.restype_to_heavyatom_names.items()
|
| 119 |
+
# }
|
| 120 |
+
# print(restype_to_heavyatom_masks[0])
|
| 121 |
+
|
| 122 |
+
restype_to_heavyatom_masks = torch.zeros([22,15]).bool()
|
| 123 |
+
for i in range(21):
|
| 124 |
+
restype_to_heavyatom_masks[i] = torch.tensor([name != "" and name !='OXT' for name in constants.restype_to_heavyatom_names[i]]).bool()
|
| 125 |
+
|
| 126 |
+
def get_heavyatom_mask(aa: torch.Tensor) -> torch.Tensor:
|
| 127 |
+
"""Compute heavy atom masks from amino acid types.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
aa: Amino acid types, (B, N).
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Heavy atom masks, (B, N, 15).
|
| 134 |
+
"""
|
| 135 |
+
batch_size, n_res = aa.size()
|
| 136 |
+
aa = aa.flatten()
|
| 137 |
+
mask = restype_to_heavyatom_masks.to(aa.device)[aa].reshape(batch_size, n_res, 15)
|
| 138 |
+
return mask
|
| 139 |
+
|
| 140 |
+
def full_atom_reconstruction(
|
| 141 |
+
R_bb: torch.Tensor,
|
| 142 |
+
t_bb: torch.Tensor,
|
| 143 |
+
angles: torch.Tensor,
|
| 144 |
+
aa: torch.Tensor,
|
| 145 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 146 |
+
"""Compute full atom positions from backbone frames and torsional angles.
|
| 147 |
+
|
| 148 |
+
See alphafold supplementary Algorithm 24 for details.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
R_bb: Rotation of backbone frames, (B, N, 3, 3).
|
| 152 |
+
t_bb: Translation of backbone frames, (B, N, 3).
|
| 153 |
+
angles: (B, N, 5), angles between (0,2pi)
|
| 154 |
+
aa: Amino acid types, (B, N).
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
A tuple of atom positions and full frames, (pos14, R, t).
|
| 158 |
+
pos14: Full atom positions in pos14 representations, (B, N, 14, 3).
|
| 159 |
+
R: Rotation of backbone, psi, chi1-4 frames, (B, N, 5, 3, 3).
|
| 160 |
+
t: Rotation of backbone, psi, chi1-4 frames, (B, N, 5, 3).
|
| 161 |
+
"""
|
| 162 |
+
N, L = aa.size()
|
| 163 |
+
|
| 164 |
+
rot_psi, rot_chi1, rot_chi2, rot_chi3, rot_chi4 = _make_psi_chi_rotation_matrices(angles).unbind(dim=2)
|
| 165 |
+
# (B, N, 3, 3)
|
| 166 |
+
zeros = torch.zeros_like(t_bb)
|
| 167 |
+
|
| 168 |
+
rigid_rotation, rigid_translation, atom14_group, atom14_position = _get_rigid_group(aa)
|
| 169 |
+
|
| 170 |
+
R_psi, t_psi = compose_chain(
|
| 171 |
+
[
|
| 172 |
+
(R_bb, t_bb),
|
| 173 |
+
(rigid_rotation[:, :, constants.PSI_FRAME], rigid_translation[:, :, constants.PSI_FRAME]),
|
| 174 |
+
(rot_psi, zeros),
|
| 175 |
+
]
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
R_chi1, t_chi1 = compose_chain(
|
| 179 |
+
[
|
| 180 |
+
(R_bb, t_bb),
|
| 181 |
+
(rigid_rotation[:, :, constants.CHI1_FRAME], rigid_translation[:, :, constants.CHI1_FRAME]),
|
| 182 |
+
(rot_chi1, zeros),
|
| 183 |
+
]
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
R_chi2, t_chi2 = compose_chain(
|
| 187 |
+
[
|
| 188 |
+
(R_chi1, t_chi1),
|
| 189 |
+
(rigid_rotation[:, :, constants.CHI2_FRAME], rigid_translation[:, :, constants.CHI2_FRAME]),
|
| 190 |
+
(rot_chi2, zeros),
|
| 191 |
+
]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
R_chi3, t_chi3 = compose_chain(
|
| 195 |
+
[
|
| 196 |
+
(R_chi2, t_chi2),
|
| 197 |
+
(rigid_rotation[:, :, constants.CHI3_FRAME], rigid_translation[:, :, constants.CHI3_FRAME]),
|
| 198 |
+
(rot_chi3, zeros),
|
| 199 |
+
]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
R_chi4, t_chi4 = compose_chain(
|
| 203 |
+
[
|
| 204 |
+
(R_chi3, t_chi3),
|
| 205 |
+
(rigid_rotation[:, :, constants.CHI4_FRAME], rigid_translation[:, :, constants.CHI4_FRAME]),
|
| 206 |
+
(rot_chi4, zeros),
|
| 207 |
+
]
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Return Frame
|
| 211 |
+
R_ret = torch.stack([R_bb, R_psi, R_chi1, R_chi2, R_chi3, R_chi4], dim=2)
|
| 212 |
+
t_ret = torch.stack([t_bb, t_psi, t_chi1, t_chi2, t_chi3, t_chi4], dim=2)
|
| 213 |
+
|
| 214 |
+
# Backbone, Omega, Phi, Psi, Chi1,2,3,4
|
| 215 |
+
R_all = torch.stack([R_bb, R_bb, R_bb, R_psi, R_chi1, R_chi2, R_chi3, R_chi4], dim=2) # (B, N, 8, 3, 3)
|
| 216 |
+
t_all = torch.stack([t_bb, t_bb, t_bb, t_psi, t_chi1, t_chi2, t_chi3, t_chi4], dim=2) # (B, N, 8, 3)
|
| 217 |
+
|
| 218 |
+
index_R = atom14_group.reshape(N, L, 14, 1, 1).repeat(1, 1, 1, 3, 3) # (B, N, 14, 3, 3)
|
| 219 |
+
index_t = atom14_group.reshape(N, L, 14, 1).repeat(1, 1, 1, 3) # (B, N, 14, 3)
|
| 220 |
+
|
| 221 |
+
R_atom = torch.gather(R_all, dim=2, index=index_R) # (N, L, 14, 3, 3)
|
| 222 |
+
t_atom = torch.gather(t_all, dim=2, index=index_t) # (N, L, 14, 3)
|
| 223 |
+
p_atom = atom14_position # (N, L, 14, 3)
|
| 224 |
+
|
| 225 |
+
pos14 = torch.matmul(R_atom, p_atom.unsqueeze(-1)).squeeze(-1) + t_atom
|
| 226 |
+
return pos14, R_ret, t_ret
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
torsions_mask = torch.zeros([22,5]).float() # 0-19, X, PAD
|
| 231 |
+
for i in range(21):
|
| 232 |
+
torsions_mask[i] = torch.tensor([True] + constants.chi_angles_mask[i]).float()
|
| 233 |
+
# print(angles_mask)
|
| 234 |
+
|
| 235 |
+
if __name__ =='__main__':
|
| 236 |
+
aa = torch.full([3,8],fill_value=constants.AA.THR).long()
|
| 237 |
+
mask = get_heavyatom_mask(aa)
|
| 238 |
+
print(mask)
|
| 239 |
+
print(mask.shape)
|
models_con/torus.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def tor_expmap(x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
|
| 6 |
+
return (x + u) % (2 * math.pi)
|
| 7 |
+
|
| 8 |
+
def tor_logmap(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 9 |
+
return torch.atan2(torch.sin(y - x), torch.cos(y - x))
|
| 10 |
+
|
| 11 |
+
def tor_projx(x: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
return x % (2 * math.pi)
|
| 13 |
+
|
| 14 |
+
def tor_random_uniform(*size, dtype=None, device=None) -> torch.Tensor:
|
| 15 |
+
z = torch.rand(*size, dtype=dtype, device=device)
|
| 16 |
+
return z * 2 * math.pi
|
| 17 |
+
|
| 18 |
+
def tor_uniform_logprob(x):
|
| 19 |
+
dim = x.shape[-1]
|
| 20 |
+
return torch.full_like(x[..., 0], -dim * math.log(2 * math.pi))
|
| 21 |
+
|
| 22 |
+
def tor_geodesic_t(t, angles_1, angles_0):
|
| 23 |
+
# target, base
|
| 24 |
+
tangent_vec = t * tor_logmap(angles_0, angles_1)
|
| 25 |
+
points_at_time_t = tor_expmap(angles_0, tangent_vec)
|
| 26 |
+
return points_at_time_t
|
| 27 |
+
|
| 28 |
+
if __name__ =='__main__':
|
| 29 |
+
a = tor_random_uniform((2,3,5))
|
| 30 |
+
b = tor_random_uniform((2,3,5))
|
| 31 |
+
t = torch.ones((2,1)) * 0.2
|
| 32 |
+
c = tor_geodesic_t(t[...,None],a,b)
|
| 33 |
+
print(c)
|
| 34 |
+
print(c.shape)
|
models_con/utils.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import math
|
| 8 |
+
from tqdm.auto import tqdm
|
| 9 |
+
import functools
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
import pandas as pd
|
| 15 |
+
|
| 16 |
+
def process_dic(state_dict):
|
| 17 |
+
new_state_dict = {}
|
| 18 |
+
for k,v in state_dict.items():
|
| 19 |
+
if 'module' in k:
|
| 20 |
+
new_state_dict[k[7:]] = v
|
| 21 |
+
else:
|
| 22 |
+
new_state_dict[k] = v
|
| 23 |
+
return new_state_dict
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def calc_distogram(pos, min_bin, max_bin, num_bins):
|
| 27 |
+
dists_2d = torch.linalg.norm(
|
| 28 |
+
pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None]
|
| 29 |
+
lower = torch.linspace(
|
| 30 |
+
min_bin,
|
| 31 |
+
max_bin,
|
| 32 |
+
num_bins,
|
| 33 |
+
device=pos.device)
|
| 34 |
+
upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
|
| 35 |
+
dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
|
| 36 |
+
return dgram
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_index_embedding(indices, embed_size, max_len=2056):
|
| 40 |
+
"""Creates sine / cosine positional embeddings from a prespecified indices.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
indices: offsets of size [..., N_edges] of type integer
|
| 44 |
+
max_len: maximum length.
|
| 45 |
+
embed_size: dimension of the embeddings to create
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
positional embedding of shape [N, embed_size]
|
| 49 |
+
"""
|
| 50 |
+
K = torch.arange(embed_size//2, device=indices.device)
|
| 51 |
+
pos_embedding_sin = torch.sin(
|
| 52 |
+
indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
|
| 53 |
+
pos_embedding_cos = torch.cos(
|
| 54 |
+
indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device)
|
| 55 |
+
pos_embedding = torch.cat([
|
| 56 |
+
pos_embedding_sin, pos_embedding_cos], axis=-1)
|
| 57 |
+
return pos_embedding
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_time_embedding(timesteps, embedding_dim, max_positions=2000):
|
| 61 |
+
# Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
|
| 62 |
+
assert len(timesteps.shape) == 1
|
| 63 |
+
timesteps = timesteps * max_positions
|
| 64 |
+
half_dim = embedding_dim // 2
|
| 65 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
| 66 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
| 67 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
| 68 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 69 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 70 |
+
emb = F.pad(emb, (0, 1), mode='constant')
|
| 71 |
+
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
| 72 |
+
return emb
|
openfold/config.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NUM_RES = "num residues placeholder"
|
| 2 |
+
NUM_MSA_SEQ = "msa placeholder"
|
| 3 |
+
NUM_EXTRA_SEQ = "extra msa placeholder"
|
| 4 |
+
NUM_TEMPLATES = "num templates placeholder"
|
openfold/model/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import importlib as importlib
|
| 4 |
+
|
| 5 |
+
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
|
| 6 |
+
__all__ = [
|
| 7 |
+
os.path.basename(f)[:-3]
|
| 8 |
+
for f in _files
|
| 9 |
+
if os.path.isfile(f) and not f.endswith("__init__.py")
|
| 10 |
+
]
|
| 11 |
+
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
|
| 12 |
+
for _m in _modules:
|
| 13 |
+
globals()[_m[0]] = _m[1]
|
| 14 |
+
|
| 15 |
+
# Avoid needlessly cluttering the global namespace
|
| 16 |
+
del _files, _m, _modules
|
openfold/model/dropout.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from functools import partialmethod
|
| 19 |
+
from typing import Union, List
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Dropout(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
Implementation of dropout with the ability to share the dropout mask
|
| 25 |
+
along a particular dimension.
|
| 26 |
+
|
| 27 |
+
If not in training mode, this module computes the identity function.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
r:
|
| 34 |
+
Dropout rate
|
| 35 |
+
batch_dim:
|
| 36 |
+
Dimension(s) along which the dropout mask is shared
|
| 37 |
+
"""
|
| 38 |
+
super(Dropout, self).__init__()
|
| 39 |
+
|
| 40 |
+
self.r = r
|
| 41 |
+
if type(batch_dim) == int:
|
| 42 |
+
batch_dim = [batch_dim]
|
| 43 |
+
self.batch_dim = batch_dim
|
| 44 |
+
self.dropout = nn.Dropout(self.r)
|
| 45 |
+
|
| 46 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
x:
|
| 50 |
+
Tensor to which dropout is applied. Can have any shape
|
| 51 |
+
compatible with self.batch_dim
|
| 52 |
+
"""
|
| 53 |
+
shape = list(x.shape)
|
| 54 |
+
if self.batch_dim is not None:
|
| 55 |
+
for bd in self.batch_dim:
|
| 56 |
+
shape[bd] = 1
|
| 57 |
+
mask = x.new_ones(shape)
|
| 58 |
+
mask = self.dropout(mask)
|
| 59 |
+
x *= mask
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class DropoutRowwise(Dropout):
|
| 64 |
+
"""
|
| 65 |
+
Convenience class for rowwise dropout as described in subsection
|
| 66 |
+
1.11.6.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DropoutColumnwise(Dropout):
|
| 73 |
+
"""
|
| 74 |
+
Convenience class for columnwise dropout as described in subsection
|
| 75 |
+
1.11.6.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
|
openfold/model/embedders.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
from openfold.model.primitives import Linear, LayerNorm
|
| 21 |
+
from openfold.utils.tensor_utils import one_hot
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InputEmbedder(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Embeds a subset of the input features.
|
| 27 |
+
|
| 28 |
+
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
tf_dim: int,
|
| 34 |
+
msa_dim: int,
|
| 35 |
+
c_z: int,
|
| 36 |
+
c_m: int,
|
| 37 |
+
relpos_k: int,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
tf_dim:
|
| 43 |
+
Final dimension of the target features
|
| 44 |
+
msa_dim:
|
| 45 |
+
Final dimension of the MSA features
|
| 46 |
+
c_z:
|
| 47 |
+
Pair embedding dimension
|
| 48 |
+
c_m:
|
| 49 |
+
MSA embedding dimension
|
| 50 |
+
relpos_k:
|
| 51 |
+
Window size used in relative positional encoding
|
| 52 |
+
"""
|
| 53 |
+
super(InputEmbedder, self).__init__()
|
| 54 |
+
|
| 55 |
+
self.tf_dim = tf_dim
|
| 56 |
+
self.msa_dim = msa_dim
|
| 57 |
+
|
| 58 |
+
self.c_z = c_z
|
| 59 |
+
self.c_m = c_m
|
| 60 |
+
|
| 61 |
+
self.linear_tf_z_i = Linear(tf_dim, c_z)
|
| 62 |
+
self.linear_tf_z_j = Linear(tf_dim, c_z)
|
| 63 |
+
self.linear_tf_m = Linear(tf_dim, c_m)
|
| 64 |
+
self.linear_msa_m = Linear(msa_dim, c_m)
|
| 65 |
+
|
| 66 |
+
# RPE stuff
|
| 67 |
+
self.relpos_k = relpos_k
|
| 68 |
+
self.no_bins = 2 * relpos_k + 1
|
| 69 |
+
self.linear_relpos = Linear(self.no_bins, c_z)
|
| 70 |
+
|
| 71 |
+
def relpos(self, ri: torch.Tensor):
|
| 72 |
+
"""
|
| 73 |
+
Computes relative positional encodings
|
| 74 |
+
|
| 75 |
+
Implements Algorithm 4.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
ri:
|
| 79 |
+
"residue_index" features of shape [*, N]
|
| 80 |
+
"""
|
| 81 |
+
d = ri[..., None] - ri[..., None, :]
|
| 82 |
+
boundaries = torch.arange(
|
| 83 |
+
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
|
| 84 |
+
)
|
| 85 |
+
oh = one_hot(d, boundaries).type(ri.dtype)
|
| 86 |
+
return self.linear_relpos(oh)
|
| 87 |
+
|
| 88 |
+
def forward(
|
| 89 |
+
self,
|
| 90 |
+
tf: torch.Tensor,
|
| 91 |
+
ri: torch.Tensor,
|
| 92 |
+
msa: torch.Tensor,
|
| 93 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 94 |
+
"""
|
| 95 |
+
Args:
|
| 96 |
+
tf:
|
| 97 |
+
"target_feat" features of shape [*, N_res, tf_dim]
|
| 98 |
+
ri:
|
| 99 |
+
"residue_index" features of shape [*, N_res]
|
| 100 |
+
msa:
|
| 101 |
+
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
|
| 102 |
+
Returns:
|
| 103 |
+
msa_emb:
|
| 104 |
+
[*, N_clust, N_res, C_m] MSA embedding
|
| 105 |
+
pair_emb:
|
| 106 |
+
[*, N_res, N_res, C_z] pair embedding
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
# [*, N_res, c_z]
|
| 110 |
+
tf_emb_i = self.linear_tf_z_i(tf)
|
| 111 |
+
tf_emb_j = self.linear_tf_z_j(tf)
|
| 112 |
+
|
| 113 |
+
# [*, N_res, N_res, c_z]
|
| 114 |
+
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
|
| 115 |
+
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
|
| 116 |
+
|
| 117 |
+
# [*, N_clust, N_res, c_m]
|
| 118 |
+
n_clust = msa.shape[-3]
|
| 119 |
+
tf_m = (
|
| 120 |
+
self.linear_tf_m(tf)
|
| 121 |
+
.unsqueeze(-3)
|
| 122 |
+
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
|
| 123 |
+
)
|
| 124 |
+
msa_emb = self.linear_msa_m(msa) + tf_m
|
| 125 |
+
|
| 126 |
+
return msa_emb, pair_emb
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class RecyclingEmbedder(nn.Module):
|
| 130 |
+
"""
|
| 131 |
+
Embeds the output of an iteration of the model for recycling.
|
| 132 |
+
|
| 133 |
+
Implements Algorithm 32.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
c_m: int,
|
| 139 |
+
c_z: int,
|
| 140 |
+
min_bin: float,
|
| 141 |
+
max_bin: float,
|
| 142 |
+
no_bins: int,
|
| 143 |
+
inf: float = 1e8,
|
| 144 |
+
**kwargs,
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
Args:
|
| 148 |
+
c_m:
|
| 149 |
+
MSA channel dimension
|
| 150 |
+
c_z:
|
| 151 |
+
Pair embedding channel dimension
|
| 152 |
+
min_bin:
|
| 153 |
+
Smallest distogram bin (Angstroms)
|
| 154 |
+
max_bin:
|
| 155 |
+
Largest distogram bin (Angstroms)
|
| 156 |
+
no_bins:
|
| 157 |
+
Number of distogram bins
|
| 158 |
+
"""
|
| 159 |
+
super(RecyclingEmbedder, self).__init__()
|
| 160 |
+
|
| 161 |
+
self.c_m = c_m
|
| 162 |
+
self.c_z = c_z
|
| 163 |
+
self.min_bin = min_bin
|
| 164 |
+
self.max_bin = max_bin
|
| 165 |
+
self.no_bins = no_bins
|
| 166 |
+
self.inf = inf
|
| 167 |
+
|
| 168 |
+
self.bins = None
|
| 169 |
+
|
| 170 |
+
self.linear = Linear(self.no_bins, self.c_z)
|
| 171 |
+
self.layer_norm_m = LayerNorm(self.c_m)
|
| 172 |
+
self.layer_norm_z = LayerNorm(self.c_z)
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
m: torch.Tensor,
|
| 177 |
+
z: torch.Tensor,
|
| 178 |
+
x: torch.Tensor,
|
| 179 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 180 |
+
"""
|
| 181 |
+
Args:
|
| 182 |
+
m:
|
| 183 |
+
First row of the MSA embedding. [*, N_res, C_m]
|
| 184 |
+
z:
|
| 185 |
+
[*, N_res, N_res, C_z] pair embedding
|
| 186 |
+
x:
|
| 187 |
+
[*, N_res, 3] predicted C_beta coordinates
|
| 188 |
+
Returns:
|
| 189 |
+
m:
|
| 190 |
+
[*, N_res, C_m] MSA embedding update
|
| 191 |
+
z:
|
| 192 |
+
[*, N_res, N_res, C_z] pair embedding update
|
| 193 |
+
"""
|
| 194 |
+
if self.bins is None:
|
| 195 |
+
self.bins = torch.linspace(
|
| 196 |
+
self.min_bin,
|
| 197 |
+
self.max_bin,
|
| 198 |
+
self.no_bins,
|
| 199 |
+
dtype=x.dtype,
|
| 200 |
+
device=x.device,
|
| 201 |
+
requires_grad=False,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# [*, N, C_m]
|
| 205 |
+
m_update = self.layer_norm_m(m)
|
| 206 |
+
|
| 207 |
+
# This squared method might become problematic in FP16 mode.
|
| 208 |
+
# I'm using it because my homegrown method had a stubborn discrepancy I
|
| 209 |
+
# couldn't find in time.
|
| 210 |
+
squared_bins = self.bins ** 2
|
| 211 |
+
upper = torch.cat(
|
| 212 |
+
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
|
| 213 |
+
)
|
| 214 |
+
d = torch.sum(
|
| 215 |
+
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# [*, N, N, no_bins]
|
| 219 |
+
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
|
| 220 |
+
|
| 221 |
+
# [*, N, N, C_z]
|
| 222 |
+
d = self.linear(d)
|
| 223 |
+
z_update = d + self.layer_norm_z(z)
|
| 224 |
+
|
| 225 |
+
return m_update, z_update
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class TemplateAngleEmbedder(nn.Module):
|
| 229 |
+
"""
|
| 230 |
+
Embeds the "template_angle_feat" feature.
|
| 231 |
+
|
| 232 |
+
Implements Algorithm 2, line 7.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
c_in: int,
|
| 238 |
+
c_out: int,
|
| 239 |
+
**kwargs,
|
| 240 |
+
):
|
| 241 |
+
"""
|
| 242 |
+
Args:
|
| 243 |
+
c_in:
|
| 244 |
+
Final dimension of "template_angle_feat"
|
| 245 |
+
c_out:
|
| 246 |
+
Output channel dimension
|
| 247 |
+
"""
|
| 248 |
+
super(TemplateAngleEmbedder, self).__init__()
|
| 249 |
+
|
| 250 |
+
self.c_out = c_out
|
| 251 |
+
self.c_in = c_in
|
| 252 |
+
|
| 253 |
+
self.linear_1 = Linear(self.c_in, self.c_out, init="relu")
|
| 254 |
+
self.relu = nn.ReLU()
|
| 255 |
+
self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
|
| 256 |
+
|
| 257 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 258 |
+
"""
|
| 259 |
+
Args:
|
| 260 |
+
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
|
| 261 |
+
Returns:
|
| 262 |
+
x: [*, N_templ, N_res, C_out] embedding
|
| 263 |
+
"""
|
| 264 |
+
x = self.linear_1(x)
|
| 265 |
+
x = self.relu(x)
|
| 266 |
+
x = self.linear_2(x)
|
| 267 |
+
|
| 268 |
+
return x
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class TemplatePairEmbedder(nn.Module):
|
| 272 |
+
"""
|
| 273 |
+
Embeds "template_pair_feat" features.
|
| 274 |
+
|
| 275 |
+
Implements Algorithm 2, line 9.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(
|
| 279 |
+
self,
|
| 280 |
+
c_in: int,
|
| 281 |
+
c_out: int,
|
| 282 |
+
**kwargs,
|
| 283 |
+
):
|
| 284 |
+
"""
|
| 285 |
+
Args:
|
| 286 |
+
c_in:
|
| 287 |
+
|
| 288 |
+
c_out:
|
| 289 |
+
Output channel dimension
|
| 290 |
+
"""
|
| 291 |
+
super(TemplatePairEmbedder, self).__init__()
|
| 292 |
+
|
| 293 |
+
self.c_in = c_in
|
| 294 |
+
self.c_out = c_out
|
| 295 |
+
|
| 296 |
+
# Despite there being no relu nearby, the source uses that initializer
|
| 297 |
+
self.linear = Linear(self.c_in, self.c_out, init="relu")
|
| 298 |
+
|
| 299 |
+
def forward(
|
| 300 |
+
self,
|
| 301 |
+
x: torch.Tensor,
|
| 302 |
+
) -> torch.Tensor:
|
| 303 |
+
"""
|
| 304 |
+
Args:
|
| 305 |
+
x:
|
| 306 |
+
[*, C_in] input tensor
|
| 307 |
+
Returns:
|
| 308 |
+
[*, C_out] output tensor
|
| 309 |
+
"""
|
| 310 |
+
x = self.linear(x)
|
| 311 |
+
|
| 312 |
+
return x
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class ExtraMSAEmbedder(nn.Module):
|
| 316 |
+
"""
|
| 317 |
+
Embeds unclustered MSA sequences.
|
| 318 |
+
|
| 319 |
+
Implements Algorithm 2, line 15
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
def __init__(
|
| 323 |
+
self,
|
| 324 |
+
c_in: int,
|
| 325 |
+
c_out: int,
|
| 326 |
+
**kwargs,
|
| 327 |
+
):
|
| 328 |
+
"""
|
| 329 |
+
Args:
|
| 330 |
+
c_in:
|
| 331 |
+
Input channel dimension
|
| 332 |
+
c_out:
|
| 333 |
+
Output channel dimension
|
| 334 |
+
"""
|
| 335 |
+
super(ExtraMSAEmbedder, self).__init__()
|
| 336 |
+
|
| 337 |
+
self.c_in = c_in
|
| 338 |
+
self.c_out = c_out
|
| 339 |
+
|
| 340 |
+
self.linear = Linear(self.c_in, self.c_out)
|
| 341 |
+
|
| 342 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 343 |
+
"""
|
| 344 |
+
Args:
|
| 345 |
+
x:
|
| 346 |
+
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
|
| 347 |
+
Returns:
|
| 348 |
+
[*, N_extra_seq, N_res, C_out] embedding
|
| 349 |
+
"""
|
| 350 |
+
x = self.linear(x)
|
| 351 |
+
|
| 352 |
+
return x
|
openfold/model/evoformer.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from typing import Tuple, Optional
|
| 20 |
+
from functools import partial
|
| 21 |
+
|
| 22 |
+
from openfold.model.primitives import Linear, LayerNorm
|
| 23 |
+
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
|
| 24 |
+
from openfold.model.msa import (
|
| 25 |
+
MSARowAttentionWithPairBias,
|
| 26 |
+
MSAColumnAttention,
|
| 27 |
+
MSAColumnGlobalAttention,
|
| 28 |
+
)
|
| 29 |
+
from openfold.model.outer_product_mean import OuterProductMean
|
| 30 |
+
from openfold.model.pair_transition import PairTransition
|
| 31 |
+
from openfold.model.triangular_attention import (
|
| 32 |
+
TriangleAttentionStartingNode,
|
| 33 |
+
TriangleAttentionEndingNode,
|
| 34 |
+
)
|
| 35 |
+
from openfold.model.triangular_multiplicative_update import (
|
| 36 |
+
TriangleMultiplicationOutgoing,
|
| 37 |
+
TriangleMultiplicationIncoming,
|
| 38 |
+
)
|
| 39 |
+
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
|
| 40 |
+
from openfold.utils.tensor_utils import chunk_layer
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class MSATransition(nn.Module):
|
| 44 |
+
"""
|
| 45 |
+
Feed-forward network applied to MSA activations after attention.
|
| 46 |
+
|
| 47 |
+
Implements Algorithm 9
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, c_m, n):
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
c_m:
|
| 53 |
+
MSA channel dimension
|
| 54 |
+
n:
|
| 55 |
+
Factor multiplied to c_m to obtain the hidden channel
|
| 56 |
+
dimension
|
| 57 |
+
"""
|
| 58 |
+
super(MSATransition, self).__init__()
|
| 59 |
+
|
| 60 |
+
self.c_m = c_m
|
| 61 |
+
self.n = n
|
| 62 |
+
|
| 63 |
+
self.layer_norm = LayerNorm(self.c_m)
|
| 64 |
+
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
|
| 65 |
+
self.relu = nn.ReLU()
|
| 66 |
+
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
|
| 67 |
+
|
| 68 |
+
def _transition(self, m, mask):
|
| 69 |
+
m = self.linear_1(m)
|
| 70 |
+
m = self.relu(m)
|
| 71 |
+
m = self.linear_2(m) * mask
|
| 72 |
+
return m
|
| 73 |
+
|
| 74 |
+
@torch.jit.ignore
|
| 75 |
+
def _chunk(self,
|
| 76 |
+
m: torch.Tensor,
|
| 77 |
+
mask: torch.Tensor,
|
| 78 |
+
chunk_size: int,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
return chunk_layer(
|
| 81 |
+
self._transition,
|
| 82 |
+
{"m": m, "mask": mask},
|
| 83 |
+
chunk_size=chunk_size,
|
| 84 |
+
no_batch_dims=len(m.shape[:-2]),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def forward(
|
| 89 |
+
self,
|
| 90 |
+
m: torch.Tensor,
|
| 91 |
+
mask: Optional[torch.Tensor] = None,
|
| 92 |
+
chunk_size: Optional[int] = None,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
"""
|
| 95 |
+
Args:
|
| 96 |
+
m:
|
| 97 |
+
[*, N_seq, N_res, C_m] MSA activation
|
| 98 |
+
mask:
|
| 99 |
+
[*, N_seq, N_res, C_m] MSA mask
|
| 100 |
+
Returns:
|
| 101 |
+
m:
|
| 102 |
+
[*, N_seq, N_res, C_m] MSA activation update
|
| 103 |
+
"""
|
| 104 |
+
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
|
| 105 |
+
if mask is None:
|
| 106 |
+
mask = m.new_ones(m.shape[:-1])
|
| 107 |
+
|
| 108 |
+
mask = mask.unsqueeze(-1)
|
| 109 |
+
|
| 110 |
+
m = self.layer_norm(m)
|
| 111 |
+
|
| 112 |
+
if chunk_size is not None:
|
| 113 |
+
m = self._chunk(m, mask, chunk_size)
|
| 114 |
+
else:
|
| 115 |
+
m = self._transition(m, mask)
|
| 116 |
+
|
| 117 |
+
return m
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class EvoformerBlockCore(nn.Module):
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
c_m: int,
|
| 124 |
+
c_z: int,
|
| 125 |
+
c_hidden_opm: int,
|
| 126 |
+
c_hidden_mul: int,
|
| 127 |
+
c_hidden_pair_att: int,
|
| 128 |
+
no_heads_msa: int,
|
| 129 |
+
no_heads_pair: int,
|
| 130 |
+
transition_n: int,
|
| 131 |
+
pair_dropout: float,
|
| 132 |
+
inf: float,
|
| 133 |
+
eps: float,
|
| 134 |
+
_is_extra_msa_stack: bool = False,
|
| 135 |
+
):
|
| 136 |
+
super(EvoformerBlockCore, self).__init__()
|
| 137 |
+
|
| 138 |
+
self.msa_transition = MSATransition(
|
| 139 |
+
c_m=c_m,
|
| 140 |
+
n=transition_n,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.outer_product_mean = OuterProductMean(
|
| 144 |
+
c_m,
|
| 145 |
+
c_z,
|
| 146 |
+
c_hidden_opm,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
| 150 |
+
c_z,
|
| 151 |
+
c_hidden_mul,
|
| 152 |
+
)
|
| 153 |
+
self.tri_mul_in = TriangleMultiplicationIncoming(
|
| 154 |
+
c_z,
|
| 155 |
+
c_hidden_mul,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.tri_att_start = TriangleAttentionStartingNode(
|
| 159 |
+
c_z,
|
| 160 |
+
c_hidden_pair_att,
|
| 161 |
+
no_heads_pair,
|
| 162 |
+
inf=inf,
|
| 163 |
+
)
|
| 164 |
+
self.tri_att_end = TriangleAttentionEndingNode(
|
| 165 |
+
c_z,
|
| 166 |
+
c_hidden_pair_att,
|
| 167 |
+
no_heads_pair,
|
| 168 |
+
inf=inf,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.pair_transition = PairTransition(
|
| 172 |
+
c_z,
|
| 173 |
+
transition_n,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
|
| 177 |
+
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
m: torch.Tensor,
|
| 182 |
+
z: torch.Tensor,
|
| 183 |
+
msa_mask: torch.Tensor,
|
| 184 |
+
pair_mask: torch.Tensor,
|
| 185 |
+
chunk_size: Optional[int] = None,
|
| 186 |
+
_mask_trans: bool = True,
|
| 187 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 188 |
+
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
| 189 |
+
# should be disabled to better approximate the exact activations of
|
| 190 |
+
# the original.
|
| 191 |
+
msa_trans_mask = msa_mask if _mask_trans else None
|
| 192 |
+
pair_trans_mask = pair_mask if _mask_trans else None
|
| 193 |
+
|
| 194 |
+
m = m + self.msa_transition(
|
| 195 |
+
m, mask=msa_trans_mask, chunk_size=chunk_size
|
| 196 |
+
)
|
| 197 |
+
z = z + self.outer_product_mean(
|
| 198 |
+
m, mask=msa_mask, chunk_size=chunk_size
|
| 199 |
+
)
|
| 200 |
+
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
|
| 201 |
+
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
|
| 202 |
+
z = z + self.ps_dropout_row_layer(
|
| 203 |
+
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
|
| 204 |
+
)
|
| 205 |
+
z = z + self.ps_dropout_col_layer(
|
| 206 |
+
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
|
| 207 |
+
)
|
| 208 |
+
z = z + self.pair_transition(
|
| 209 |
+
z, mask=pair_trans_mask, chunk_size=chunk_size
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
return m, z
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class EvoformerBlock(nn.Module):
|
| 216 |
+
def __init__(self,
|
| 217 |
+
c_m: int,
|
| 218 |
+
c_z: int,
|
| 219 |
+
c_hidden_msa_att: int,
|
| 220 |
+
c_hidden_opm: int,
|
| 221 |
+
c_hidden_mul: int,
|
| 222 |
+
c_hidden_pair_att: int,
|
| 223 |
+
no_heads_msa: int,
|
| 224 |
+
no_heads_pair: int,
|
| 225 |
+
transition_n: int,
|
| 226 |
+
msa_dropout: float,
|
| 227 |
+
pair_dropout: float,
|
| 228 |
+
inf: float,
|
| 229 |
+
eps: float,
|
| 230 |
+
):
|
| 231 |
+
super(EvoformerBlock, self).__init__()
|
| 232 |
+
|
| 233 |
+
self.msa_att_row = MSARowAttentionWithPairBias(
|
| 234 |
+
c_m=c_m,
|
| 235 |
+
c_z=c_z,
|
| 236 |
+
c_hidden=c_hidden_msa_att,
|
| 237 |
+
no_heads=no_heads_msa,
|
| 238 |
+
inf=inf,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.msa_att_col = MSAColumnAttention(
|
| 242 |
+
c_m,
|
| 243 |
+
c_hidden_msa_att,
|
| 244 |
+
no_heads_msa,
|
| 245 |
+
inf=inf,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
| 249 |
+
|
| 250 |
+
self.core = EvoformerBlockCore(
|
| 251 |
+
c_m=c_m,
|
| 252 |
+
c_z=c_z,
|
| 253 |
+
c_hidden_opm=c_hidden_opm,
|
| 254 |
+
c_hidden_mul=c_hidden_mul,
|
| 255 |
+
c_hidden_pair_att=c_hidden_pair_att,
|
| 256 |
+
no_heads_msa=no_heads_msa,
|
| 257 |
+
no_heads_pair=no_heads_pair,
|
| 258 |
+
transition_n=transition_n,
|
| 259 |
+
pair_dropout=pair_dropout,
|
| 260 |
+
inf=inf,
|
| 261 |
+
eps=eps,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
def forward(self,
|
| 265 |
+
m: torch.Tensor,
|
| 266 |
+
z: torch.Tensor,
|
| 267 |
+
msa_mask: torch.Tensor,
|
| 268 |
+
pair_mask: torch.Tensor,
|
| 269 |
+
chunk_size: Optional[int] = None,
|
| 270 |
+
_mask_trans: bool = True,
|
| 271 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 272 |
+
m = m + self.msa_dropout_layer(
|
| 273 |
+
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
|
| 274 |
+
)
|
| 275 |
+
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
|
| 276 |
+
m, z = self.core(
|
| 277 |
+
m,
|
| 278 |
+
z,
|
| 279 |
+
msa_mask=msa_mask,
|
| 280 |
+
pair_mask=pair_mask,
|
| 281 |
+
chunk_size=chunk_size,
|
| 282 |
+
_mask_trans=_mask_trans,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return m, z
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class ExtraMSABlock(nn.Module):
|
| 289 |
+
"""
|
| 290 |
+
Almost identical to the standard EvoformerBlock, except in that the
|
| 291 |
+
ExtraMSABlock uses GlobalAttention for MSA column attention and
|
| 292 |
+
requires more fine-grained control over checkpointing. Separated from
|
| 293 |
+
its twin to preserve the TorchScript-ability of the latter.
|
| 294 |
+
"""
|
| 295 |
+
def __init__(self,
|
| 296 |
+
c_m: int,
|
| 297 |
+
c_z: int,
|
| 298 |
+
c_hidden_msa_att: int,
|
| 299 |
+
c_hidden_opm: int,
|
| 300 |
+
c_hidden_mul: int,
|
| 301 |
+
c_hidden_pair_att: int,
|
| 302 |
+
no_heads_msa: int,
|
| 303 |
+
no_heads_pair: int,
|
| 304 |
+
transition_n: int,
|
| 305 |
+
msa_dropout: float,
|
| 306 |
+
pair_dropout: float,
|
| 307 |
+
inf: float,
|
| 308 |
+
eps: float,
|
| 309 |
+
ckpt: bool,
|
| 310 |
+
):
|
| 311 |
+
super(ExtraMSABlock, self).__init__()
|
| 312 |
+
|
| 313 |
+
self.ckpt = ckpt
|
| 314 |
+
|
| 315 |
+
self.msa_att_row = MSARowAttentionWithPairBias(
|
| 316 |
+
c_m=c_m,
|
| 317 |
+
c_z=c_z,
|
| 318 |
+
c_hidden=c_hidden_msa_att,
|
| 319 |
+
no_heads=no_heads_msa,
|
| 320 |
+
inf=inf,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
self.msa_att_col = MSAColumnGlobalAttention(
|
| 324 |
+
c_in=c_m,
|
| 325 |
+
c_hidden=c_hidden_msa_att,
|
| 326 |
+
no_heads=no_heads_msa,
|
| 327 |
+
inf=inf,
|
| 328 |
+
eps=eps,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
| 332 |
+
|
| 333 |
+
self.core = EvoformerBlockCore(
|
| 334 |
+
c_m=c_m,
|
| 335 |
+
c_z=c_z,
|
| 336 |
+
c_hidden_opm=c_hidden_opm,
|
| 337 |
+
c_hidden_mul=c_hidden_mul,
|
| 338 |
+
c_hidden_pair_att=c_hidden_pair_att,
|
| 339 |
+
no_heads_msa=no_heads_msa,
|
| 340 |
+
no_heads_pair=no_heads_pair,
|
| 341 |
+
transition_n=transition_n,
|
| 342 |
+
pair_dropout=pair_dropout,
|
| 343 |
+
inf=inf,
|
| 344 |
+
eps=eps,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def forward(self,
|
| 348 |
+
m: torch.Tensor,
|
| 349 |
+
z: torch.Tensor,
|
| 350 |
+
msa_mask: torch.Tensor,
|
| 351 |
+
pair_mask: torch.Tensor,
|
| 352 |
+
chunk_size: Optional[int] = None,
|
| 353 |
+
_chunk_logits: Optional[int] = 1024,
|
| 354 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 355 |
+
def add(m1, m2):
|
| 356 |
+
# The first operation in a checkpoint can't be in-place, but it's
|
| 357 |
+
# nice to have in-place addition during inference. Thus...
|
| 358 |
+
if(torch.is_grad_enabled()):
|
| 359 |
+
m1 = m1 + m2
|
| 360 |
+
else:
|
| 361 |
+
m1 += m2
|
| 362 |
+
|
| 363 |
+
return m1
|
| 364 |
+
|
| 365 |
+
m = add(m, self.msa_dropout_layer(
|
| 366 |
+
self.msa_att_row(
|
| 367 |
+
m.clone() if torch.is_grad_enabled() else m,
|
| 368 |
+
z=z.clone() if torch.is_grad_enabled() else z,
|
| 369 |
+
mask=msa_mask,
|
| 370 |
+
chunk_size=chunk_size,
|
| 371 |
+
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
|
| 372 |
+
_checkpoint_chunks=
|
| 373 |
+
self.ckpt if torch.is_grad_enabled() else False,
|
| 374 |
+
)
|
| 375 |
+
))
|
| 376 |
+
|
| 377 |
+
def fn(m, z):
|
| 378 |
+
m = add(m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size))
|
| 379 |
+
m, z = self.core(
|
| 380 |
+
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
return m, z
|
| 384 |
+
|
| 385 |
+
if(torch.is_grad_enabled() and self.ckpt):
|
| 386 |
+
checkpoint_fn = get_checkpoint_fn()
|
| 387 |
+
m, z = checkpoint_fn(fn, m, z)
|
| 388 |
+
else:
|
| 389 |
+
m, z = fn(m, z)
|
| 390 |
+
|
| 391 |
+
return m, z
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class EvoformerStack(nn.Module):
|
| 395 |
+
"""
|
| 396 |
+
Main Evoformer trunk.
|
| 397 |
+
|
| 398 |
+
Implements Algorithm 6.
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
def __init__(
|
| 402 |
+
self,
|
| 403 |
+
c_m: int,
|
| 404 |
+
c_z: int,
|
| 405 |
+
c_hidden_msa_att: int,
|
| 406 |
+
c_hidden_opm: int,
|
| 407 |
+
c_hidden_mul: int,
|
| 408 |
+
c_hidden_pair_att: int,
|
| 409 |
+
c_s: int,
|
| 410 |
+
no_heads_msa: int,
|
| 411 |
+
no_heads_pair: int,
|
| 412 |
+
no_blocks: int,
|
| 413 |
+
transition_n: int,
|
| 414 |
+
msa_dropout: float,
|
| 415 |
+
pair_dropout: float,
|
| 416 |
+
blocks_per_ckpt: int,
|
| 417 |
+
inf: float,
|
| 418 |
+
eps: float,
|
| 419 |
+
clear_cache_between_blocks: bool = False,
|
| 420 |
+
**kwargs,
|
| 421 |
+
):
|
| 422 |
+
"""
|
| 423 |
+
Args:
|
| 424 |
+
c_m:
|
| 425 |
+
MSA channel dimension
|
| 426 |
+
c_z:
|
| 427 |
+
Pair channel dimension
|
| 428 |
+
c_hidden_msa_att:
|
| 429 |
+
Hidden dimension in MSA attention
|
| 430 |
+
c_hidden_opm:
|
| 431 |
+
Hidden dimension in outer product mean module
|
| 432 |
+
c_hidden_mul:
|
| 433 |
+
Hidden dimension in multiplicative updates
|
| 434 |
+
c_hidden_pair_att:
|
| 435 |
+
Hidden dimension in triangular attention
|
| 436 |
+
c_s:
|
| 437 |
+
Channel dimension of the output "single" embedding
|
| 438 |
+
no_heads_msa:
|
| 439 |
+
Number of heads used for MSA attention
|
| 440 |
+
no_heads_pair:
|
| 441 |
+
Number of heads used for pair attention
|
| 442 |
+
no_blocks:
|
| 443 |
+
Number of Evoformer blocks in the stack
|
| 444 |
+
transition_n:
|
| 445 |
+
Factor by which to multiply c_m to obtain the MSATransition
|
| 446 |
+
hidden dimension
|
| 447 |
+
msa_dropout:
|
| 448 |
+
Dropout rate for MSA activations
|
| 449 |
+
pair_dropout:
|
| 450 |
+
Dropout used for pair activations
|
| 451 |
+
blocks_per_ckpt:
|
| 452 |
+
Number of Evoformer blocks in each activation checkpoint
|
| 453 |
+
clear_cache_between_blocks:
|
| 454 |
+
Whether to clear CUDA's GPU memory cache between blocks of the
|
| 455 |
+
stack. Slows down each block but can reduce fragmentation
|
| 456 |
+
"""
|
| 457 |
+
super(EvoformerStack, self).__init__()
|
| 458 |
+
|
| 459 |
+
self.blocks_per_ckpt = blocks_per_ckpt
|
| 460 |
+
self.clear_cache_between_blocks = clear_cache_between_blocks
|
| 461 |
+
|
| 462 |
+
self.blocks = nn.ModuleList()
|
| 463 |
+
|
| 464 |
+
for _ in range(no_blocks):
|
| 465 |
+
block = EvoformerBlock(
|
| 466 |
+
c_m=c_m,
|
| 467 |
+
c_z=c_z,
|
| 468 |
+
c_hidden_msa_att=c_hidden_msa_att,
|
| 469 |
+
c_hidden_opm=c_hidden_opm,
|
| 470 |
+
c_hidden_mul=c_hidden_mul,
|
| 471 |
+
c_hidden_pair_att=c_hidden_pair_att,
|
| 472 |
+
no_heads_msa=no_heads_msa,
|
| 473 |
+
no_heads_pair=no_heads_pair,
|
| 474 |
+
transition_n=transition_n,
|
| 475 |
+
msa_dropout=msa_dropout,
|
| 476 |
+
pair_dropout=pair_dropout,
|
| 477 |
+
inf=inf,
|
| 478 |
+
eps=eps,
|
| 479 |
+
)
|
| 480 |
+
self.blocks.append(block)
|
| 481 |
+
|
| 482 |
+
self.linear = Linear(c_m, c_s)
|
| 483 |
+
|
| 484 |
+
def forward(self,
|
| 485 |
+
m: torch.Tensor,
|
| 486 |
+
z: torch.Tensor,
|
| 487 |
+
msa_mask: torch.Tensor,
|
| 488 |
+
pair_mask: torch.Tensor,
|
| 489 |
+
chunk_size: int,
|
| 490 |
+
_mask_trans: bool = True,
|
| 491 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 492 |
+
"""
|
| 493 |
+
Args:
|
| 494 |
+
m:
|
| 495 |
+
[*, N_seq, N_res, C_m] MSA embedding
|
| 496 |
+
z:
|
| 497 |
+
[*, N_res, N_res, C_z] pair embedding
|
| 498 |
+
msa_mask:
|
| 499 |
+
[*, N_seq, N_res] MSA mask
|
| 500 |
+
pair_mask:
|
| 501 |
+
[*, N_res, N_res] pair mask
|
| 502 |
+
Returns:
|
| 503 |
+
m:
|
| 504 |
+
[*, N_seq, N_res, C_m] MSA embedding
|
| 505 |
+
z:
|
| 506 |
+
[*, N_res, N_res, C_z] pair embedding
|
| 507 |
+
s:
|
| 508 |
+
[*, N_res, C_s] single embedding (or None if extra MSA stack)
|
| 509 |
+
"""
|
| 510 |
+
blocks = [
|
| 511 |
+
partial(
|
| 512 |
+
b,
|
| 513 |
+
msa_mask=msa_mask,
|
| 514 |
+
pair_mask=pair_mask,
|
| 515 |
+
chunk_size=chunk_size,
|
| 516 |
+
_mask_trans=_mask_trans,
|
| 517 |
+
)
|
| 518 |
+
for b in self.blocks
|
| 519 |
+
]
|
| 520 |
+
|
| 521 |
+
if(self.clear_cache_between_blocks):
|
| 522 |
+
def block_with_cache_clear(block, *args):
|
| 523 |
+
torch.cuda.empty_cache()
|
| 524 |
+
return block(*args)
|
| 525 |
+
|
| 526 |
+
blocks = [partial(block_with_cache_clear, b) for b in blocks]
|
| 527 |
+
|
| 528 |
+
m, z = checkpoint_blocks(
|
| 529 |
+
blocks,
|
| 530 |
+
args=(m, z),
|
| 531 |
+
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
s = self.linear(m[..., 0, :, :])
|
| 535 |
+
|
| 536 |
+
return m, z, s
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
class ExtraMSAStack(nn.Module):
|
| 540 |
+
"""
|
| 541 |
+
Implements Algorithm 18.
|
| 542 |
+
"""
|
| 543 |
+
|
| 544 |
+
def __init__(self,
|
| 545 |
+
c_m: int,
|
| 546 |
+
c_z: int,
|
| 547 |
+
c_hidden_msa_att: int,
|
| 548 |
+
c_hidden_opm: int,
|
| 549 |
+
c_hidden_mul: int,
|
| 550 |
+
c_hidden_pair_att: int,
|
| 551 |
+
no_heads_msa: int,
|
| 552 |
+
no_heads_pair: int,
|
| 553 |
+
no_blocks: int,
|
| 554 |
+
transition_n: int,
|
| 555 |
+
msa_dropout: float,
|
| 556 |
+
pair_dropout: float,
|
| 557 |
+
inf: float,
|
| 558 |
+
eps: float,
|
| 559 |
+
ckpt: bool,
|
| 560 |
+
clear_cache_between_blocks: bool = False,
|
| 561 |
+
**kwargs,
|
| 562 |
+
):
|
| 563 |
+
super(ExtraMSAStack, self).__init__()
|
| 564 |
+
|
| 565 |
+
self.clear_cache_between_blocks = clear_cache_between_blocks
|
| 566 |
+
self.blocks = nn.ModuleList()
|
| 567 |
+
for _ in range(no_blocks):
|
| 568 |
+
block = ExtraMSABlock(
|
| 569 |
+
c_m=c_m,
|
| 570 |
+
c_z=c_z,
|
| 571 |
+
c_hidden_msa_att=c_hidden_msa_att,
|
| 572 |
+
c_hidden_opm=c_hidden_opm,
|
| 573 |
+
c_hidden_mul=c_hidden_mul,
|
| 574 |
+
c_hidden_pair_att=c_hidden_pair_att,
|
| 575 |
+
no_heads_msa=no_heads_msa,
|
| 576 |
+
no_heads_pair=no_heads_pair,
|
| 577 |
+
transition_n=transition_n,
|
| 578 |
+
msa_dropout=msa_dropout,
|
| 579 |
+
pair_dropout=pair_dropout,
|
| 580 |
+
inf=inf,
|
| 581 |
+
eps=eps,
|
| 582 |
+
ckpt=ckpt,
|
| 583 |
+
)
|
| 584 |
+
self.blocks.append(block)
|
| 585 |
+
|
| 586 |
+
def forward(self,
|
| 587 |
+
m: torch.Tensor,
|
| 588 |
+
z: torch.Tensor,
|
| 589 |
+
chunk_size: int,
|
| 590 |
+
msa_mask: Optional[torch.Tensor] = None,
|
| 591 |
+
pair_mask: Optional[torch.Tensor] = None,
|
| 592 |
+
_mask_trans: bool = True,
|
| 593 |
+
) -> torch.Tensor:
|
| 594 |
+
"""
|
| 595 |
+
Args:
|
| 596 |
+
m:
|
| 597 |
+
[*, N_extra, N_res, C_m] extra MSA embedding
|
| 598 |
+
z:
|
| 599 |
+
[*, N_res, N_res, C_z] pair embedding
|
| 600 |
+
msa_mask:
|
| 601 |
+
Optional [*, N_extra, N_res] MSA mask
|
| 602 |
+
pair_mask:
|
| 603 |
+
Optional [*, N_res, N_res] pair mask
|
| 604 |
+
Returns:
|
| 605 |
+
[*, N_res, N_res, C_z] pair update
|
| 606 |
+
"""
|
| 607 |
+
#checkpoint_fn = get_checkpoint_fn()
|
| 608 |
+
#blocks = [
|
| 609 |
+
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
|
| 610 |
+
#]
|
| 611 |
+
|
| 612 |
+
#def dodo(b, *args):
|
| 613 |
+
# torch.cuda.empty_cache()
|
| 614 |
+
# return b(*args)
|
| 615 |
+
|
| 616 |
+
#blocks = [partial(dodo, b) for b in blocks]
|
| 617 |
+
|
| 618 |
+
#for b in blocks:
|
| 619 |
+
# if(torch.is_grad_enabled()):
|
| 620 |
+
# m, z = checkpoint_fn(b, *(m, z))
|
| 621 |
+
# else:
|
| 622 |
+
# m, z = b(m, z)
|
| 623 |
+
|
| 624 |
+
for b in self.blocks:
|
| 625 |
+
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
|
| 626 |
+
|
| 627 |
+
if(self.clear_cache_between_blocks):
|
| 628 |
+
torch.cuda.empty_cache()
|
| 629 |
+
|
| 630 |
+
return z
|
openfold/model/heads.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
|
| 19 |
+
from openfold.model.primitives import Linear, LayerNorm
|
| 20 |
+
from openfold.utils.loss import (
|
| 21 |
+
compute_plddt,
|
| 22 |
+
compute_tm,
|
| 23 |
+
compute_predicted_aligned_error,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AuxiliaryHeads(nn.Module):
|
| 28 |
+
def __init__(self, config):
|
| 29 |
+
super(AuxiliaryHeads, self).__init__()
|
| 30 |
+
|
| 31 |
+
self.plddt = PerResidueLDDTCaPredictor(
|
| 32 |
+
**config["lddt"],
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.distogram = DistogramHead(
|
| 36 |
+
**config["distogram"],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.masked_msa = MaskedMSAHead(
|
| 40 |
+
**config["masked_msa"],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.experimentally_resolved = ExperimentallyResolvedHead(
|
| 44 |
+
**config["experimentally_resolved"],
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if config.tm.enabled:
|
| 48 |
+
self.tm = TMScoreHead(
|
| 49 |
+
**config.tm,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.config = config
|
| 53 |
+
|
| 54 |
+
def forward(self, outputs):
|
| 55 |
+
aux_out = {}
|
| 56 |
+
lddt_logits = self.plddt(outputs["sm"]["single"])
|
| 57 |
+
aux_out["lddt_logits"] = lddt_logits
|
| 58 |
+
|
| 59 |
+
# Required for relaxation later on
|
| 60 |
+
aux_out["plddt"] = compute_plddt(lddt_logits)
|
| 61 |
+
|
| 62 |
+
distogram_logits = self.distogram(outputs["pair"])
|
| 63 |
+
aux_out["distogram_logits"] = distogram_logits
|
| 64 |
+
|
| 65 |
+
masked_msa_logits = self.masked_msa(outputs["msa"])
|
| 66 |
+
aux_out["masked_msa_logits"] = masked_msa_logits
|
| 67 |
+
|
| 68 |
+
experimentally_resolved_logits = self.experimentally_resolved(
|
| 69 |
+
outputs["single"]
|
| 70 |
+
)
|
| 71 |
+
aux_out[
|
| 72 |
+
"experimentally_resolved_logits"
|
| 73 |
+
] = experimentally_resolved_logits
|
| 74 |
+
|
| 75 |
+
if self.config.tm.enabled:
|
| 76 |
+
tm_logits = self.tm(outputs["pair"])
|
| 77 |
+
aux_out["tm_logits"] = tm_logits
|
| 78 |
+
aux_out["predicted_tm_score"] = compute_tm(
|
| 79 |
+
tm_logits, **self.config.tm
|
| 80 |
+
)
|
| 81 |
+
aux_out.update(
|
| 82 |
+
compute_predicted_aligned_error(
|
| 83 |
+
tm_logits,
|
| 84 |
+
**self.config.tm,
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return aux_out
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class PerResidueLDDTCaPredictor(nn.Module):
|
| 92 |
+
def __init__(self, no_bins, c_in, c_hidden):
|
| 93 |
+
super(PerResidueLDDTCaPredictor, self).__init__()
|
| 94 |
+
|
| 95 |
+
self.no_bins = no_bins
|
| 96 |
+
self.c_in = c_in
|
| 97 |
+
self.c_hidden = c_hidden
|
| 98 |
+
|
| 99 |
+
self.layer_norm = LayerNorm(self.c_in)
|
| 100 |
+
|
| 101 |
+
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
|
| 102 |
+
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
|
| 103 |
+
self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
|
| 104 |
+
|
| 105 |
+
self.relu = nn.ReLU()
|
| 106 |
+
|
| 107 |
+
def forward(self, s):
|
| 108 |
+
s = self.layer_norm(s)
|
| 109 |
+
s = self.linear_1(s)
|
| 110 |
+
s = self.relu(s)
|
| 111 |
+
s = self.linear_2(s)
|
| 112 |
+
s = self.relu(s)
|
| 113 |
+
s = self.linear_3(s)
|
| 114 |
+
|
| 115 |
+
return s
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class DistogramHead(nn.Module):
|
| 119 |
+
"""
|
| 120 |
+
Computes a distogram probability distribution.
|
| 121 |
+
|
| 122 |
+
For use in computation of distogram loss, subsection 1.9.8
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, c_z, no_bins, **kwargs):
|
| 126 |
+
"""
|
| 127 |
+
Args:
|
| 128 |
+
c_z:
|
| 129 |
+
Input channel dimension
|
| 130 |
+
no_bins:
|
| 131 |
+
Number of distogram bins
|
| 132 |
+
"""
|
| 133 |
+
super(DistogramHead, self).__init__()
|
| 134 |
+
|
| 135 |
+
self.c_z = c_z
|
| 136 |
+
self.no_bins = no_bins
|
| 137 |
+
|
| 138 |
+
self.linear = Linear(self.c_z, self.no_bins, init="final")
|
| 139 |
+
|
| 140 |
+
def forward(self, z): # [*, N, N, C_z]
|
| 141 |
+
"""
|
| 142 |
+
Args:
|
| 143 |
+
z:
|
| 144 |
+
[*, N_res, N_res, C_z] pair embedding
|
| 145 |
+
Returns:
|
| 146 |
+
[*, N, N, no_bins] distogram probability distribution
|
| 147 |
+
"""
|
| 148 |
+
# [*, N, N, no_bins]
|
| 149 |
+
logits = self.linear(z)
|
| 150 |
+
logits = logits + logits.transpose(-2, -3)
|
| 151 |
+
return logits
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class TMScoreHead(nn.Module):
|
| 155 |
+
"""
|
| 156 |
+
For use in computation of TM-score, subsection 1.9.7
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, c_z, no_bins, **kwargs):
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
c_z:
|
| 163 |
+
Input channel dimension
|
| 164 |
+
no_bins:
|
| 165 |
+
Number of bins
|
| 166 |
+
"""
|
| 167 |
+
super(TMScoreHead, self).__init__()
|
| 168 |
+
|
| 169 |
+
self.c_z = c_z
|
| 170 |
+
self.no_bins = no_bins
|
| 171 |
+
|
| 172 |
+
self.linear = Linear(self.c_z, self.no_bins, init="final")
|
| 173 |
+
|
| 174 |
+
def forward(self, z):
|
| 175 |
+
"""
|
| 176 |
+
Args:
|
| 177 |
+
z:
|
| 178 |
+
[*, N_res, N_res, C_z] pairwise embedding
|
| 179 |
+
Returns:
|
| 180 |
+
[*, N_res, N_res, no_bins] prediction
|
| 181 |
+
"""
|
| 182 |
+
# [*, N, N, no_bins]
|
| 183 |
+
logits = self.linear(z)
|
| 184 |
+
return logits
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class MaskedMSAHead(nn.Module):
|
| 188 |
+
"""
|
| 189 |
+
For use in computation of masked MSA loss, subsection 1.9.9
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, c_m, c_out, **kwargs):
|
| 193 |
+
"""
|
| 194 |
+
Args:
|
| 195 |
+
c_m:
|
| 196 |
+
MSA channel dimension
|
| 197 |
+
c_out:
|
| 198 |
+
Output channel dimension
|
| 199 |
+
"""
|
| 200 |
+
super(MaskedMSAHead, self).__init__()
|
| 201 |
+
|
| 202 |
+
self.c_m = c_m
|
| 203 |
+
self.c_out = c_out
|
| 204 |
+
|
| 205 |
+
self.linear = Linear(self.c_m, self.c_out, init="final")
|
| 206 |
+
|
| 207 |
+
def forward(self, m):
|
| 208 |
+
"""
|
| 209 |
+
Args:
|
| 210 |
+
m:
|
| 211 |
+
[*, N_seq, N_res, C_m] MSA embedding
|
| 212 |
+
Returns:
|
| 213 |
+
[*, N_seq, N_res, C_out] reconstruction
|
| 214 |
+
"""
|
| 215 |
+
# [*, N_seq, N_res, C_out]
|
| 216 |
+
logits = self.linear(m)
|
| 217 |
+
return logits
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ExperimentallyResolvedHead(nn.Module):
|
| 221 |
+
"""
|
| 222 |
+
For use in computation of "experimentally resolved" loss, subsection
|
| 223 |
+
1.9.10
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
def __init__(self, c_s, c_out, **kwargs):
|
| 227 |
+
"""
|
| 228 |
+
Args:
|
| 229 |
+
c_s:
|
| 230 |
+
Input channel dimension
|
| 231 |
+
c_out:
|
| 232 |
+
Number of distogram bins
|
| 233 |
+
"""
|
| 234 |
+
super(ExperimentallyResolvedHead, self).__init__()
|
| 235 |
+
|
| 236 |
+
self.c_s = c_s
|
| 237 |
+
self.c_out = c_out
|
| 238 |
+
|
| 239 |
+
self.linear = Linear(self.c_s, self.c_out, init="final")
|
| 240 |
+
|
| 241 |
+
def forward(self, s):
|
| 242 |
+
"""
|
| 243 |
+
Args:
|
| 244 |
+
s:
|
| 245 |
+
[*, N_res, C_s] single embedding
|
| 246 |
+
Returns:
|
| 247 |
+
[*, N, C_out] logits
|
| 248 |
+
"""
|
| 249 |
+
# [*, N, C_out]
|
| 250 |
+
logits = self.linear(s)
|
| 251 |
+
return logits
|
openfold/model/model.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from openfold.utils.feats import (
|
| 21 |
+
pseudo_beta_fn,
|
| 22 |
+
build_extra_msa_feat,
|
| 23 |
+
build_template_angle_feat,
|
| 24 |
+
build_template_pair_feat,
|
| 25 |
+
atom14_to_atom37,
|
| 26 |
+
)
|
| 27 |
+
from openfold.model.embedders import (
|
| 28 |
+
InputEmbedder,
|
| 29 |
+
RecyclingEmbedder,
|
| 30 |
+
TemplateAngleEmbedder,
|
| 31 |
+
TemplatePairEmbedder,
|
| 32 |
+
ExtraMSAEmbedder,
|
| 33 |
+
)
|
| 34 |
+
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
|
| 35 |
+
from openfold.model.heads import AuxiliaryHeads
|
| 36 |
+
import openfold.np.residue_constants as residue_constants
|
| 37 |
+
from openfold.model.structure_module import StructureModule
|
| 38 |
+
from openfold.model.template import (
|
| 39 |
+
TemplatePairStack,
|
| 40 |
+
TemplatePointwiseAttention,
|
| 41 |
+
)
|
| 42 |
+
from openfold.utils.loss import (
|
| 43 |
+
compute_plddt,
|
| 44 |
+
)
|
| 45 |
+
from openfold.utils.tensor_utils import (
|
| 46 |
+
dict_multimap,
|
| 47 |
+
tensor_tree_map,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AlphaFold(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
Alphafold 2.
|
| 54 |
+
|
| 55 |
+
Implements Algorithm 2 (but with training).
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, config):
|
| 59 |
+
"""
|
| 60 |
+
Args:
|
| 61 |
+
config:
|
| 62 |
+
A dict-like config object (like the one in config.py)
|
| 63 |
+
"""
|
| 64 |
+
super(AlphaFold, self).__init__()
|
| 65 |
+
|
| 66 |
+
self.globals = config.globals
|
| 67 |
+
config = config.model
|
| 68 |
+
template_config = config.template
|
| 69 |
+
extra_msa_config = config.extra_msa
|
| 70 |
+
|
| 71 |
+
# Main trunk + structure module
|
| 72 |
+
self.input_embedder = InputEmbedder(
|
| 73 |
+
**config["input_embedder"],
|
| 74 |
+
)
|
| 75 |
+
self.recycling_embedder = RecyclingEmbedder(
|
| 76 |
+
**config["recycling_embedder"],
|
| 77 |
+
)
|
| 78 |
+
self.template_angle_embedder = TemplateAngleEmbedder(
|
| 79 |
+
**template_config["template_angle_embedder"],
|
| 80 |
+
)
|
| 81 |
+
self.template_pair_embedder = TemplatePairEmbedder(
|
| 82 |
+
**template_config["template_pair_embedder"],
|
| 83 |
+
)
|
| 84 |
+
self.template_pair_stack = TemplatePairStack(
|
| 85 |
+
**template_config["template_pair_stack"],
|
| 86 |
+
)
|
| 87 |
+
self.template_pointwise_att = TemplatePointwiseAttention(
|
| 88 |
+
**template_config["template_pointwise_attention"],
|
| 89 |
+
)
|
| 90 |
+
self.extra_msa_embedder = ExtraMSAEmbedder(
|
| 91 |
+
**extra_msa_config["extra_msa_embedder"],
|
| 92 |
+
)
|
| 93 |
+
self.extra_msa_stack = ExtraMSAStack(
|
| 94 |
+
**extra_msa_config["extra_msa_stack"],
|
| 95 |
+
)
|
| 96 |
+
self.evoformer = EvoformerStack(
|
| 97 |
+
**config["evoformer_stack"],
|
| 98 |
+
)
|
| 99 |
+
self.structure_module = StructureModule(
|
| 100 |
+
**config["structure_module"],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.aux_heads = AuxiliaryHeads(
|
| 104 |
+
config["heads"],
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.config = config
|
| 108 |
+
|
| 109 |
+
def embed_templates(self, batch, z, pair_mask, templ_dim):
|
| 110 |
+
# Embed the templates one at a time (with a poor man's vmap)
|
| 111 |
+
template_embeds = []
|
| 112 |
+
n_templ = batch["template_aatype"].shape[templ_dim]
|
| 113 |
+
for i in range(n_templ):
|
| 114 |
+
idx = batch["template_aatype"].new_tensor(i)
|
| 115 |
+
single_template_feats = tensor_tree_map(
|
| 116 |
+
lambda t: torch.index_select(t, templ_dim, idx),
|
| 117 |
+
batch,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
single_template_embeds = {}
|
| 121 |
+
if self.config.template.embed_angles:
|
| 122 |
+
template_angle_feat = build_template_angle_feat(
|
| 123 |
+
single_template_feats,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# [*, S_t, N, C_m]
|
| 127 |
+
a = self.template_angle_embedder(template_angle_feat)
|
| 128 |
+
|
| 129 |
+
single_template_embeds["angle"] = a
|
| 130 |
+
|
| 131 |
+
# [*, S_t, N, N, C_t]
|
| 132 |
+
t = build_template_pair_feat(
|
| 133 |
+
single_template_feats,
|
| 134 |
+
inf=self.config.template.inf,
|
| 135 |
+
eps=self.config.template.eps,
|
| 136 |
+
**self.config.template.distogram,
|
| 137 |
+
).to(z.dtype)
|
| 138 |
+
t = self.template_pair_embedder(t)
|
| 139 |
+
|
| 140 |
+
single_template_embeds.update({"pair": t})
|
| 141 |
+
|
| 142 |
+
template_embeds.append(single_template_embeds)
|
| 143 |
+
|
| 144 |
+
template_embeds = dict_multimap(
|
| 145 |
+
partial(torch.cat, dim=templ_dim),
|
| 146 |
+
template_embeds,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# [*, S_t, N, N, C_z]
|
| 150 |
+
t = self.template_pair_stack(
|
| 151 |
+
template_embeds["pair"],
|
| 152 |
+
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
| 153 |
+
chunk_size=self.globals.chunk_size,
|
| 154 |
+
_mask_trans=self.config._mask_trans,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# [*, N, N, C_z]
|
| 158 |
+
t = self.template_pointwise_att(
|
| 159 |
+
t,
|
| 160 |
+
z,
|
| 161 |
+
template_mask=batch["template_mask"].to(dtype=z.dtype),
|
| 162 |
+
chunk_size=self.globals.chunk_size,
|
| 163 |
+
)
|
| 164 |
+
t = t * (torch.sum(batch["template_mask"]) > 0)
|
| 165 |
+
|
| 166 |
+
ret = {}
|
| 167 |
+
if self.config.template.embed_angles:
|
| 168 |
+
ret["template_angle_embedding"] = template_embeds["angle"]
|
| 169 |
+
|
| 170 |
+
ret.update({"template_pair_embedding": t})
|
| 171 |
+
|
| 172 |
+
return ret
|
| 173 |
+
|
| 174 |
+
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
|
| 175 |
+
# Primary output dictionary
|
| 176 |
+
outputs = {}
|
| 177 |
+
|
| 178 |
+
# This needs to be done manually for DeepSpeed's sake
|
| 179 |
+
dtype = next(self.parameters()).dtype
|
| 180 |
+
for k in feats:
|
| 181 |
+
if(feats[k].dtype == torch.float32):
|
| 182 |
+
feats[k] = feats[k].to(dtype=dtype)
|
| 183 |
+
|
| 184 |
+
# Grab some data about the input
|
| 185 |
+
batch_dims = feats["target_feat"].shape[:-2]
|
| 186 |
+
no_batch_dims = len(batch_dims)
|
| 187 |
+
n = feats["target_feat"].shape[-2]
|
| 188 |
+
n_seq = feats["msa_feat"].shape[-3]
|
| 189 |
+
device = feats["target_feat"].device
|
| 190 |
+
|
| 191 |
+
# Prep some features
|
| 192 |
+
seq_mask = feats["seq_mask"]
|
| 193 |
+
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
|
| 194 |
+
msa_mask = feats["msa_mask"]
|
| 195 |
+
|
| 196 |
+
# Initialize the MSA and pair representations
|
| 197 |
+
|
| 198 |
+
# m: [*, S_c, N, C_m]
|
| 199 |
+
# z: [*, N, N, C_z]
|
| 200 |
+
m, z = self.input_embedder(
|
| 201 |
+
feats["target_feat"],
|
| 202 |
+
feats["residue_index"],
|
| 203 |
+
feats["msa_feat"],
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Initialize the recycling embeddings, if needs be
|
| 207 |
+
if None in [m_1_prev, z_prev, x_prev]:
|
| 208 |
+
# [*, N, C_m]
|
| 209 |
+
m_1_prev = m.new_zeros(
|
| 210 |
+
(*batch_dims, n, self.config.input_embedder.c_m),
|
| 211 |
+
requires_grad=False,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# [*, N, N, C_z]
|
| 215 |
+
z_prev = z.new_zeros(
|
| 216 |
+
(*batch_dims, n, n, self.config.input_embedder.c_z),
|
| 217 |
+
requires_grad=False,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# [*, N, 3]
|
| 221 |
+
x_prev = z.new_zeros(
|
| 222 |
+
(*batch_dims, n, residue_constants.atom_type_num, 3),
|
| 223 |
+
requires_grad=False,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
x_prev = pseudo_beta_fn(
|
| 227 |
+
feats["aatype"], x_prev, None
|
| 228 |
+
).to(dtype=z.dtype)
|
| 229 |
+
|
| 230 |
+
# m_1_prev_emb: [*, N, C_m]
|
| 231 |
+
# z_prev_emb: [*, N, N, C_z]
|
| 232 |
+
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
|
| 233 |
+
m_1_prev,
|
| 234 |
+
z_prev,
|
| 235 |
+
x_prev,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# If the number of recycling iterations is 0, skip recycling
|
| 239 |
+
# altogether. We zero them this way instead of computing them
|
| 240 |
+
# conditionally to avoid leaving parameters unused, which has annoying
|
| 241 |
+
# implications for DDP training.
|
| 242 |
+
if(not _recycle):
|
| 243 |
+
m_1_prev_emb *= 0
|
| 244 |
+
z_prev_emb *= 0
|
| 245 |
+
|
| 246 |
+
# [*, S_c, N, C_m]
|
| 247 |
+
m[..., 0, :, :] += m_1_prev_emb
|
| 248 |
+
|
| 249 |
+
# [*, N, N, C_z]
|
| 250 |
+
z += z_prev_emb
|
| 251 |
+
|
| 252 |
+
# Possibly prevents memory fragmentation
|
| 253 |
+
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
|
| 254 |
+
|
| 255 |
+
# Embed the templates + merge with MSA/pair embeddings
|
| 256 |
+
if self.config.template.enabled:
|
| 257 |
+
template_feats = {
|
| 258 |
+
k: v for k, v in feats.items() if k.startswith("template_")
|
| 259 |
+
}
|
| 260 |
+
template_embeds = self.embed_templates(
|
| 261 |
+
template_feats,
|
| 262 |
+
z,
|
| 263 |
+
pair_mask.to(dtype=z.dtype),
|
| 264 |
+
no_batch_dims,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# [*, N, N, C_z]
|
| 268 |
+
z = z + template_embeds["template_pair_embedding"]
|
| 269 |
+
|
| 270 |
+
if self.config.template.embed_angles:
|
| 271 |
+
# [*, S = S_c + S_t, N, C_m]
|
| 272 |
+
m = torch.cat(
|
| 273 |
+
[m, template_embeds["template_angle_embedding"]],
|
| 274 |
+
dim=-3
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# [*, S, N]
|
| 278 |
+
torsion_angles_mask = feats["template_torsion_angles_mask"]
|
| 279 |
+
msa_mask = torch.cat(
|
| 280 |
+
[feats["msa_mask"], torsion_angles_mask[..., 2]],
|
| 281 |
+
dim=-2
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Embed extra MSA features + merge with pairwise embeddings
|
| 285 |
+
if self.config.extra_msa.enabled:
|
| 286 |
+
# [*, S_e, N, C_e]
|
| 287 |
+
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
|
| 288 |
+
|
| 289 |
+
# [*, N, N, C_z]
|
| 290 |
+
z = self.extra_msa_stack(
|
| 291 |
+
a,
|
| 292 |
+
z,
|
| 293 |
+
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
|
| 294 |
+
chunk_size=self.globals.chunk_size,
|
| 295 |
+
pair_mask=pair_mask.to(dtype=z.dtype),
|
| 296 |
+
_mask_trans=self.config._mask_trans,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Run MSA + pair embeddings through the trunk of the network
|
| 300 |
+
# m: [*, S, N, C_m]
|
| 301 |
+
# z: [*, N, N, C_z]
|
| 302 |
+
# s: [*, N, C_s]
|
| 303 |
+
m, z, s = self.evoformer(
|
| 304 |
+
m,
|
| 305 |
+
z,
|
| 306 |
+
msa_mask=msa_mask.to(dtype=m.dtype),
|
| 307 |
+
pair_mask=pair_mask.to(dtype=z.dtype),
|
| 308 |
+
chunk_size=self.globals.chunk_size,
|
| 309 |
+
_mask_trans=self.config._mask_trans,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
outputs["msa"] = m[..., :n_seq, :, :]
|
| 313 |
+
outputs["pair"] = z
|
| 314 |
+
outputs["single"] = s
|
| 315 |
+
|
| 316 |
+
# Predict 3D structure
|
| 317 |
+
outputs["sm"] = self.structure_module(
|
| 318 |
+
s,
|
| 319 |
+
z,
|
| 320 |
+
feats["aatype"],
|
| 321 |
+
mask=feats["seq_mask"].to(dtype=s.dtype),
|
| 322 |
+
)
|
| 323 |
+
outputs["final_atom_positions"] = atom14_to_atom37(
|
| 324 |
+
outputs["sm"]["positions"][-1], feats
|
| 325 |
+
)
|
| 326 |
+
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
|
| 327 |
+
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
|
| 328 |
+
|
| 329 |
+
# Save embeddings for use during the next recycling iteration
|
| 330 |
+
|
| 331 |
+
# [*, N, C_m]
|
| 332 |
+
m_1_prev = m[..., 0, :, :]
|
| 333 |
+
|
| 334 |
+
# [*, N, N, C_z]
|
| 335 |
+
z_prev = z
|
| 336 |
+
|
| 337 |
+
# [*, N, 3]
|
| 338 |
+
x_prev = outputs["final_atom_positions"]
|
| 339 |
+
|
| 340 |
+
return outputs, m_1_prev, z_prev, x_prev
|
| 341 |
+
|
| 342 |
+
def _disable_activation_checkpointing(self):
|
| 343 |
+
self.template_pair_stack.blocks_per_ckpt = None
|
| 344 |
+
self.evoformer.blocks_per_ckpt = None
|
| 345 |
+
|
| 346 |
+
for b in self.extra_msa_stack.blocks:
|
| 347 |
+
b.ckpt = False
|
| 348 |
+
|
| 349 |
+
def _enable_activation_checkpointing(self):
|
| 350 |
+
self.template_pair_stack.blocks_per_ckpt = (
|
| 351 |
+
self.config.template.template_pair_stack.blocks_per_ckpt
|
| 352 |
+
)
|
| 353 |
+
self.evoformer.blocks_per_ckpt = (
|
| 354 |
+
self.config.evoformer_stack.blocks_per_ckpt
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
for b in self.extra_msa_stack.blocks:
|
| 358 |
+
b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
|
| 359 |
+
|
| 360 |
+
def forward(self, batch):
|
| 361 |
+
"""
|
| 362 |
+
Args:
|
| 363 |
+
batch:
|
| 364 |
+
Dictionary of arguments outlined in Algorithm 2. Keys must
|
| 365 |
+
include the official names of the features in the
|
| 366 |
+
supplement subsection 1.2.9.
|
| 367 |
+
|
| 368 |
+
The final dimension of each input must have length equal to
|
| 369 |
+
the number of recycling iterations.
|
| 370 |
+
|
| 371 |
+
Features (without the recycling dimension):
|
| 372 |
+
|
| 373 |
+
"aatype" ([*, N_res]):
|
| 374 |
+
Contrary to the supplement, this tensor of residue
|
| 375 |
+
indices is not one-hot.
|
| 376 |
+
"target_feat" ([*, N_res, C_tf])
|
| 377 |
+
One-hot encoding of the target sequence. C_tf is
|
| 378 |
+
config.model.input_embedder.tf_dim.
|
| 379 |
+
"residue_index" ([*, N_res])
|
| 380 |
+
Tensor whose final dimension consists of
|
| 381 |
+
consecutive indices from 0 to N_res.
|
| 382 |
+
"msa_feat" ([*, N_seq, N_res, C_msa])
|
| 383 |
+
MSA features, constructed as in the supplement.
|
| 384 |
+
C_msa is config.model.input_embedder.msa_dim.
|
| 385 |
+
"seq_mask" ([*, N_res])
|
| 386 |
+
1-D sequence mask
|
| 387 |
+
"msa_mask" ([*, N_seq, N_res])
|
| 388 |
+
MSA mask
|
| 389 |
+
"pair_mask" ([*, N_res, N_res])
|
| 390 |
+
2-D pair mask
|
| 391 |
+
"extra_msa_mask" ([*, N_extra, N_res])
|
| 392 |
+
Extra MSA mask
|
| 393 |
+
"template_mask" ([*, N_templ])
|
| 394 |
+
Template mask (on the level of templates, not
|
| 395 |
+
residues)
|
| 396 |
+
"template_aatype" ([*, N_templ, N_res])
|
| 397 |
+
Tensor of template residue indices (indices greater
|
| 398 |
+
than 19 are clamped to 20 (Unknown))
|
| 399 |
+
"template_all_atom_positions"
|
| 400 |
+
([*, N_templ, N_res, 37, 3])
|
| 401 |
+
Template atom coordinates in atom37 format
|
| 402 |
+
"template_all_atom_mask" ([*, N_templ, N_res, 37])
|
| 403 |
+
Template atom coordinate mask
|
| 404 |
+
"template_pseudo_beta" ([*, N_templ, N_res, 3])
|
| 405 |
+
Positions of template carbon "pseudo-beta" atoms
|
| 406 |
+
(i.e. C_beta for all residues but glycine, for
|
| 407 |
+
for which C_alpha is used instead)
|
| 408 |
+
"template_pseudo_beta_mask" ([*, N_templ, N_res])
|
| 409 |
+
Pseudo-beta mask
|
| 410 |
+
"""
|
| 411 |
+
# Initialize recycling embeddings
|
| 412 |
+
m_1_prev, z_prev, x_prev = None, None, None
|
| 413 |
+
|
| 414 |
+
# Disable activation checkpointing for the first few recycling iters
|
| 415 |
+
is_grad_enabled = torch.is_grad_enabled()
|
| 416 |
+
self._disable_activation_checkpointing()
|
| 417 |
+
|
| 418 |
+
# Main recycling loop
|
| 419 |
+
num_iters = batch["aatype"].shape[-1]
|
| 420 |
+
for cycle_no in range(num_iters):
|
| 421 |
+
# Select the features for the current recycling cycle
|
| 422 |
+
fetch_cur_batch = lambda t: t[..., cycle_no]
|
| 423 |
+
feats = tensor_tree_map(fetch_cur_batch, batch)
|
| 424 |
+
|
| 425 |
+
# Enable grad iff we're training and it's the final recycling layer
|
| 426 |
+
is_final_iter = cycle_no == (num_iters - 1)
|
| 427 |
+
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
|
| 428 |
+
if is_final_iter:
|
| 429 |
+
self._enable_activation_checkpointing()
|
| 430 |
+
# Sidestep AMP bug (PyTorch issue #65766)
|
| 431 |
+
if torch.is_autocast_enabled():
|
| 432 |
+
torch.clear_autocast_cache()
|
| 433 |
+
|
| 434 |
+
# Run the next iteration of the model
|
| 435 |
+
outputs, m_1_prev, z_prev, x_prev = self.iteration(
|
| 436 |
+
feats,
|
| 437 |
+
m_1_prev,
|
| 438 |
+
z_prev,
|
| 439 |
+
x_prev,
|
| 440 |
+
_recycle=(num_iters > 1)
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Run auxiliary heads
|
| 444 |
+
outputs.update(self.aux_heads(outputs))
|
| 445 |
+
|
| 446 |
+
return outputs
|
openfold/model/msa.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from typing import Optional, List, Tuple
|
| 20 |
+
|
| 21 |
+
from openfold.model.primitives import (
|
| 22 |
+
Linear,
|
| 23 |
+
LayerNorm,
|
| 24 |
+
Attention,
|
| 25 |
+
GlobalAttention,
|
| 26 |
+
_attention_chunked_trainable,
|
| 27 |
+
)
|
| 28 |
+
from openfold.utils.checkpointing import get_checkpoint_fn
|
| 29 |
+
from openfold.utils.tensor_utils import (
|
| 30 |
+
chunk_layer,
|
| 31 |
+
permute_final_dims,
|
| 32 |
+
flatten_final_dims,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MSAAttention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
c_in,
|
| 40 |
+
c_hidden,
|
| 41 |
+
no_heads,
|
| 42 |
+
pair_bias=False,
|
| 43 |
+
c_z=None,
|
| 44 |
+
inf=1e9,
|
| 45 |
+
):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
c_in:
|
| 49 |
+
Input channel dimension
|
| 50 |
+
c_hidden:
|
| 51 |
+
Per-head hidden channel dimension
|
| 52 |
+
no_heads:
|
| 53 |
+
Number of attention heads
|
| 54 |
+
pair_bias:
|
| 55 |
+
Whether to use pair embedding bias
|
| 56 |
+
c_z:
|
| 57 |
+
Pair embedding channel dimension. Ignored unless pair_bias
|
| 58 |
+
is true
|
| 59 |
+
inf:
|
| 60 |
+
A large number to be used in computing the attention mask
|
| 61 |
+
"""
|
| 62 |
+
super(MSAAttention, self).__init__()
|
| 63 |
+
|
| 64 |
+
self.c_in = c_in
|
| 65 |
+
self.c_hidden = c_hidden
|
| 66 |
+
self.no_heads = no_heads
|
| 67 |
+
self.pair_bias = pair_bias
|
| 68 |
+
self.c_z = c_z
|
| 69 |
+
self.inf = inf
|
| 70 |
+
|
| 71 |
+
self.layer_norm_m = LayerNorm(self.c_in)
|
| 72 |
+
|
| 73 |
+
self.layer_norm_z = None
|
| 74 |
+
self.linear_z = None
|
| 75 |
+
if self.pair_bias:
|
| 76 |
+
self.layer_norm_z = LayerNorm(self.c_z)
|
| 77 |
+
self.linear_z = Linear(
|
| 78 |
+
self.c_z, self.no_heads, bias=False, init="normal"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.mha = Attention(
|
| 82 |
+
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
@torch.jit.ignore
|
| 86 |
+
def _chunk(self,
|
| 87 |
+
m: torch.Tensor,
|
| 88 |
+
biases: List[torch.Tensor],
|
| 89 |
+
chunk_size: int,
|
| 90 |
+
) -> torch.Tensor:
|
| 91 |
+
return chunk_layer(
|
| 92 |
+
self.mha,
|
| 93 |
+
{"q_x": m, "kv_x": m, "biases": biases},
|
| 94 |
+
chunk_size=chunk_size,
|
| 95 |
+
no_batch_dims=len(m.shape[:-2]),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def _prep_inputs(self,
|
| 99 |
+
m: torch.Tensor,
|
| 100 |
+
z: Optional[torch.Tensor],
|
| 101 |
+
mask: Optional[torch.Tensor]
|
| 102 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 103 |
+
# [*, N_seq, N_res, C_m]
|
| 104 |
+
m = self.layer_norm_m(m)
|
| 105 |
+
|
| 106 |
+
n_seq, n_res = m.shape[-3:-1]
|
| 107 |
+
if mask is None:
|
| 108 |
+
# [*, N_seq, N_res]
|
| 109 |
+
mask = m.new_ones(
|
| 110 |
+
m.shape[:-3] + (n_seq, n_res),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# [*, N_seq, 1, 1, N_res]
|
| 114 |
+
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
| 115 |
+
|
| 116 |
+
# This step simply returns a larger view of the bias, and does not
|
| 117 |
+
# consume additional memory.
|
| 118 |
+
# [*, N_seq, no_heads, N_res, N_res]
|
| 119 |
+
#bias = bias.expand(
|
| 120 |
+
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
|
| 121 |
+
#)
|
| 122 |
+
|
| 123 |
+
if (self.pair_bias and
|
| 124 |
+
z is not None and # For the
|
| 125 |
+
self.layer_norm_z is not None and # benefit of
|
| 126 |
+
self.linear_z is not None # TorchScript
|
| 127 |
+
):
|
| 128 |
+
# [*, N_res, N_res, C_z]
|
| 129 |
+
z = self.layer_norm_z(z)
|
| 130 |
+
|
| 131 |
+
# [*, N_res, N_res, no_heads]
|
| 132 |
+
z = self.linear_z(z)
|
| 133 |
+
|
| 134 |
+
# [*, 1, no_heads, N_res, N_res]
|
| 135 |
+
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
|
| 136 |
+
|
| 137 |
+
return m, mask_bias, z
|
| 138 |
+
|
| 139 |
+
@torch.jit.ignore
|
| 140 |
+
def _chunked_msa_attn(self,
|
| 141 |
+
m: torch.Tensor,
|
| 142 |
+
z: Optional[torch.Tensor],
|
| 143 |
+
mask: Optional[torch.Tensor],
|
| 144 |
+
chunk_logits: int,
|
| 145 |
+
checkpoint: bool,
|
| 146 |
+
) -> torch.Tensor:
|
| 147 |
+
MSA_DIM = -4
|
| 148 |
+
|
| 149 |
+
def _get_qkv(m, z):
|
| 150 |
+
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
| 151 |
+
q, k, v = self.mha._prep_qkv(m, m)
|
| 152 |
+
return m, q, k, v, mask_bias, z
|
| 153 |
+
|
| 154 |
+
checkpoint_fn = get_checkpoint_fn()
|
| 155 |
+
|
| 156 |
+
if(torch.is_grad_enabled() and checkpoint):
|
| 157 |
+
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
|
| 158 |
+
else:
|
| 159 |
+
m, q, k, v, mask_bias, z = _get_qkv(m, z)
|
| 160 |
+
|
| 161 |
+
o = _attention_chunked_trainable(
|
| 162 |
+
query=q,
|
| 163 |
+
key=k,
|
| 164 |
+
value=v,
|
| 165 |
+
biases=[mask_bias, z],
|
| 166 |
+
chunk_size=chunk_logits,
|
| 167 |
+
chunk_dim=MSA_DIM,
|
| 168 |
+
checkpoint=checkpoint,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if(torch.is_grad_enabled() and checkpoint):
|
| 172 |
+
# Storing an additional m here is far from ideal
|
| 173 |
+
m = checkpoint_fn(self.mha._wrap_up, o, m)
|
| 174 |
+
else:
|
| 175 |
+
m = self.mha._wrap_up(o, m)
|
| 176 |
+
|
| 177 |
+
return m
|
| 178 |
+
|
| 179 |
+
def forward(self,
|
| 180 |
+
m: torch.Tensor,
|
| 181 |
+
z: Optional[torch.Tensor] = None,
|
| 182 |
+
mask: Optional[torch.Tensor] = None,
|
| 183 |
+
chunk_size: Optional[int] = None,
|
| 184 |
+
_chunk_logits: Optional[int] = None,
|
| 185 |
+
_checkpoint_chunks: Optional[bool] = None,
|
| 186 |
+
) -> torch.Tensor:
|
| 187 |
+
"""
|
| 188 |
+
Args:
|
| 189 |
+
m:
|
| 190 |
+
[*, N_seq, N_res, C_m] MSA embedding
|
| 191 |
+
z:
|
| 192 |
+
[*, N_res, N_res, C_z] pair embedding. Required only if
|
| 193 |
+
pair_bias is True
|
| 194 |
+
mask:
|
| 195 |
+
[*, N_seq, N_res] MSA mask
|
| 196 |
+
chunk_size:
|
| 197 |
+
Size of chunks into which the inputs are split along their
|
| 198 |
+
batch dimensions. A low value decreases memory overhead at the
|
| 199 |
+
cost of slower execution. Chunking is not performed by default.
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
if(_chunk_logits is not None):
|
| 203 |
+
return self._chunked_msa_attn(
|
| 204 |
+
m=m, z=z, mask=mask,
|
| 205 |
+
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
| 209 |
+
|
| 210 |
+
biases = [mask_bias]
|
| 211 |
+
if(z is not None):
|
| 212 |
+
biases.append(z)
|
| 213 |
+
|
| 214 |
+
if chunk_size is not None:
|
| 215 |
+
m = self._chunk(m, biases, chunk_size)
|
| 216 |
+
else:
|
| 217 |
+
m = self.mha(
|
| 218 |
+
q_x=m,
|
| 219 |
+
kv_x=m,
|
| 220 |
+
biases=biases
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return m
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class MSARowAttentionWithPairBias(MSAAttention):
|
| 227 |
+
"""
|
| 228 |
+
Implements Algorithm 7.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
c_m:
|
| 235 |
+
Input channel dimension
|
| 236 |
+
c_z:
|
| 237 |
+
Pair embedding channel dimension
|
| 238 |
+
c_hidden:
|
| 239 |
+
Per-head hidden channel dimension
|
| 240 |
+
no_heads:
|
| 241 |
+
Number of attention heads
|
| 242 |
+
inf:
|
| 243 |
+
Large number used to construct attention masks
|
| 244 |
+
"""
|
| 245 |
+
super(MSARowAttentionWithPairBias, self).__init__(
|
| 246 |
+
c_m,
|
| 247 |
+
c_hidden,
|
| 248 |
+
no_heads,
|
| 249 |
+
pair_bias=True,
|
| 250 |
+
c_z=c_z,
|
| 251 |
+
inf=inf,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class MSAColumnAttention(nn.Module):
|
| 256 |
+
"""
|
| 257 |
+
Implements Algorithm 8.
|
| 258 |
+
|
| 259 |
+
By rights, this should also be a subclass of MSAAttention. Alas,
|
| 260 |
+
most inheritance isn't supported by TorchScript.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
|
| 264 |
+
"""
|
| 265 |
+
Args:
|
| 266 |
+
c_m:
|
| 267 |
+
MSA channel dimension
|
| 268 |
+
c_hidden:
|
| 269 |
+
Per-head hidden channel dimension
|
| 270 |
+
no_heads:
|
| 271 |
+
Number of attention heads
|
| 272 |
+
inf:
|
| 273 |
+
Large number used to construct attention masks
|
| 274 |
+
"""
|
| 275 |
+
super(MSAColumnAttention, self).__init__()
|
| 276 |
+
|
| 277 |
+
self.c_m = c_m
|
| 278 |
+
self.c_hidden = c_hidden
|
| 279 |
+
self.no_heads = no_heads
|
| 280 |
+
self.inf = inf
|
| 281 |
+
|
| 282 |
+
self._msa_att = MSAAttention(
|
| 283 |
+
c_in=c_m,
|
| 284 |
+
c_hidden=c_hidden,
|
| 285 |
+
no_heads=no_heads,
|
| 286 |
+
pair_bias=False,
|
| 287 |
+
c_z=None,
|
| 288 |
+
inf=inf,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def forward(self,
|
| 292 |
+
m: torch.Tensor,
|
| 293 |
+
mask: Optional[torch.Tensor] = None,
|
| 294 |
+
chunk_size: Optional[int] = None
|
| 295 |
+
) -> torch.Tensor:
|
| 296 |
+
"""
|
| 297 |
+
Args:
|
| 298 |
+
m:
|
| 299 |
+
[*, N_seq, N_res, C_m] MSA embedding
|
| 300 |
+
mask:
|
| 301 |
+
[*, N_seq, N_res] MSA mask
|
| 302 |
+
chunk_size:
|
| 303 |
+
Size of chunks into which the inputs are split along their
|
| 304 |
+
batch dimensions. A low value decreases memory overhead at the
|
| 305 |
+
cost of slower execution. Chunking is not performed by default.
|
| 306 |
+
"""
|
| 307 |
+
# [*, N_res, N_seq, C_in]
|
| 308 |
+
m = m.transpose(-2, -3)
|
| 309 |
+
if mask is not None:
|
| 310 |
+
mask = mask.transpose(-1, -2)
|
| 311 |
+
|
| 312 |
+
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
|
| 313 |
+
|
| 314 |
+
# [*, N_seq, N_res, C_in]
|
| 315 |
+
m = m.transpose(-2, -3)
|
| 316 |
+
if mask is not None:
|
| 317 |
+
mask = mask.transpose(-1, -2)
|
| 318 |
+
|
| 319 |
+
return m
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class MSAColumnGlobalAttention(nn.Module):
|
| 323 |
+
def __init__(
|
| 324 |
+
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
|
| 325 |
+
):
|
| 326 |
+
super(MSAColumnGlobalAttention, self).__init__()
|
| 327 |
+
|
| 328 |
+
self.c_in = c_in
|
| 329 |
+
self.c_hidden = c_hidden
|
| 330 |
+
self.no_heads = no_heads
|
| 331 |
+
self.inf = inf
|
| 332 |
+
self.eps = eps
|
| 333 |
+
|
| 334 |
+
self.layer_norm_m = nn.LayerNorm(c_in)
|
| 335 |
+
|
| 336 |
+
self.global_attention = GlobalAttention(
|
| 337 |
+
c_in=c_in,
|
| 338 |
+
c_hidden=c_hidden,
|
| 339 |
+
no_heads=no_heads,
|
| 340 |
+
inf=inf,
|
| 341 |
+
eps=eps,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
@torch.jit.ignore
|
| 345 |
+
def _chunk(self,
|
| 346 |
+
m: torch.Tensor,
|
| 347 |
+
mask: torch.Tensor,
|
| 348 |
+
chunk_size: int,
|
| 349 |
+
) -> torch.Tensor:
|
| 350 |
+
mha_input = {
|
| 351 |
+
"m": m,
|
| 352 |
+
"mask": mask,
|
| 353 |
+
}
|
| 354 |
+
return chunk_layer(
|
| 355 |
+
self.global_attention,
|
| 356 |
+
mha_input,
|
| 357 |
+
chunk_size=chunk_size,
|
| 358 |
+
no_batch_dims=len(m.shape[:-2]),
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def forward(
|
| 362 |
+
self,
|
| 363 |
+
m: torch.Tensor,
|
| 364 |
+
mask: Optional[torch.Tensor] = None,
|
| 365 |
+
chunk_size: Optional[int] = None,
|
| 366 |
+
) -> torch.Tensor:
|
| 367 |
+
n_seq, n_res, c_in = m.shape[-3:]
|
| 368 |
+
|
| 369 |
+
if mask is None:
|
| 370 |
+
# [*, N_seq, N_res]
|
| 371 |
+
mask = torch.ones(
|
| 372 |
+
m.shape[:-1],
|
| 373 |
+
dtype=m.dtype,
|
| 374 |
+
device=m.device,
|
| 375 |
+
).detach()
|
| 376 |
+
|
| 377 |
+
# [*, N_res, N_seq, C_in]
|
| 378 |
+
m = m.transpose(-2, -3)
|
| 379 |
+
mask = mask.transpose(-1, -2)
|
| 380 |
+
|
| 381 |
+
# [*, N_res, N_seq, C_in]
|
| 382 |
+
m = self.layer_norm_m(m)
|
| 383 |
+
|
| 384 |
+
if chunk_size is not None:
|
| 385 |
+
m = self._chunk(m, mask, chunk_size)
|
| 386 |
+
else:
|
| 387 |
+
m = self.global_attention(m=m, mask=mask)
|
| 388 |
+
|
| 389 |
+
# [*, N_seq, N_res, C_in]
|
| 390 |
+
m = m.transpose(-2, -3)
|
| 391 |
+
|
| 392 |
+
return m
|
openfold/model/outer_product_mean.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from openfold.model.primitives import Linear
|
| 23 |
+
from openfold.utils.tensor_utils import chunk_layer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class OuterProductMean(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Implements Algorithm 10.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
c_m:
|
| 35 |
+
MSA embedding channel dimension
|
| 36 |
+
c_z:
|
| 37 |
+
Pair embedding channel dimension
|
| 38 |
+
c_hidden:
|
| 39 |
+
Hidden channel dimension
|
| 40 |
+
"""
|
| 41 |
+
super(OuterProductMean, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.c_m = c_m
|
| 44 |
+
self.c_z = c_z
|
| 45 |
+
self.c_hidden = c_hidden
|
| 46 |
+
self.eps = eps
|
| 47 |
+
|
| 48 |
+
self.layer_norm = nn.LayerNorm(c_m)
|
| 49 |
+
self.linear_1 = Linear(c_m, c_hidden)
|
| 50 |
+
self.linear_2 = Linear(c_m, c_hidden)
|
| 51 |
+
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
|
| 52 |
+
|
| 53 |
+
def _opm(self, a, b):
|
| 54 |
+
# [*, N_res, N_res, C, C]
|
| 55 |
+
outer = torch.einsum("...bac,...dae->...bdce", a, b)
|
| 56 |
+
|
| 57 |
+
# [*, N_res, N_res, C * C]
|
| 58 |
+
outer = outer.reshape(outer.shape[:-2] + (-1,))
|
| 59 |
+
|
| 60 |
+
# [*, N_res, N_res, C_z]
|
| 61 |
+
outer = self.linear_out(outer)
|
| 62 |
+
|
| 63 |
+
return outer
|
| 64 |
+
|
| 65 |
+
@torch.jit.ignore
|
| 66 |
+
def _chunk(self,
|
| 67 |
+
a: torch.Tensor,
|
| 68 |
+
b: torch.Tensor,
|
| 69 |
+
chunk_size: int
|
| 70 |
+
) -> torch.Tensor:
|
| 71 |
+
# Since the "batch dim" in this case is not a true batch dimension
|
| 72 |
+
# (in that the shape of the output depends on it), we need to
|
| 73 |
+
# iterate over it ourselves
|
| 74 |
+
a_reshape = a.reshape((-1,) + a.shape[-3:])
|
| 75 |
+
b_reshape = b.reshape((-1,) + b.shape[-3:])
|
| 76 |
+
out = []
|
| 77 |
+
for a_prime, b_prime in zip(a_reshape, b_reshape):
|
| 78 |
+
outer = chunk_layer(
|
| 79 |
+
partial(self._opm, b=b_prime),
|
| 80 |
+
{"a": a_prime},
|
| 81 |
+
chunk_size=chunk_size,
|
| 82 |
+
no_batch_dims=1,
|
| 83 |
+
)
|
| 84 |
+
out.append(outer)
|
| 85 |
+
outer = torch.stack(out, dim=0)
|
| 86 |
+
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
|
| 87 |
+
|
| 88 |
+
return outer
|
| 89 |
+
|
| 90 |
+
def forward(self,
|
| 91 |
+
m: torch.Tensor,
|
| 92 |
+
mask: Optional[torch.Tensor] = None,
|
| 93 |
+
chunk_size: Optional[int] = None
|
| 94 |
+
) -> torch.Tensor:
|
| 95 |
+
"""
|
| 96 |
+
Args:
|
| 97 |
+
m:
|
| 98 |
+
[*, N_seq, N_res, C_m] MSA embedding
|
| 99 |
+
mask:
|
| 100 |
+
[*, N_seq, N_res] MSA mask
|
| 101 |
+
Returns:
|
| 102 |
+
[*, N_res, N_res, C_z] pair embedding update
|
| 103 |
+
"""
|
| 104 |
+
if mask is None:
|
| 105 |
+
mask = m.new_ones(m.shape[:-1])
|
| 106 |
+
|
| 107 |
+
# [*, N_seq, N_res, C_m]
|
| 108 |
+
m = self.layer_norm(m)
|
| 109 |
+
|
| 110 |
+
# [*, N_seq, N_res, C]
|
| 111 |
+
mask = mask.unsqueeze(-1)
|
| 112 |
+
a = self.linear_1(m) * mask
|
| 113 |
+
b = self.linear_2(m) * mask
|
| 114 |
+
|
| 115 |
+
a = a.transpose(-2, -3)
|
| 116 |
+
b = b.transpose(-2, -3)
|
| 117 |
+
|
| 118 |
+
if chunk_size is not None:
|
| 119 |
+
outer = self._chunk(a, b, chunk_size)
|
| 120 |
+
else:
|
| 121 |
+
outer = self._opm(a, b)
|
| 122 |
+
|
| 123 |
+
# [*, N_res, N_res, 1]
|
| 124 |
+
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
|
| 125 |
+
|
| 126 |
+
# [*, N_res, N_res, C_z]
|
| 127 |
+
outer = outer / (self.eps + norm)
|
| 128 |
+
|
| 129 |
+
return outer
|
openfold/model/pair_transition.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from openfold.model.primitives import Linear, LayerNorm
|
| 21 |
+
from openfold.utils.tensor_utils import chunk_layer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PairTransition(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Implements Algorithm 15.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, c_z, n):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
c_z:
|
| 33 |
+
Pair transition channel dimension
|
| 34 |
+
n:
|
| 35 |
+
Factor by which c_z is multiplied to obtain hidden channel
|
| 36 |
+
dimension
|
| 37 |
+
"""
|
| 38 |
+
super(PairTransition, self).__init__()
|
| 39 |
+
|
| 40 |
+
self.c_z = c_z
|
| 41 |
+
self.n = n
|
| 42 |
+
|
| 43 |
+
self.layer_norm = LayerNorm(self.c_z)
|
| 44 |
+
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
|
| 45 |
+
self.relu = nn.ReLU()
|
| 46 |
+
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
|
| 47 |
+
|
| 48 |
+
def _transition(self, z, mask):
|
| 49 |
+
# [*, N_res, N_res, C_hidden]
|
| 50 |
+
z = self.linear_1(z)
|
| 51 |
+
z = self.relu(z)
|
| 52 |
+
|
| 53 |
+
# [*, N_res, N_res, C_z]
|
| 54 |
+
z = self.linear_2(z) * mask
|
| 55 |
+
|
| 56 |
+
return z
|
| 57 |
+
|
| 58 |
+
@torch.jit.ignore
|
| 59 |
+
def _chunk(self,
|
| 60 |
+
z: torch.Tensor,
|
| 61 |
+
mask: torch.Tensor,
|
| 62 |
+
chunk_size: int,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
return chunk_layer(
|
| 65 |
+
self._transition,
|
| 66 |
+
{"z": z, "mask": mask},
|
| 67 |
+
chunk_size=chunk_size,
|
| 68 |
+
no_batch_dims=len(z.shape[:-2]),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def forward(self,
|
| 73 |
+
z: torch.Tensor,
|
| 74 |
+
mask: Optional[torch.Tensor] = None,
|
| 75 |
+
chunk_size: Optional[int] = None,
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Args:
|
| 79 |
+
z:
|
| 80 |
+
[*, N_res, N_res, C_z] pair embedding
|
| 81 |
+
Returns:
|
| 82 |
+
[*, N_res, N_res, C_z] pair embedding update
|
| 83 |
+
"""
|
| 84 |
+
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
|
| 85 |
+
if mask is None:
|
| 86 |
+
mask = z.new_ones(z.shape[:-1])
|
| 87 |
+
|
| 88 |
+
# [*, N_res, N_res, 1]
|
| 89 |
+
mask = mask.unsqueeze(-1)
|
| 90 |
+
|
| 91 |
+
# [*, N_res, N_res, C_z]
|
| 92 |
+
z = self.layer_norm(z)
|
| 93 |
+
|
| 94 |
+
if chunk_size is not None:
|
| 95 |
+
z = self._chunk(z, mask, chunk_size)
|
| 96 |
+
else:
|
| 97 |
+
z = self._transition(z=z, mask=mask)
|
| 98 |
+
|
| 99 |
+
return z
|
openfold/model/primitives.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
import math
|
| 18 |
+
from typing import Optional, Callable, List, Tuple, Sequence
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
import deepspeed
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from scipy.stats import truncnorm
|
| 25 |
+
|
| 26 |
+
from openfold.utils.checkpointing import get_checkpoint_fn
|
| 27 |
+
from openfold.utils.tensor_utils import (
|
| 28 |
+
permute_final_dims,
|
| 29 |
+
flatten_final_dims,
|
| 30 |
+
_chunk_slice,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _prod(nums):
|
| 35 |
+
out = 1
|
| 36 |
+
for n in nums:
|
| 37 |
+
out = out * n
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
| 42 |
+
fan_out, fan_in = linear_weight_shape
|
| 43 |
+
|
| 44 |
+
if fan == "fan_in":
|
| 45 |
+
f = fan_in
|
| 46 |
+
elif fan == "fan_out":
|
| 47 |
+
f = fan_out
|
| 48 |
+
elif fan == "fan_avg":
|
| 49 |
+
f = (fan_in + fan_out) / 2
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError("Invalid fan option")
|
| 52 |
+
|
| 53 |
+
return f
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
|
| 57 |
+
shape = weights.shape
|
| 58 |
+
f = _calculate_fan(shape, fan)
|
| 59 |
+
scale = scale / max(1, f)
|
| 60 |
+
a = -2
|
| 61 |
+
b = 2
|
| 62 |
+
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
|
| 63 |
+
size = _prod(shape)
|
| 64 |
+
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
|
| 65 |
+
samples = np.reshape(samples, shape)
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
weights.copy_(torch.tensor(samples, device=weights.device))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def lecun_normal_init_(weights):
|
| 71 |
+
trunc_normal_init_(weights, scale=1.0)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def he_normal_init_(weights):
|
| 75 |
+
trunc_normal_init_(weights, scale=2.0)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def glorot_uniform_init_(weights):
|
| 79 |
+
nn.init.xavier_uniform_(weights, gain=1)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def final_init_(weights):
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
weights.fill_(0.0)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def gating_init_(weights):
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
weights.fill_(0.0)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def normal_init_(weights):
|
| 93 |
+
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def ipa_point_weights_init_(weights):
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
softplus_inverse_1 = 0.541324854612918
|
| 99 |
+
weights.fill_(softplus_inverse_1)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Linear(nn.Linear):
|
| 103 |
+
"""
|
| 104 |
+
A Linear layer with built-in nonstandard initializations. Called just
|
| 105 |
+
like torch.nn.Linear.
|
| 106 |
+
|
| 107 |
+
Implements the initializers in 1.11.4, plus some additional ones found
|
| 108 |
+
in the code.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
in_dim: int,
|
| 114 |
+
out_dim: int,
|
| 115 |
+
bias: bool = True,
|
| 116 |
+
init: str = "default",
|
| 117 |
+
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
| 118 |
+
):
|
| 119 |
+
"""
|
| 120 |
+
Args:
|
| 121 |
+
in_dim:
|
| 122 |
+
The final dimension of inputs to the layer
|
| 123 |
+
out_dim:
|
| 124 |
+
The final dimension of layer outputs
|
| 125 |
+
bias:
|
| 126 |
+
Whether to learn an additive bias. True by default
|
| 127 |
+
init:
|
| 128 |
+
The initializer to use. Choose from:
|
| 129 |
+
|
| 130 |
+
"default": LeCun fan-in truncated normal initialization
|
| 131 |
+
"relu": He initialization w/ truncated normal distribution
|
| 132 |
+
"glorot": Fan-average Glorot uniform initialization
|
| 133 |
+
"gating": Weights=0, Bias=1
|
| 134 |
+
"normal": Normal initialization with std=1/sqrt(fan_in)
|
| 135 |
+
"final": Weights=0, Bias=0
|
| 136 |
+
|
| 137 |
+
Overridden by init_fn if the latter is not None.
|
| 138 |
+
init_fn:
|
| 139 |
+
A custom initializer taking weight and bias as inputs.
|
| 140 |
+
Overrides init if not None.
|
| 141 |
+
"""
|
| 142 |
+
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
|
| 143 |
+
|
| 144 |
+
if bias:
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
self.bias.fill_(0)
|
| 147 |
+
|
| 148 |
+
if init_fn is not None:
|
| 149 |
+
init_fn(self.weight, self.bias)
|
| 150 |
+
else:
|
| 151 |
+
if init == "default":
|
| 152 |
+
lecun_normal_init_(self.weight)
|
| 153 |
+
elif init == "relu":
|
| 154 |
+
he_normal_init_(self.weight)
|
| 155 |
+
elif init == "glorot":
|
| 156 |
+
glorot_uniform_init_(self.weight)
|
| 157 |
+
elif init == "gating":
|
| 158 |
+
gating_init_(self.weight)
|
| 159 |
+
if bias:
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
self.bias.fill_(1.0)
|
| 162 |
+
elif init == "normal":
|
| 163 |
+
normal_init_(self.weight)
|
| 164 |
+
elif init == "final":
|
| 165 |
+
final_init_(self.weight)
|
| 166 |
+
else:
|
| 167 |
+
raise ValueError("Invalid init string.")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class LayerNorm(nn.Module):
|
| 171 |
+
def __init__(self, c_in, eps=1e-5):
|
| 172 |
+
super(LayerNorm, self).__init__()
|
| 173 |
+
|
| 174 |
+
self.c_in = (c_in,)
|
| 175 |
+
self.eps = eps
|
| 176 |
+
|
| 177 |
+
self.weight = nn.Parameter(torch.ones(c_in))
|
| 178 |
+
self.bias = nn.Parameter(torch.zeros(c_in))
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
d = x.dtype
|
| 182 |
+
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
|
| 183 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 184 |
+
out = nn.functional.layer_norm(
|
| 185 |
+
x,
|
| 186 |
+
self.c_in,
|
| 187 |
+
self.weight.to(dtype=d),
|
| 188 |
+
self.bias.to(dtype=d),
|
| 189 |
+
self.eps
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
out = nn.functional.layer_norm(
|
| 193 |
+
x,
|
| 194 |
+
self.c_in,
|
| 195 |
+
self.weight,
|
| 196 |
+
self.bias,
|
| 197 |
+
self.eps,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return out
|
| 201 |
+
|
| 202 |
+
@torch.jit.ignore
|
| 203 |
+
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
| 204 |
+
"""
|
| 205 |
+
Softmax, but without automatic casting to fp32 when the input is of
|
| 206 |
+
type bfloat16
|
| 207 |
+
"""
|
| 208 |
+
d = t.dtype
|
| 209 |
+
if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()):
|
| 210 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 211 |
+
s = torch.nn.functional.softmax(t, dim=dim)
|
| 212 |
+
else:
|
| 213 |
+
s = torch.nn.functional.softmax(t, dim=dim)
|
| 214 |
+
|
| 215 |
+
return s
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
#@torch.jit.script
|
| 219 |
+
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
|
| 220 |
+
# [*, H, Q, C_hidden]
|
| 221 |
+
query = permute_final_dims(query, (1, 0, 2))
|
| 222 |
+
|
| 223 |
+
# [*, H, C_hidden, K]
|
| 224 |
+
key = permute_final_dims(key, (1, 2, 0))
|
| 225 |
+
|
| 226 |
+
# [*, H, V, C_hidden]
|
| 227 |
+
value = permute_final_dims(value, (1, 0, 2))
|
| 228 |
+
|
| 229 |
+
# [*, H, Q, K]
|
| 230 |
+
a = torch.matmul(query, key)
|
| 231 |
+
|
| 232 |
+
for b in biases:
|
| 233 |
+
a += b
|
| 234 |
+
|
| 235 |
+
a = softmax(a, -1)
|
| 236 |
+
|
| 237 |
+
# [*, H, Q, C_hidden]
|
| 238 |
+
a = torch.matmul(a, value)
|
| 239 |
+
|
| 240 |
+
# [*, Q, H, C_hidden]
|
| 241 |
+
a = a.transpose(-2, -3)
|
| 242 |
+
|
| 243 |
+
return a
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@torch.jit.ignore
|
| 247 |
+
def _attention_chunked_trainable(
|
| 248 |
+
query, key, value, biases, chunk_size, chunk_dim, checkpoint,
|
| 249 |
+
):
|
| 250 |
+
if(checkpoint and len(biases) > 2):
|
| 251 |
+
raise ValueError(
|
| 252 |
+
"Checkpointed version permits only permits two bias terms"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def _checkpointable_attention(q, k, v, b1, b2):
|
| 256 |
+
bs = [b for b in [b1, b2] if b is not None]
|
| 257 |
+
return _attention(q, k, v, bs)
|
| 258 |
+
|
| 259 |
+
o_chunks = []
|
| 260 |
+
checkpoint_fn = get_checkpoint_fn()
|
| 261 |
+
count = query.shape[chunk_dim]
|
| 262 |
+
for start in range(0, count, chunk_size):
|
| 263 |
+
end = start + chunk_size
|
| 264 |
+
idx = [slice(None)] * len(query.shape)
|
| 265 |
+
idx[chunk_dim] = slice(start, end)
|
| 266 |
+
idx_tup = tuple(idx)
|
| 267 |
+
q_chunk = query[idx_tup]
|
| 268 |
+
k_chunk = key[idx_tup]
|
| 269 |
+
v_chunk = value[idx_tup]
|
| 270 |
+
|
| 271 |
+
def _slice_bias(b):
|
| 272 |
+
idx[chunk_dim] = (
|
| 273 |
+
slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
|
| 274 |
+
)
|
| 275 |
+
return b[tuple(idx)]
|
| 276 |
+
|
| 277 |
+
if(checkpoint):
|
| 278 |
+
bias_1_chunk, bias_2_chunk = [
|
| 279 |
+
_slice_bias(b) if b is not None else None
|
| 280 |
+
for b in (biases + [None, None])[:2]
|
| 281 |
+
]
|
| 282 |
+
|
| 283 |
+
o_chunk = checkpoint_fn(_checkpointable_attention,
|
| 284 |
+
q_chunk, k_chunk, v_chunk, bias_1_chunk, bias_2_chunk
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
bias_chunks = [
|
| 288 |
+
_slice_bias(b) for b in biases
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
|
| 292 |
+
|
| 293 |
+
o_chunks.append(o_chunk)
|
| 294 |
+
|
| 295 |
+
o = torch.cat(o_chunks, dim=chunk_dim)
|
| 296 |
+
return o
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class Attention(nn.Module):
|
| 300 |
+
"""
|
| 301 |
+
Standard multi-head attention using AlphaFold's default layer
|
| 302 |
+
initialization. Allows multiple bias vectors.
|
| 303 |
+
"""
|
| 304 |
+
def __init__(
|
| 305 |
+
self,
|
| 306 |
+
c_q: int,
|
| 307 |
+
c_k: int,
|
| 308 |
+
c_v: int,
|
| 309 |
+
c_hidden: int,
|
| 310 |
+
no_heads: int,
|
| 311 |
+
gating: bool = True,
|
| 312 |
+
):
|
| 313 |
+
"""
|
| 314 |
+
Args:
|
| 315 |
+
c_q:
|
| 316 |
+
Input dimension of query data
|
| 317 |
+
c_k:
|
| 318 |
+
Input dimension of key data
|
| 319 |
+
c_v:
|
| 320 |
+
Input dimension of value data
|
| 321 |
+
c_hidden:
|
| 322 |
+
Per-head hidden dimension
|
| 323 |
+
no_heads:
|
| 324 |
+
Number of attention heads
|
| 325 |
+
gating:
|
| 326 |
+
Whether the output should be gated using query data
|
| 327 |
+
"""
|
| 328 |
+
super(Attention, self).__init__()
|
| 329 |
+
|
| 330 |
+
self.c_q = c_q
|
| 331 |
+
self.c_k = c_k
|
| 332 |
+
self.c_v = c_v
|
| 333 |
+
self.c_hidden = c_hidden
|
| 334 |
+
self.no_heads = no_heads
|
| 335 |
+
self.gating = gating
|
| 336 |
+
|
| 337 |
+
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
|
| 338 |
+
# stated in the supplement, but the overall channel dimension.
|
| 339 |
+
|
| 340 |
+
self.linear_q = Linear(
|
| 341 |
+
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
| 342 |
+
)
|
| 343 |
+
self.linear_k = Linear(
|
| 344 |
+
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
| 345 |
+
)
|
| 346 |
+
self.linear_v = Linear(
|
| 347 |
+
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
| 348 |
+
)
|
| 349 |
+
self.linear_o = Linear(
|
| 350 |
+
self.c_hidden * self.no_heads, self.c_q, init="final"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
self.linear_g = None
|
| 354 |
+
if self.gating:
|
| 355 |
+
self.linear_g = Linear(
|
| 356 |
+
self.c_q, self.c_hidden * self.no_heads, init="gating"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
self.sigmoid = nn.Sigmoid()
|
| 360 |
+
|
| 361 |
+
def _prep_qkv(self,
|
| 362 |
+
q_x: torch.Tensor,
|
| 363 |
+
kv_x: torch.Tensor
|
| 364 |
+
) -> Tuple[
|
| 365 |
+
torch.Tensor, torch.Tensor, torch.Tensor
|
| 366 |
+
]:
|
| 367 |
+
# [*, Q/K/V, H * C_hidden]
|
| 368 |
+
q = self.linear_q(q_x)
|
| 369 |
+
k = self.linear_k(kv_x)
|
| 370 |
+
v = self.linear_v(kv_x)
|
| 371 |
+
|
| 372 |
+
# [*, Q/K, H, C_hidden]
|
| 373 |
+
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
| 374 |
+
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
| 375 |
+
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
| 376 |
+
|
| 377 |
+
q /= math.sqrt(self.c_hidden)
|
| 378 |
+
|
| 379 |
+
return q, k, v
|
| 380 |
+
|
| 381 |
+
def _wrap_up(self,
|
| 382 |
+
o: torch.Tensor,
|
| 383 |
+
q_x: torch.Tensor
|
| 384 |
+
) -> torch.Tensor:
|
| 385 |
+
if(self.linear_g is not None):
|
| 386 |
+
g = self.sigmoid(self.linear_g(q_x))
|
| 387 |
+
|
| 388 |
+
# [*, Q, H, C_hidden]
|
| 389 |
+
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
| 390 |
+
o = o * g
|
| 391 |
+
|
| 392 |
+
# [*, Q, H * C_hidden]
|
| 393 |
+
o = flatten_final_dims(o, 2)
|
| 394 |
+
|
| 395 |
+
# [*, Q, C_q]
|
| 396 |
+
o = self.linear_o(o)
|
| 397 |
+
|
| 398 |
+
return o
|
| 399 |
+
|
| 400 |
+
def forward(
|
| 401 |
+
self,
|
| 402 |
+
q_x: torch.Tensor,
|
| 403 |
+
kv_x: torch.Tensor,
|
| 404 |
+
biases: Optional[List[torch.Tensor]] = None,
|
| 405 |
+
use_lma: bool = False,
|
| 406 |
+
q_chunk_size: Optional[int] = None,
|
| 407 |
+
kv_chunk_size: Optional[int] = None,
|
| 408 |
+
) -> torch.Tensor:
|
| 409 |
+
"""
|
| 410 |
+
Args:
|
| 411 |
+
q_x:
|
| 412 |
+
[*, Q, C_q] query data
|
| 413 |
+
kv_x:
|
| 414 |
+
[*, K, C_k] key data
|
| 415 |
+
biases:
|
| 416 |
+
List of biases that broadcast to [*, H, Q, K]
|
| 417 |
+
use_lma:
|
| 418 |
+
Whether to use low-memory attention
|
| 419 |
+
q_chunk_size:
|
| 420 |
+
Query chunk size (for LMA)
|
| 421 |
+
kv_chunk_size:
|
| 422 |
+
Key/Value chunk size (for LMA)
|
| 423 |
+
Returns
|
| 424 |
+
[*, Q, C_q] attention update
|
| 425 |
+
"""
|
| 426 |
+
if(biases is None):
|
| 427 |
+
biases = []
|
| 428 |
+
if(use_lma and (q_chunk_size is None or kv_chunk_size is None)):
|
| 429 |
+
raise ValueError(
|
| 430 |
+
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
|
| 431 |
+
"be provided"
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
q, k, v = self._prep_qkv(q_x, kv_x)
|
| 435 |
+
|
| 436 |
+
if(use_lma):
|
| 437 |
+
biases = [
|
| 438 |
+
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
|
| 439 |
+
for b in biases
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
|
| 443 |
+
else:
|
| 444 |
+
o = _attention(q, k, v, biases)
|
| 445 |
+
|
| 446 |
+
o = self._wrap_up(o, q_x)
|
| 447 |
+
|
| 448 |
+
return o
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class GlobalAttention(nn.Module):
|
| 452 |
+
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
|
| 453 |
+
super(GlobalAttention, self).__init__()
|
| 454 |
+
|
| 455 |
+
self.c_in = c_in
|
| 456 |
+
self.c_hidden = c_hidden
|
| 457 |
+
self.no_heads = no_heads
|
| 458 |
+
self.inf = inf
|
| 459 |
+
self.eps = eps
|
| 460 |
+
|
| 461 |
+
self.linear_q = Linear(
|
| 462 |
+
c_in, c_hidden * no_heads, bias=False, init="glorot"
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
self.linear_k = Linear(
|
| 466 |
+
c_in, c_hidden, bias=False, init="glorot",
|
| 467 |
+
)
|
| 468 |
+
self.linear_v = Linear(
|
| 469 |
+
c_in, c_hidden, bias=False, init="glorot",
|
| 470 |
+
)
|
| 471 |
+
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
|
| 472 |
+
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
|
| 473 |
+
|
| 474 |
+
self.sigmoid = nn.Sigmoid()
|
| 475 |
+
|
| 476 |
+
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 477 |
+
# [*, N_res, C_in]
|
| 478 |
+
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
|
| 479 |
+
torch.sum(mask, dim=-1)[..., None] + self.eps
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# [*, N_res, H * C_hidden]
|
| 483 |
+
q = self.linear_q(q)
|
| 484 |
+
q *= (self.c_hidden ** (-0.5))
|
| 485 |
+
|
| 486 |
+
# [*, N_res, H, C_hidden]
|
| 487 |
+
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
| 488 |
+
|
| 489 |
+
# [*, N_res, N_seq, C_hidden]
|
| 490 |
+
k = self.linear_k(m)
|
| 491 |
+
v = self.linear_v(m)
|
| 492 |
+
|
| 493 |
+
# [*, N_res, H, N_seq]
|
| 494 |
+
a = torch.matmul(
|
| 495 |
+
q,
|
| 496 |
+
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
|
| 497 |
+
)
|
| 498 |
+
bias = (self.inf * (mask - 1))[..., :, None, :]
|
| 499 |
+
a += bias
|
| 500 |
+
a = softmax(a)
|
| 501 |
+
|
| 502 |
+
# [*, N_res, H, C_hidden]
|
| 503 |
+
o = torch.matmul(
|
| 504 |
+
a,
|
| 505 |
+
v,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# [*, N_res, N_seq, C_hidden]
|
| 509 |
+
g = self.sigmoid(self.linear_g(m))
|
| 510 |
+
|
| 511 |
+
# [*, N_res, N_seq, H, C_hidden]
|
| 512 |
+
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
| 513 |
+
|
| 514 |
+
# [*, N_res, N_seq, H, C_hidden]
|
| 515 |
+
o = o.unsqueeze(-3) * g
|
| 516 |
+
|
| 517 |
+
# [*, N_res, N_seq, H * C_hidden]
|
| 518 |
+
o = o.reshape(o.shape[:-2] + (-1,))
|
| 519 |
+
|
| 520 |
+
# [*, N_res, N_seq, C_in]
|
| 521 |
+
m = self.linear_o(o)
|
| 522 |
+
|
| 523 |
+
return m
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def _lma(
|
| 527 |
+
q: torch.Tensor,
|
| 528 |
+
k: torch.Tensor,
|
| 529 |
+
v: torch.Tensor,
|
| 530 |
+
biases: List[torch.Tensor],
|
| 531 |
+
q_chunk_size: int,
|
| 532 |
+
kv_chunk_size: int,
|
| 533 |
+
):
|
| 534 |
+
no_q, no_kv = q.shape[-3], k.shape[-3]
|
| 535 |
+
|
| 536 |
+
# [*, Q, H, C_hidden]
|
| 537 |
+
o = q.new_zeros(q.shape)
|
| 538 |
+
for q_s in range(0, no_q, q_chunk_size):
|
| 539 |
+
q_chunk = q[..., q_s: q_s + q_chunk_size, :, :]
|
| 540 |
+
large_bias_chunks = [
|
| 541 |
+
b[..., q_s: q_s + q_chunk_size, :] for b in biases
|
| 542 |
+
]
|
| 543 |
+
|
| 544 |
+
maxes = []
|
| 545 |
+
weights = []
|
| 546 |
+
values = []
|
| 547 |
+
for kv_s in range(0, no_kv, kv_chunk_size):
|
| 548 |
+
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :, :]
|
| 549 |
+
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :, :]
|
| 550 |
+
small_bias_chunks = [
|
| 551 |
+
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
|
| 552 |
+
]
|
| 553 |
+
|
| 554 |
+
a = torch.einsum(
|
| 555 |
+
"...qhd,...khd->...hqk", q_chunk, k_chunk,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
for b in small_bias_chunks:
|
| 559 |
+
a += b
|
| 560 |
+
|
| 561 |
+
a = a.transpose(-2, -3)
|
| 562 |
+
|
| 563 |
+
max_a = torch.max(a, dim=-1, keepdim=True)[0]
|
| 564 |
+
exp_a = torch.exp(a - max_a)
|
| 565 |
+
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
|
| 566 |
+
|
| 567 |
+
maxes.append(max_a.detach().squeeze(-1))
|
| 568 |
+
weights.append(torch.sum(exp_a, dim=-1))
|
| 569 |
+
values.append(exp_v)
|
| 570 |
+
|
| 571 |
+
chunk_max = torch.stack(maxes, dim=-3)
|
| 572 |
+
chunk_weights = torch.stack(weights, dim=-3)
|
| 573 |
+
chunk_values = torch.stack(values, dim=-4)
|
| 574 |
+
|
| 575 |
+
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
|
| 576 |
+
max_diffs = torch.exp(chunk_max - global_max)
|
| 577 |
+
chunk_values *= max_diffs.unsqueeze(-1)
|
| 578 |
+
chunk_weights *= max_diffs
|
| 579 |
+
|
| 580 |
+
all_values = torch.sum(chunk_values, dim=-4)
|
| 581 |
+
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
|
| 582 |
+
|
| 583 |
+
q_chunk_out = all_values / all_weights
|
| 584 |
+
|
| 585 |
+
o[..., q_s: q_s + q_chunk_size, :, :] = q_chunk_out
|
| 586 |
+
|
| 587 |
+
return o
|
openfold/model/structure_module.py
ADDED
|
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from functools import reduce
|
| 16 |
+
import importlib
|
| 17 |
+
import math
|
| 18 |
+
import sys
|
| 19 |
+
from operator import mul
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from typing import Optional, Tuple, Sequence
|
| 24 |
+
|
| 25 |
+
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
|
| 26 |
+
from openfold.np.residue_constants import (
|
| 27 |
+
restype_rigid_group_default_frame,
|
| 28 |
+
restype_atom14_to_rigid_group,
|
| 29 |
+
restype_atom14_mask,
|
| 30 |
+
restype_atom14_rigid_group_positions,
|
| 31 |
+
)
|
| 32 |
+
from openfold.utils.feats import (
|
| 33 |
+
frames_and_literature_positions_to_atom14_pos,
|
| 34 |
+
torsion_angles_to_frames,
|
| 35 |
+
)
|
| 36 |
+
from openfold.utils.precision_utils import is_fp16_enabled
|
| 37 |
+
from openfold.utils.rigid_utils import Rotation, Rigid
|
| 38 |
+
from openfold.utils.tensor_utils import (
|
| 39 |
+
dict_multimap,
|
| 40 |
+
permute_final_dims,
|
| 41 |
+
flatten_final_dims,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AngleResnetBlock(nn.Module):
|
| 48 |
+
def __init__(self, c_hidden):
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
c_hidden:
|
| 52 |
+
Hidden channel dimension
|
| 53 |
+
"""
|
| 54 |
+
super(AngleResnetBlock, self).__init__()
|
| 55 |
+
|
| 56 |
+
self.c_hidden = c_hidden
|
| 57 |
+
|
| 58 |
+
self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu")
|
| 59 |
+
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final")
|
| 60 |
+
|
| 61 |
+
self.relu = nn.ReLU()
|
| 62 |
+
|
| 63 |
+
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
|
| 65 |
+
s_initial = a
|
| 66 |
+
|
| 67 |
+
a = self.relu(a)
|
| 68 |
+
a = self.linear_1(a)
|
| 69 |
+
a = self.relu(a)
|
| 70 |
+
a = self.linear_2(a)
|
| 71 |
+
|
| 72 |
+
return a + s_initial
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class AngleResnet(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
Implements Algorithm 20, lines 11-14
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
|
| 81 |
+
"""
|
| 82 |
+
Args:
|
| 83 |
+
c_in:
|
| 84 |
+
Input channel dimension
|
| 85 |
+
c_hidden:
|
| 86 |
+
Hidden channel dimension
|
| 87 |
+
no_blocks:
|
| 88 |
+
Number of resnet blocks
|
| 89 |
+
no_angles:
|
| 90 |
+
Number of torsion angles to generate
|
| 91 |
+
epsilon:
|
| 92 |
+
Small constant for normalization
|
| 93 |
+
"""
|
| 94 |
+
super(AngleResnet, self).__init__()
|
| 95 |
+
|
| 96 |
+
self.c_in = c_in
|
| 97 |
+
self.c_hidden = c_hidden
|
| 98 |
+
self.no_blocks = no_blocks
|
| 99 |
+
self.no_angles = no_angles
|
| 100 |
+
self.eps = epsilon
|
| 101 |
+
|
| 102 |
+
self.linear_in = Linear(self.c_in, self.c_hidden)
|
| 103 |
+
self.linear_initial = Linear(self.c_in, self.c_hidden)
|
| 104 |
+
|
| 105 |
+
self.layers = nn.ModuleList()
|
| 106 |
+
for _ in range(self.no_blocks):
|
| 107 |
+
layer = AngleResnetBlock(c_hidden=self.c_hidden)
|
| 108 |
+
self.layers.append(layer)
|
| 109 |
+
|
| 110 |
+
self.linear_out = Linear(self.c_hidden, self.no_angles * 2)
|
| 111 |
+
|
| 112 |
+
self.relu = nn.ReLU()
|
| 113 |
+
|
| 114 |
+
def forward(
|
| 115 |
+
self, s: torch.Tensor, s_initial: torch.Tensor
|
| 116 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 117 |
+
"""
|
| 118 |
+
Args:
|
| 119 |
+
s:
|
| 120 |
+
[*, C_hidden] single embedding
|
| 121 |
+
s_initial:
|
| 122 |
+
[*, C_hidden] single embedding as of the start of the
|
| 123 |
+
StructureModule
|
| 124 |
+
Returns:
|
| 125 |
+
[*, no_angles, 2] predicted angles
|
| 126 |
+
"""
|
| 127 |
+
# NOTE: The ReLU's applied to the inputs are absent from the supplement
|
| 128 |
+
# pseudocode but present in the source. For maximal compatibility with
|
| 129 |
+
# the pretrained weights, I'm going with the source.
|
| 130 |
+
|
| 131 |
+
# [*, C_hidden]
|
| 132 |
+
s_initial = self.relu(s_initial)
|
| 133 |
+
s_initial = self.linear_initial(s_initial)
|
| 134 |
+
s = self.relu(s)
|
| 135 |
+
s = self.linear_in(s)
|
| 136 |
+
s = s + s_initial
|
| 137 |
+
|
| 138 |
+
for l in self.layers:
|
| 139 |
+
s = l(s)
|
| 140 |
+
|
| 141 |
+
s = self.relu(s)
|
| 142 |
+
|
| 143 |
+
# [*, no_angles * 2]
|
| 144 |
+
s = self.linear_out(s)
|
| 145 |
+
|
| 146 |
+
# [*, no_angles, 2]
|
| 147 |
+
s = s.view(s.shape[:-1] + (-1, 2))
|
| 148 |
+
|
| 149 |
+
unnormalized_s = s
|
| 150 |
+
norm_denom = torch.sqrt(
|
| 151 |
+
torch.clamp(
|
| 152 |
+
torch.sum(s ** 2, dim=-1, keepdim=True),
|
| 153 |
+
min=self.eps,
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
s = s / norm_denom
|
| 157 |
+
|
| 158 |
+
return unnormalized_s, s
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class InvariantPointAttention(nn.Module):
|
| 162 |
+
"""
|
| 163 |
+
Implements Algorithm 22.
|
| 164 |
+
"""
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
c_s: int,
|
| 168 |
+
c_z: int,
|
| 169 |
+
c_hidden: int,
|
| 170 |
+
no_heads: int,
|
| 171 |
+
no_qk_points: int,
|
| 172 |
+
no_v_points: int,
|
| 173 |
+
inf: float = 1e5,
|
| 174 |
+
eps: float = 1e-8,
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Args:
|
| 178 |
+
c_s:
|
| 179 |
+
Single representation channel dimension
|
| 180 |
+
c_z:
|
| 181 |
+
Pair representation channel dimension
|
| 182 |
+
c_hidden:
|
| 183 |
+
Hidden channel dimension
|
| 184 |
+
no_heads:
|
| 185 |
+
Number of attention heads
|
| 186 |
+
no_qk_points:
|
| 187 |
+
Number of query/key points to generate
|
| 188 |
+
no_v_points:
|
| 189 |
+
Number of value points to generate
|
| 190 |
+
"""
|
| 191 |
+
super(InvariantPointAttention, self).__init__()
|
| 192 |
+
|
| 193 |
+
self.c_s = c_s
|
| 194 |
+
self.c_z = c_z
|
| 195 |
+
self.c_hidden = c_hidden
|
| 196 |
+
self.no_heads = no_heads
|
| 197 |
+
self.no_qk_points = no_qk_points
|
| 198 |
+
self.no_v_points = no_v_points
|
| 199 |
+
self.inf = inf
|
| 200 |
+
self.eps = eps
|
| 201 |
+
|
| 202 |
+
# These linear layers differ from their specifications in the
|
| 203 |
+
# supplement. There, they lack bias and use Glorot initialization.
|
| 204 |
+
# Here as in the official source, they have bias and use the default
|
| 205 |
+
# Lecun initialization.
|
| 206 |
+
hc = self.c_hidden * self.no_heads
|
| 207 |
+
self.linear_q = Linear(self.c_s, hc)
|
| 208 |
+
self.linear_kv = Linear(self.c_s, 2 * hc)
|
| 209 |
+
|
| 210 |
+
hpq = self.no_heads * self.no_qk_points * 3
|
| 211 |
+
self.linear_q_points = Linear(self.c_s, hpq)
|
| 212 |
+
|
| 213 |
+
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
|
| 214 |
+
self.linear_kv_points = Linear(self.c_s, hpkv)
|
| 215 |
+
|
| 216 |
+
hpv = self.no_heads * self.no_v_points * 3
|
| 217 |
+
|
| 218 |
+
self.linear_b = Linear(self.c_z, self.no_heads)
|
| 219 |
+
|
| 220 |
+
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
|
| 221 |
+
ipa_point_weights_init_(self.head_weights)
|
| 222 |
+
|
| 223 |
+
concat_out_dim = self.no_heads * (
|
| 224 |
+
self.c_z + self.c_hidden + self.no_v_points * 4
|
| 225 |
+
)
|
| 226 |
+
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
|
| 227 |
+
|
| 228 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 229 |
+
self.softplus = nn.Softplus()
|
| 230 |
+
|
| 231 |
+
def forward(
|
| 232 |
+
self,
|
| 233 |
+
s: torch.Tensor,
|
| 234 |
+
z: Optional[torch.Tensor],
|
| 235 |
+
r: Rigid,
|
| 236 |
+
mask: torch.Tensor,
|
| 237 |
+
inplace_safe: bool = False,
|
| 238 |
+
_offload_inference: bool = False,
|
| 239 |
+
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
|
| 240 |
+
) -> torch.Tensor:
|
| 241 |
+
"""
|
| 242 |
+
Args:
|
| 243 |
+
s:
|
| 244 |
+
[*, N_res, C_s] single representation
|
| 245 |
+
z:
|
| 246 |
+
[*, N_res, N_res, C_z] pair representation
|
| 247 |
+
r:
|
| 248 |
+
[*, N_res] transformation object
|
| 249 |
+
mask:
|
| 250 |
+
[*, N_res] mask
|
| 251 |
+
Returns:
|
| 252 |
+
[*, N_res, C_s] single representation update
|
| 253 |
+
"""
|
| 254 |
+
if(_offload_inference and inplace_safe):
|
| 255 |
+
z = _z_reference_list
|
| 256 |
+
else:
|
| 257 |
+
z = [z]
|
| 258 |
+
|
| 259 |
+
#######################################
|
| 260 |
+
# Generate scalar and point activations
|
| 261 |
+
#######################################
|
| 262 |
+
# [*, N_res, H * C_hidden]
|
| 263 |
+
q = self.linear_q(s)
|
| 264 |
+
kv = self.linear_kv(s)
|
| 265 |
+
|
| 266 |
+
# [*, N_res, H, C_hidden]
|
| 267 |
+
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
| 268 |
+
|
| 269 |
+
# [*, N_res, H, 2 * C_hidden]
|
| 270 |
+
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
|
| 271 |
+
|
| 272 |
+
# [*, N_res, H, C_hidden]
|
| 273 |
+
k, v = torch.split(kv, self.c_hidden, dim=-1)
|
| 274 |
+
|
| 275 |
+
# [*, N_res, H * P_q * 3]
|
| 276 |
+
q_pts = self.linear_q_points(s)
|
| 277 |
+
|
| 278 |
+
# This is kind of clunky, but it's how the original does it
|
| 279 |
+
# [*, N_res, H * P_q, 3]
|
| 280 |
+
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
|
| 281 |
+
q_pts = torch.stack(q_pts, dim=-1)
|
| 282 |
+
q_pts = r[..., None].apply(q_pts)
|
| 283 |
+
|
| 284 |
+
# [*, N_res, H, P_q, 3]
|
| 285 |
+
q_pts = q_pts.view(
|
| 286 |
+
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# [*, N_res, H * (P_q + P_v) * 3]
|
| 290 |
+
kv_pts = self.linear_kv_points(s)
|
| 291 |
+
|
| 292 |
+
# [*, N_res, H * (P_q + P_v), 3]
|
| 293 |
+
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
|
| 294 |
+
kv_pts = torch.stack(kv_pts, dim=-1)
|
| 295 |
+
kv_pts = r[..., None].apply(kv_pts)
|
| 296 |
+
|
| 297 |
+
# [*, N_res, H, (P_q + P_v), 3]
|
| 298 |
+
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
|
| 299 |
+
|
| 300 |
+
# [*, N_res, H, P_q/P_v, 3]
|
| 301 |
+
k_pts, v_pts = torch.split(
|
| 302 |
+
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
##########################
|
| 306 |
+
# Compute attention scores
|
| 307 |
+
##########################
|
| 308 |
+
# [*, N_res, N_res, H]
|
| 309 |
+
b = self.linear_b(z[0])
|
| 310 |
+
|
| 311 |
+
if(_offload_inference):
|
| 312 |
+
assert(sys.getrefcount(z[0]) == 2)
|
| 313 |
+
z[0] = z[0].cpu()
|
| 314 |
+
|
| 315 |
+
# [*, H, N_res, N_res]
|
| 316 |
+
if(is_fp16_enabled()):
|
| 317 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 318 |
+
a = torch.matmul(
|
| 319 |
+
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
|
| 320 |
+
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
a = torch.matmul(
|
| 324 |
+
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
|
| 325 |
+
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
a *= math.sqrt(1.0 / (3 * self.c_hidden))
|
| 329 |
+
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
|
| 330 |
+
|
| 331 |
+
# [*, N_res, N_res, H, P_q, 3]
|
| 332 |
+
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
|
| 333 |
+
if(inplace_safe):
|
| 334 |
+
pt_att *= pt_att
|
| 335 |
+
else:
|
| 336 |
+
pt_att = pt_att ** 2
|
| 337 |
+
|
| 338 |
+
# [*, N_res, N_res, H, P_q]
|
| 339 |
+
pt_att = sum(torch.unbind(pt_att, dim=-1))
|
| 340 |
+
head_weights = self.softplus(self.head_weights).view(
|
| 341 |
+
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
|
| 342 |
+
)
|
| 343 |
+
head_weights = head_weights * math.sqrt(
|
| 344 |
+
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
|
| 345 |
+
)
|
| 346 |
+
if(inplace_safe):
|
| 347 |
+
pt_att *= head_weights
|
| 348 |
+
else:
|
| 349 |
+
pt_att = pt_att * head_weights
|
| 350 |
+
|
| 351 |
+
# [*, N_res, N_res, H]
|
| 352 |
+
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
|
| 353 |
+
# [*, N_res, N_res]
|
| 354 |
+
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
| 355 |
+
square_mask = self.inf * (square_mask - 1)
|
| 356 |
+
|
| 357 |
+
# [*, H, N_res, N_res]
|
| 358 |
+
pt_att = permute_final_dims(pt_att, (2, 0, 1))
|
| 359 |
+
|
| 360 |
+
if(inplace_safe):
|
| 361 |
+
a += pt_att
|
| 362 |
+
del pt_att
|
| 363 |
+
a += square_mask.unsqueeze(-3)
|
| 364 |
+
# in-place softmax
|
| 365 |
+
attn_core_inplace_cuda.forward_(
|
| 366 |
+
a,
|
| 367 |
+
reduce(mul, a.shape[:-1]),
|
| 368 |
+
a.shape[-1],
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
a = a + pt_att
|
| 372 |
+
a = a + square_mask.unsqueeze(-3)
|
| 373 |
+
a = self.softmax(a)
|
| 374 |
+
|
| 375 |
+
################
|
| 376 |
+
# Compute output
|
| 377 |
+
################
|
| 378 |
+
# [*, N_res, H, C_hidden]
|
| 379 |
+
o = torch.matmul(
|
| 380 |
+
a, v.transpose(-2, -3).to(dtype=a.dtype)
|
| 381 |
+
).transpose(-2, -3)
|
| 382 |
+
|
| 383 |
+
# [*, N_res, H * C_hidden]
|
| 384 |
+
o = flatten_final_dims(o, 2)
|
| 385 |
+
|
| 386 |
+
# [*, H, 3, N_res, P_v]
|
| 387 |
+
if(inplace_safe):
|
| 388 |
+
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
|
| 389 |
+
o_pt = [
|
| 390 |
+
torch.matmul(a, v.to(a.dtype))
|
| 391 |
+
for v in torch.unbind(v_pts, dim=-3)
|
| 392 |
+
]
|
| 393 |
+
o_pt = torch.stack(o_pt, dim=-3)
|
| 394 |
+
else:
|
| 395 |
+
o_pt = torch.sum(
|
| 396 |
+
(
|
| 397 |
+
a[..., None, :, :, None]
|
| 398 |
+
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
|
| 399 |
+
),
|
| 400 |
+
dim=-2,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# [*, N_res, H, P_v, 3]
|
| 404 |
+
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
|
| 405 |
+
o_pt = r[..., None, None].invert_apply(o_pt)
|
| 406 |
+
|
| 407 |
+
# [*, N_res, H * P_v]
|
| 408 |
+
o_pt_norm = flatten_final_dims(
|
| 409 |
+
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# [*, N_res, H * P_v, 3]
|
| 413 |
+
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
|
| 414 |
+
|
| 415 |
+
if(_offload_inference):
|
| 416 |
+
z[0] = z[0].to(o_pt.device)
|
| 417 |
+
|
| 418 |
+
# [*, N_res, H, C_z]
|
| 419 |
+
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
|
| 420 |
+
|
| 421 |
+
# [*, N_res, H * C_z]
|
| 422 |
+
o_pair = flatten_final_dims(o_pair, 2)
|
| 423 |
+
|
| 424 |
+
# [*, N_res, C_s]
|
| 425 |
+
s = self.linear_out(
|
| 426 |
+
torch.cat(
|
| 427 |
+
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
|
| 428 |
+
).to(dtype=z[0].dtype)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
return s
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class BackboneUpdate(nn.Module):
|
| 435 |
+
"""
|
| 436 |
+
Implements part of Algorithm 23.
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
def __init__(self, c_s):
|
| 440 |
+
"""
|
| 441 |
+
Args:
|
| 442 |
+
c_s:
|
| 443 |
+
Single representation channel dimension
|
| 444 |
+
"""
|
| 445 |
+
super(BackboneUpdate, self).__init__()
|
| 446 |
+
|
| 447 |
+
self.c_s = c_s
|
| 448 |
+
|
| 449 |
+
self.linear = Linear(self.c_s, 6, init="final")
|
| 450 |
+
|
| 451 |
+
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 452 |
+
"""
|
| 453 |
+
Args:
|
| 454 |
+
[*, N_res, C_s] single representation
|
| 455 |
+
Returns:
|
| 456 |
+
[*, N_res, 6] update vector
|
| 457 |
+
"""
|
| 458 |
+
# [*, 6]
|
| 459 |
+
update = self.linear(s)
|
| 460 |
+
|
| 461 |
+
return update
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class StructureModuleTransitionLayer(nn.Module):
|
| 465 |
+
def __init__(self, c):
|
| 466 |
+
super(StructureModuleTransitionLayer, self).__init__()
|
| 467 |
+
|
| 468 |
+
self.c = c
|
| 469 |
+
|
| 470 |
+
self.linear_1 = Linear(self.c, self.c, init="relu")
|
| 471 |
+
self.linear_2 = Linear(self.c, self.c, init="relu")
|
| 472 |
+
self.linear_3 = Linear(self.c, self.c, init="final")
|
| 473 |
+
|
| 474 |
+
self.relu = nn.ReLU()
|
| 475 |
+
|
| 476 |
+
def forward(self, s):
|
| 477 |
+
s_initial = s
|
| 478 |
+
s = self.linear_1(s)
|
| 479 |
+
s = self.relu(s)
|
| 480 |
+
s = self.linear_2(s)
|
| 481 |
+
s = self.relu(s)
|
| 482 |
+
s = self.linear_3(s)
|
| 483 |
+
|
| 484 |
+
s = s + s_initial
|
| 485 |
+
|
| 486 |
+
return s
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
class StructureModuleTransition(nn.Module):
|
| 490 |
+
def __init__(self, c, num_layers, dropout_rate):
|
| 491 |
+
super(StructureModuleTransition, self).__init__()
|
| 492 |
+
|
| 493 |
+
self.c = c
|
| 494 |
+
self.num_layers = num_layers
|
| 495 |
+
self.dropout_rate = dropout_rate
|
| 496 |
+
|
| 497 |
+
self.layers = nn.ModuleList()
|
| 498 |
+
for _ in range(self.num_layers):
|
| 499 |
+
l = StructureModuleTransitionLayer(self.c)
|
| 500 |
+
self.layers.append(l)
|
| 501 |
+
|
| 502 |
+
self.dropout = nn.Dropout(self.dropout_rate)
|
| 503 |
+
self.layer_norm = LayerNorm(self.c)
|
| 504 |
+
|
| 505 |
+
def forward(self, s):
|
| 506 |
+
for l in self.layers:
|
| 507 |
+
s = l(s)
|
| 508 |
+
|
| 509 |
+
s = self.dropout(s)
|
| 510 |
+
s = self.layer_norm(s)
|
| 511 |
+
|
| 512 |
+
return s
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class StructureModule(nn.Module):
|
| 516 |
+
def __init__(
|
| 517 |
+
self,
|
| 518 |
+
c_s,
|
| 519 |
+
c_z,
|
| 520 |
+
c_ipa,
|
| 521 |
+
c_resnet,
|
| 522 |
+
no_heads_ipa,
|
| 523 |
+
no_qk_points,
|
| 524 |
+
no_v_points,
|
| 525 |
+
dropout_rate,
|
| 526 |
+
no_blocks,
|
| 527 |
+
no_transition_layers,
|
| 528 |
+
no_resnet_blocks,
|
| 529 |
+
no_angles,
|
| 530 |
+
trans_scale_factor,
|
| 531 |
+
epsilon,
|
| 532 |
+
inf,
|
| 533 |
+
**kwargs,
|
| 534 |
+
):
|
| 535 |
+
"""
|
| 536 |
+
Args:
|
| 537 |
+
c_s:
|
| 538 |
+
Single representation channel dimension
|
| 539 |
+
c_z:
|
| 540 |
+
Pair representation channel dimension
|
| 541 |
+
c_ipa:
|
| 542 |
+
IPA hidden channel dimension
|
| 543 |
+
c_resnet:
|
| 544 |
+
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
|
| 545 |
+
no_heads_ipa:
|
| 546 |
+
Number of IPA heads
|
| 547 |
+
no_qk_points:
|
| 548 |
+
Number of query/key points to generate during IPA
|
| 549 |
+
no_v_points:
|
| 550 |
+
Number of value points to generate during IPA
|
| 551 |
+
dropout_rate:
|
| 552 |
+
Dropout rate used throughout the layer
|
| 553 |
+
no_blocks:
|
| 554 |
+
Number of structure module blocks
|
| 555 |
+
no_transition_layers:
|
| 556 |
+
Number of layers in the single representation transition
|
| 557 |
+
(Alg. 23 lines 8-9)
|
| 558 |
+
no_resnet_blocks:
|
| 559 |
+
Number of blocks in the angle resnet
|
| 560 |
+
no_angles:
|
| 561 |
+
Number of angles to generate in the angle resnet
|
| 562 |
+
trans_scale_factor:
|
| 563 |
+
Scale of single representation transition hidden dimension
|
| 564 |
+
epsilon:
|
| 565 |
+
Small number used in angle resnet normalization
|
| 566 |
+
inf:
|
| 567 |
+
Large number used for attention masking
|
| 568 |
+
"""
|
| 569 |
+
super(StructureModule, self).__init__()
|
| 570 |
+
|
| 571 |
+
self.c_s = c_s
|
| 572 |
+
self.c_z = c_z
|
| 573 |
+
self.c_ipa = c_ipa
|
| 574 |
+
self.c_resnet = c_resnet
|
| 575 |
+
self.no_heads_ipa = no_heads_ipa
|
| 576 |
+
self.no_qk_points = no_qk_points
|
| 577 |
+
self.no_v_points = no_v_points
|
| 578 |
+
self.dropout_rate = dropout_rate
|
| 579 |
+
self.no_blocks = no_blocks
|
| 580 |
+
self.no_transition_layers = no_transition_layers
|
| 581 |
+
self.no_resnet_blocks = no_resnet_blocks
|
| 582 |
+
self.no_angles = no_angles
|
| 583 |
+
self.trans_scale_factor = trans_scale_factor
|
| 584 |
+
self.epsilon = epsilon
|
| 585 |
+
self.inf = inf
|
| 586 |
+
|
| 587 |
+
# Buffers to be lazily initialized later
|
| 588 |
+
# self.default_frames
|
| 589 |
+
# self.group_idx
|
| 590 |
+
# self.atom_mask
|
| 591 |
+
# self.lit_positions
|
| 592 |
+
|
| 593 |
+
self.layer_norm_s = LayerNorm(self.c_s)
|
| 594 |
+
self.layer_norm_z = LayerNorm(self.c_z)
|
| 595 |
+
|
| 596 |
+
self.linear_in = Linear(self.c_s, self.c_s)
|
| 597 |
+
|
| 598 |
+
self.ipa = InvariantPointAttention(
|
| 599 |
+
self.c_s,
|
| 600 |
+
self.c_z,
|
| 601 |
+
self.c_ipa,
|
| 602 |
+
self.no_heads_ipa,
|
| 603 |
+
self.no_qk_points,
|
| 604 |
+
self.no_v_points,
|
| 605 |
+
inf=self.inf,
|
| 606 |
+
eps=self.epsilon,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
self.ipa_dropout = nn.Dropout(self.dropout_rate)
|
| 610 |
+
self.layer_norm_ipa = LayerNorm(self.c_s)
|
| 611 |
+
|
| 612 |
+
self.transition = StructureModuleTransition(
|
| 613 |
+
self.c_s,
|
| 614 |
+
self.no_transition_layers,
|
| 615 |
+
self.dropout_rate,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
self.bb_update = BackboneUpdate(self.c_s)
|
| 619 |
+
|
| 620 |
+
self.angle_resnet = AngleResnet(
|
| 621 |
+
self.c_s,
|
| 622 |
+
self.c_resnet,
|
| 623 |
+
self.no_resnet_blocks,
|
| 624 |
+
self.no_angles,
|
| 625 |
+
self.epsilon,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
def forward(
|
| 629 |
+
self,
|
| 630 |
+
evoformer_output_dict,
|
| 631 |
+
aatype,
|
| 632 |
+
mask=None,
|
| 633 |
+
inplace_safe=False,
|
| 634 |
+
_offload_inference=False,
|
| 635 |
+
):
|
| 636 |
+
"""
|
| 637 |
+
Args:
|
| 638 |
+
evoformer_output_dict:
|
| 639 |
+
Dictionary containing:
|
| 640 |
+
"single":
|
| 641 |
+
[*, N_res, C_s] single representation
|
| 642 |
+
"pair":
|
| 643 |
+
[*, N_res, N_res, C_z] pair representation
|
| 644 |
+
aatype:
|
| 645 |
+
[*, N_res] amino acid indices
|
| 646 |
+
mask:
|
| 647 |
+
Optional [*, N_res] sequence mask
|
| 648 |
+
Returns:
|
| 649 |
+
A dictionary of outputs
|
| 650 |
+
"""
|
| 651 |
+
s = evoformer_output_dict["single"]
|
| 652 |
+
|
| 653 |
+
if mask is None:
|
| 654 |
+
# [*, N]
|
| 655 |
+
mask = s.new_ones(s.shape[:-1])
|
| 656 |
+
|
| 657 |
+
# [*, N, C_s]
|
| 658 |
+
s = self.layer_norm_s(s)
|
| 659 |
+
|
| 660 |
+
# [*, N, N, C_z]
|
| 661 |
+
z = self.layer_norm_z(evoformer_output_dict["pair"])
|
| 662 |
+
|
| 663 |
+
z_reference_list = None
|
| 664 |
+
if(_offload_inference):
|
| 665 |
+
assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
|
| 666 |
+
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
|
| 667 |
+
z_reference_list = [z]
|
| 668 |
+
z = None
|
| 669 |
+
|
| 670 |
+
# [*, N, C_s]
|
| 671 |
+
s_initial = s
|
| 672 |
+
s = self.linear_in(s)
|
| 673 |
+
|
| 674 |
+
# [*, N]
|
| 675 |
+
rigids = Rigid.identity(
|
| 676 |
+
s.shape[:-1],
|
| 677 |
+
s.dtype,
|
| 678 |
+
s.device,
|
| 679 |
+
self.training,
|
| 680 |
+
fmt="quat",
|
| 681 |
+
)
|
| 682 |
+
outputs = []
|
| 683 |
+
for i in range(self.no_blocks):
|
| 684 |
+
# [*, N, C_s]
|
| 685 |
+
s = s + self.ipa(
|
| 686 |
+
s,
|
| 687 |
+
z,
|
| 688 |
+
rigids,
|
| 689 |
+
mask,
|
| 690 |
+
inplace_safe=inplace_safe,
|
| 691 |
+
_offload_inference=_offload_inference,
|
| 692 |
+
_z_reference_list=z_reference_list
|
| 693 |
+
)
|
| 694 |
+
s = self.ipa_dropout(s)
|
| 695 |
+
s = self.layer_norm_ipa(s)
|
| 696 |
+
s = self.transition(s)
|
| 697 |
+
|
| 698 |
+
# [*, N]
|
| 699 |
+
rigids = rigids.compose_q_update_vec(self.bb_update(s))
|
| 700 |
+
|
| 701 |
+
# To hew as closely as possible to AlphaFold, we convert our
|
| 702 |
+
# quaternion-based transformations to rotation-matrix ones
|
| 703 |
+
# here
|
| 704 |
+
backb_to_global = Rigid(
|
| 705 |
+
Rotation(
|
| 706 |
+
rot_mats=rigids.get_rots().get_rot_mats(),
|
| 707 |
+
quats=None
|
| 708 |
+
),
|
| 709 |
+
rigids.get_trans(),
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
backb_to_global = backb_to_global.scale_translation(
|
| 713 |
+
self.trans_scale_factor
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# [*, N, 7, 2]
|
| 717 |
+
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
|
| 718 |
+
|
| 719 |
+
all_frames_to_global = self.torsion_angles_to_frames(
|
| 720 |
+
backb_to_global,
|
| 721 |
+
angles,
|
| 722 |
+
aatype,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
|
| 726 |
+
all_frames_to_global,
|
| 727 |
+
aatype,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
|
| 731 |
+
|
| 732 |
+
preds = {
|
| 733 |
+
"frames": scaled_rigids.to_tensor_7(),
|
| 734 |
+
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
|
| 735 |
+
"unnormalized_angles": unnormalized_angles,
|
| 736 |
+
"angles": angles,
|
| 737 |
+
"positions": pred_xyz,
|
| 738 |
+
"states": s,
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
outputs.append(preds)
|
| 742 |
+
|
| 743 |
+
rigids = rigids.stop_rot_gradient()
|
| 744 |
+
|
| 745 |
+
del z, z_reference_list
|
| 746 |
+
|
| 747 |
+
if(_offload_inference):
|
| 748 |
+
evoformer_output_dict["pair"] = (
|
| 749 |
+
evoformer_output_dict["pair"].to(s.device)
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
outputs = dict_multimap(torch.stack, outputs)
|
| 753 |
+
outputs["single"] = s
|
| 754 |
+
|
| 755 |
+
return outputs
|
| 756 |
+
|
| 757 |
+
def _init_residue_constants(self, float_dtype, device):
|
| 758 |
+
if not hasattr(self, "default_frames"):
|
| 759 |
+
self.register_buffer(
|
| 760 |
+
"default_frames",
|
| 761 |
+
torch.tensor(
|
| 762 |
+
restype_rigid_group_default_frame,
|
| 763 |
+
dtype=float_dtype,
|
| 764 |
+
device=device,
|
| 765 |
+
requires_grad=False,
|
| 766 |
+
),
|
| 767 |
+
persistent=False,
|
| 768 |
+
)
|
| 769 |
+
if not hasattr(self, "group_idx"):
|
| 770 |
+
self.register_buffer(
|
| 771 |
+
"group_idx",
|
| 772 |
+
torch.tensor(
|
| 773 |
+
restype_atom14_to_rigid_group,
|
| 774 |
+
device=device,
|
| 775 |
+
requires_grad=False,
|
| 776 |
+
),
|
| 777 |
+
persistent=False,
|
| 778 |
+
)
|
| 779 |
+
if not hasattr(self, "atom_mask"):
|
| 780 |
+
self.register_buffer(
|
| 781 |
+
"atom_mask",
|
| 782 |
+
torch.tensor(
|
| 783 |
+
restype_atom14_mask,
|
| 784 |
+
dtype=float_dtype,
|
| 785 |
+
device=device,
|
| 786 |
+
requires_grad=False,
|
| 787 |
+
),
|
| 788 |
+
persistent=False,
|
| 789 |
+
)
|
| 790 |
+
if not hasattr(self, "lit_positions"):
|
| 791 |
+
self.register_buffer(
|
| 792 |
+
"lit_positions",
|
| 793 |
+
torch.tensor(
|
| 794 |
+
restype_atom14_rigid_group_positions,
|
| 795 |
+
dtype=float_dtype,
|
| 796 |
+
device=device,
|
| 797 |
+
requires_grad=False,
|
| 798 |
+
),
|
| 799 |
+
persistent=False,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
def torsion_angles_to_frames(self, r, alpha, f):
|
| 803 |
+
# Lazily initialize the residue constants on the correct device
|
| 804 |
+
self._init_residue_constants(alpha.dtype, alpha.device)
|
| 805 |
+
# Separated purely to make testing less annoying
|
| 806 |
+
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
|
| 807 |
+
|
| 808 |
+
def frames_and_literature_positions_to_atom14_pos(
|
| 809 |
+
self, r, f # [*, N, 8] # [*, N]
|
| 810 |
+
):
|
| 811 |
+
# Lazily initialize the residue constants on the correct device
|
| 812 |
+
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
|
| 813 |
+
return frames_and_literature_positions_to_atom14_pos(
|
| 814 |
+
r,
|
| 815 |
+
f,
|
| 816 |
+
self.default_frames,
|
| 817 |
+
self.group_idx,
|
| 818 |
+
self.atom_mask,
|
| 819 |
+
self.lit_positions,
|
| 820 |
+
)
|
openfold/model/template.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from functools import partial
|
| 16 |
+
import math
|
| 17 |
+
from typing import Optional, List
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from openfold.model.primitives import Linear, LayerNorm, Attention
|
| 23 |
+
from openfold.model.dropout import (
|
| 24 |
+
DropoutRowwise,
|
| 25 |
+
DropoutColumnwise,
|
| 26 |
+
)
|
| 27 |
+
from openfold.model.pair_transition import PairTransition
|
| 28 |
+
from openfold.model.triangular_attention import (
|
| 29 |
+
TriangleAttentionStartingNode,
|
| 30 |
+
TriangleAttentionEndingNode,
|
| 31 |
+
)
|
| 32 |
+
from openfold.model.triangular_multiplicative_update import (
|
| 33 |
+
TriangleMultiplicationOutgoing,
|
| 34 |
+
TriangleMultiplicationIncoming,
|
| 35 |
+
)
|
| 36 |
+
from openfold.utils.checkpointing import checkpoint_blocks
|
| 37 |
+
from openfold.utils.tensor_utils import (
|
| 38 |
+
chunk_layer,
|
| 39 |
+
permute_final_dims,
|
| 40 |
+
flatten_final_dims,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TemplatePointwiseAttention(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
Implements Algorithm 17.
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
c_t:
|
| 52 |
+
Template embedding channel dimension
|
| 53 |
+
c_z:
|
| 54 |
+
Pair embedding channel dimension
|
| 55 |
+
c_hidden:
|
| 56 |
+
Hidden channel dimension
|
| 57 |
+
"""
|
| 58 |
+
super(TemplatePointwiseAttention, self).__init__()
|
| 59 |
+
|
| 60 |
+
self.c_t = c_t
|
| 61 |
+
self.c_z = c_z
|
| 62 |
+
self.c_hidden = c_hidden
|
| 63 |
+
self.no_heads = no_heads
|
| 64 |
+
self.inf = inf
|
| 65 |
+
|
| 66 |
+
self.mha = Attention(
|
| 67 |
+
self.c_z,
|
| 68 |
+
self.c_t,
|
| 69 |
+
self.c_t,
|
| 70 |
+
self.c_hidden,
|
| 71 |
+
self.no_heads,
|
| 72 |
+
gating=False,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def _chunk(self,
|
| 76 |
+
z: torch.Tensor,
|
| 77 |
+
t: torch.Tensor,
|
| 78 |
+
biases: List[torch.Tensor],
|
| 79 |
+
chunk_size: int,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
mha_inputs = {
|
| 82 |
+
"q_x": z,
|
| 83 |
+
"kv_x": t,
|
| 84 |
+
"biases": biases,
|
| 85 |
+
}
|
| 86 |
+
return chunk_layer(
|
| 87 |
+
self.mha,
|
| 88 |
+
mha_inputs,
|
| 89 |
+
chunk_size=chunk_size,
|
| 90 |
+
no_batch_dims=len(z.shape[:-2]),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def forward(self,
|
| 95 |
+
t: torch.Tensor,
|
| 96 |
+
z: torch.Tensor,
|
| 97 |
+
template_mask: Optional[torch.Tensor] = None,
|
| 98 |
+
chunk_size: Optional[int] = None
|
| 99 |
+
) -> torch.Tensor:
|
| 100 |
+
"""
|
| 101 |
+
Args:
|
| 102 |
+
t:
|
| 103 |
+
[*, N_templ, N_res, N_res, C_t] template embedding
|
| 104 |
+
z:
|
| 105 |
+
[*, N_res, N_res, C_t] pair embedding
|
| 106 |
+
template_mask:
|
| 107 |
+
[*, N_templ] template mask
|
| 108 |
+
Returns:
|
| 109 |
+
[*, N_res, N_res, C_z] pair embedding update
|
| 110 |
+
"""
|
| 111 |
+
if template_mask is None:
|
| 112 |
+
template_mask = t.new_ones(t.shape[:-3])
|
| 113 |
+
|
| 114 |
+
bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
|
| 115 |
+
|
| 116 |
+
# [*, N_res, N_res, 1, C_z]
|
| 117 |
+
z = z.unsqueeze(-2)
|
| 118 |
+
|
| 119 |
+
# [*, N_res, N_res, N_temp, C_t]
|
| 120 |
+
t = permute_final_dims(t, (1, 2, 0, 3))
|
| 121 |
+
|
| 122 |
+
# [*, N_res, N_res, 1, C_z]
|
| 123 |
+
biases = [bias]
|
| 124 |
+
if chunk_size is not None:
|
| 125 |
+
z = self._chunk(z, t, biases, chunk_size)
|
| 126 |
+
else:
|
| 127 |
+
z = self.mha(q_x=z, kv_x=t, biases=biases)
|
| 128 |
+
|
| 129 |
+
# [*, N_res, N_res, C_z]
|
| 130 |
+
z = z.squeeze(-2)
|
| 131 |
+
|
| 132 |
+
return z
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class TemplatePairStackBlock(nn.Module):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
c_t: int,
|
| 139 |
+
c_hidden_tri_att: int,
|
| 140 |
+
c_hidden_tri_mul: int,
|
| 141 |
+
no_heads: int,
|
| 142 |
+
pair_transition_n: int,
|
| 143 |
+
dropout_rate: float,
|
| 144 |
+
inf: float,
|
| 145 |
+
**kwargs,
|
| 146 |
+
):
|
| 147 |
+
super(TemplatePairStackBlock, self).__init__()
|
| 148 |
+
|
| 149 |
+
self.c_t = c_t
|
| 150 |
+
self.c_hidden_tri_att = c_hidden_tri_att
|
| 151 |
+
self.c_hidden_tri_mul = c_hidden_tri_mul
|
| 152 |
+
self.no_heads = no_heads
|
| 153 |
+
self.pair_transition_n = pair_transition_n
|
| 154 |
+
self.dropout_rate = dropout_rate
|
| 155 |
+
self.inf = inf
|
| 156 |
+
|
| 157 |
+
self.dropout_row = DropoutRowwise(self.dropout_rate)
|
| 158 |
+
self.dropout_col = DropoutColumnwise(self.dropout_rate)
|
| 159 |
+
|
| 160 |
+
self.tri_att_start = TriangleAttentionStartingNode(
|
| 161 |
+
self.c_t,
|
| 162 |
+
self.c_hidden_tri_att,
|
| 163 |
+
self.no_heads,
|
| 164 |
+
inf=inf,
|
| 165 |
+
)
|
| 166 |
+
self.tri_att_end = TriangleAttentionEndingNode(
|
| 167 |
+
self.c_t,
|
| 168 |
+
self.c_hidden_tri_att,
|
| 169 |
+
self.no_heads,
|
| 170 |
+
inf=inf,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
| 174 |
+
self.c_t,
|
| 175 |
+
self.c_hidden_tri_mul,
|
| 176 |
+
)
|
| 177 |
+
self.tri_mul_in = TriangleMultiplicationIncoming(
|
| 178 |
+
self.c_t,
|
| 179 |
+
self.c_hidden_tri_mul,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
self.pair_transition = PairTransition(
|
| 183 |
+
self.c_t,
|
| 184 |
+
self.pair_transition_n,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def forward(self,
|
| 188 |
+
z: torch.Tensor,
|
| 189 |
+
mask: torch.Tensor,
|
| 190 |
+
chunk_size: Optional[int] = None,
|
| 191 |
+
_mask_trans: bool = True
|
| 192 |
+
):
|
| 193 |
+
single_templates = [
|
| 194 |
+
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
|
| 195 |
+
]
|
| 196 |
+
single_templates_masks = [
|
| 197 |
+
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
|
| 198 |
+
]
|
| 199 |
+
for i in range(len(single_templates)):
|
| 200 |
+
single = single_templates[i]
|
| 201 |
+
single_mask = single_templates_masks[i]
|
| 202 |
+
|
| 203 |
+
single = single + self.dropout_row(
|
| 204 |
+
self.tri_att_start(
|
| 205 |
+
single,
|
| 206 |
+
chunk_size=chunk_size,
|
| 207 |
+
mask=single_mask
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
single = single + self.dropout_col(
|
| 211 |
+
self.tri_att_end(
|
| 212 |
+
single,
|
| 213 |
+
chunk_size=chunk_size,
|
| 214 |
+
mask=single_mask
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
single = single + self.dropout_row(
|
| 218 |
+
self.tri_mul_out(
|
| 219 |
+
single,
|
| 220 |
+
mask=single_mask
|
| 221 |
+
)
|
| 222 |
+
)
|
| 223 |
+
single = single + self.dropout_row(
|
| 224 |
+
self.tri_mul_in(
|
| 225 |
+
single,
|
| 226 |
+
mask=single_mask
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
single = single + self.pair_transition(
|
| 230 |
+
single,
|
| 231 |
+
mask=single_mask if _mask_trans else None,
|
| 232 |
+
chunk_size=chunk_size,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
single_templates[i] = single
|
| 236 |
+
|
| 237 |
+
z = torch.cat(single_templates, dim=-4)
|
| 238 |
+
|
| 239 |
+
return z
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class TemplatePairStack(nn.Module):
|
| 243 |
+
"""
|
| 244 |
+
Implements Algorithm 16.
|
| 245 |
+
"""
|
| 246 |
+
def __init__(
|
| 247 |
+
self,
|
| 248 |
+
c_t,
|
| 249 |
+
c_hidden_tri_att,
|
| 250 |
+
c_hidden_tri_mul,
|
| 251 |
+
no_blocks,
|
| 252 |
+
no_heads,
|
| 253 |
+
pair_transition_n,
|
| 254 |
+
dropout_rate,
|
| 255 |
+
blocks_per_ckpt,
|
| 256 |
+
inf=1e9,
|
| 257 |
+
**kwargs,
|
| 258 |
+
):
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
c_t:
|
| 262 |
+
Template embedding channel dimension
|
| 263 |
+
c_hidden_tri_att:
|
| 264 |
+
Per-head hidden dimension for triangular attention
|
| 265 |
+
c_hidden_tri_att:
|
| 266 |
+
Hidden dimension for triangular multiplication
|
| 267 |
+
no_blocks:
|
| 268 |
+
Number of blocks in the stack
|
| 269 |
+
pair_transition_n:
|
| 270 |
+
Scale of pair transition (Alg. 15) hidden dimension
|
| 271 |
+
dropout_rate:
|
| 272 |
+
Dropout rate used throughout the stack
|
| 273 |
+
blocks_per_ckpt:
|
| 274 |
+
Number of blocks per activation checkpoint. None disables
|
| 275 |
+
activation checkpointing
|
| 276 |
+
"""
|
| 277 |
+
super(TemplatePairStack, self).__init__()
|
| 278 |
+
|
| 279 |
+
self.blocks_per_ckpt = blocks_per_ckpt
|
| 280 |
+
|
| 281 |
+
self.blocks = nn.ModuleList()
|
| 282 |
+
for _ in range(no_blocks):
|
| 283 |
+
block = TemplatePairStackBlock(
|
| 284 |
+
c_t=c_t,
|
| 285 |
+
c_hidden_tri_att=c_hidden_tri_att,
|
| 286 |
+
c_hidden_tri_mul=c_hidden_tri_mul,
|
| 287 |
+
no_heads=no_heads,
|
| 288 |
+
pair_transition_n=pair_transition_n,
|
| 289 |
+
dropout_rate=dropout_rate,
|
| 290 |
+
inf=inf,
|
| 291 |
+
)
|
| 292 |
+
self.blocks.append(block)
|
| 293 |
+
|
| 294 |
+
self.layer_norm = LayerNorm(c_t)
|
| 295 |
+
|
| 296 |
+
def forward(
|
| 297 |
+
self,
|
| 298 |
+
t: torch.tensor,
|
| 299 |
+
mask: torch.tensor,
|
| 300 |
+
chunk_size: int,
|
| 301 |
+
_mask_trans: bool = True,
|
| 302 |
+
):
|
| 303 |
+
"""
|
| 304 |
+
Args:
|
| 305 |
+
t:
|
| 306 |
+
[*, N_templ, N_res, N_res, C_t] template embedding
|
| 307 |
+
mask:
|
| 308 |
+
[*, N_templ, N_res, N_res] mask
|
| 309 |
+
Returns:
|
| 310 |
+
[*, N_templ, N_res, N_res, C_t] template embedding update
|
| 311 |
+
"""
|
| 312 |
+
if(mask.shape[-3] == 1):
|
| 313 |
+
expand_idx = list(mask.shape)
|
| 314 |
+
expand_idx[-3] = t.shape[-4]
|
| 315 |
+
mask = mask.expand(*expand_idx)
|
| 316 |
+
|
| 317 |
+
t, = checkpoint_blocks(
|
| 318 |
+
blocks=[
|
| 319 |
+
partial(
|
| 320 |
+
b,
|
| 321 |
+
mask=mask,
|
| 322 |
+
chunk_size=chunk_size,
|
| 323 |
+
_mask_trans=_mask_trans,
|
| 324 |
+
)
|
| 325 |
+
for b in self.blocks
|
| 326 |
+
],
|
| 327 |
+
args=(t,),
|
| 328 |
+
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
t = self.layer_norm(t)
|
| 332 |
+
|
| 333 |
+
return t
|
openfold/model/torchscript.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Optional, Sequence, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from openfold.model.dropout import (
|
| 21 |
+
DropoutRowwise,
|
| 22 |
+
DropoutColumnwise,
|
| 23 |
+
)
|
| 24 |
+
from openfold.model.evoformer import (
|
| 25 |
+
EvoformerBlock,
|
| 26 |
+
EvoformerStack,
|
| 27 |
+
)
|
| 28 |
+
from openfold.model.outer_product_mean import OuterProductMean
|
| 29 |
+
from openfold.model.msa import (
|
| 30 |
+
MSARowAttentionWithPairBias,
|
| 31 |
+
MSAColumnAttention,
|
| 32 |
+
MSAColumnGlobalAttention,
|
| 33 |
+
)
|
| 34 |
+
from openfold.model.pair_transition import PairTransition
|
| 35 |
+
from openfold.model.primitives import Attention, GlobalAttention
|
| 36 |
+
from openfold.model.structure_module import (
|
| 37 |
+
InvariantPointAttention,
|
| 38 |
+
BackboneUpdate,
|
| 39 |
+
)
|
| 40 |
+
from openfold.model.template import TemplatePairStackBlock
|
| 41 |
+
from openfold.model.triangular_attention import (
|
| 42 |
+
TriangleAttentionStartingNode,
|
| 43 |
+
TriangleAttentionEndingNode,
|
| 44 |
+
)
|
| 45 |
+
from openfold.model.triangular_multiplicative_update import (
|
| 46 |
+
TriangleMultiplicationOutgoing,
|
| 47 |
+
TriangleMultiplicationIncoming,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def script_preset_(model: torch.nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
TorchScript a handful of low-level but frequently used submodule types
|
| 54 |
+
that are known to be scriptable.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model:
|
| 58 |
+
A torch.nn.Module. It should contain at least some modules from
|
| 59 |
+
this repository, or this function won't do anything.
|
| 60 |
+
"""
|
| 61 |
+
script_submodules_(
|
| 62 |
+
model,
|
| 63 |
+
[
|
| 64 |
+
nn.Dropout,
|
| 65 |
+
Attention,
|
| 66 |
+
GlobalAttention,
|
| 67 |
+
EvoformerBlock,
|
| 68 |
+
#TemplatePairStackBlock,
|
| 69 |
+
],
|
| 70 |
+
attempt_trace=False,
|
| 71 |
+
batch_dims=None,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _get_module_device(module: torch.nn.Module) -> torch.device:
|
| 76 |
+
"""
|
| 77 |
+
Fetches the device of a module, assuming that all of the module's
|
| 78 |
+
parameters reside on a single device
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
module: A torch.nn.Module
|
| 82 |
+
Returns:
|
| 83 |
+
The module's device
|
| 84 |
+
"""
|
| 85 |
+
return next(module.parameters()).device
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _trace_module(module, batch_dims=None):
|
| 89 |
+
if(batch_dims is None):
|
| 90 |
+
batch_dims = ()
|
| 91 |
+
|
| 92 |
+
# Stand-in values
|
| 93 |
+
n_seq = 10
|
| 94 |
+
n_res = 10
|
| 95 |
+
|
| 96 |
+
device = _get_module_device(module)
|
| 97 |
+
|
| 98 |
+
def msa(channel_dim):
|
| 99 |
+
return torch.rand(
|
| 100 |
+
(*batch_dims, n_seq, n_res, channel_dim),
|
| 101 |
+
device=device,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def pair(channel_dim):
|
| 105 |
+
return torch.rand(
|
| 106 |
+
(*batch_dims, n_res, n_res, channel_dim),
|
| 107 |
+
device=device,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if(isinstance(module, MSARowAttentionWithPairBias)):
|
| 111 |
+
inputs = {
|
| 112 |
+
"forward": (
|
| 113 |
+
msa(module.c_in), # m
|
| 114 |
+
pair(module.c_z), # z
|
| 115 |
+
torch.randint(
|
| 116 |
+
0, 2,
|
| 117 |
+
(*batch_dims, n_seq, n_res)
|
| 118 |
+
), # mask
|
| 119 |
+
),
|
| 120 |
+
}
|
| 121 |
+
elif(isinstance(module, MSAColumnAttention)):
|
| 122 |
+
inputs = {
|
| 123 |
+
"forward": (
|
| 124 |
+
msa(module.c_in), # m
|
| 125 |
+
torch.randint(
|
| 126 |
+
0, 2,
|
| 127 |
+
(*batch_dims, n_seq, n_res)
|
| 128 |
+
), # mask
|
| 129 |
+
),
|
| 130 |
+
}
|
| 131 |
+
elif(isinstance(module, OuterProductMean)):
|
| 132 |
+
inputs = {
|
| 133 |
+
"forward": (
|
| 134 |
+
msa(module.c_m),
|
| 135 |
+
torch.randint(
|
| 136 |
+
0, 2,
|
| 137 |
+
(*batch_dims, n_seq, n_res)
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
}
|
| 141 |
+
else:
|
| 142 |
+
raise TypeError(
|
| 143 |
+
f"tracing is not supported for modules of type {type(module)}"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return torch.jit.trace_module(module, inputs)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _script_submodules_helper_(
|
| 150 |
+
model,
|
| 151 |
+
types,
|
| 152 |
+
attempt_trace,
|
| 153 |
+
to_trace,
|
| 154 |
+
):
|
| 155 |
+
for name, child in model.named_children():
|
| 156 |
+
if(types is None or any(isinstance(child, t) for t in types)):
|
| 157 |
+
try:
|
| 158 |
+
scripted = torch.jit.script(child)
|
| 159 |
+
setattr(model, name, scripted)
|
| 160 |
+
continue
|
| 161 |
+
except (RuntimeError, torch.jit.frontend.NotSupportedError) as e:
|
| 162 |
+
if(attempt_trace):
|
| 163 |
+
to_trace.add(type(child))
|
| 164 |
+
else:
|
| 165 |
+
raise e
|
| 166 |
+
|
| 167 |
+
_script_submodules_helper_(child, types, attempt_trace, to_trace)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _trace_submodules_(
|
| 171 |
+
model,
|
| 172 |
+
types,
|
| 173 |
+
batch_dims=None,
|
| 174 |
+
):
|
| 175 |
+
for name, child in model.named_children():
|
| 176 |
+
if(any(isinstance(child, t) for t in types)):
|
| 177 |
+
traced = _trace_module(child, batch_dims=batch_dims)
|
| 178 |
+
setattr(model, name, traced)
|
| 179 |
+
else:
|
| 180 |
+
_trace_submodules_(child, types, batch_dims=batch_dims)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def script_submodules_(
|
| 184 |
+
model: nn.Module,
|
| 185 |
+
types: Optional[Sequence[type]] = None,
|
| 186 |
+
attempt_trace: Optional[bool] = True,
|
| 187 |
+
batch_dims: Optional[Tuple[int]] = None,
|
| 188 |
+
):
|
| 189 |
+
"""
|
| 190 |
+
Convert all submodules whose types match one of those in the input
|
| 191 |
+
list to recursively scripted equivalents in place. To script the entire
|
| 192 |
+
model, just call torch.jit.script on it directly.
|
| 193 |
+
|
| 194 |
+
When types is None, all submodules are scripted.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
model:
|
| 198 |
+
A torch.nn.Module
|
| 199 |
+
types:
|
| 200 |
+
A list of types of submodules to script
|
| 201 |
+
attempt_trace:
|
| 202 |
+
Whether to attempt to trace specified modules if scripting
|
| 203 |
+
fails. Recall that tracing eliminates all conditional
|
| 204 |
+
logic---with great tracing comes the mild responsibility of
|
| 205 |
+
having to remember to ensure that the modules in question
|
| 206 |
+
perform the same computations no matter what.
|
| 207 |
+
"""
|
| 208 |
+
to_trace = set()
|
| 209 |
+
|
| 210 |
+
# Aggressively script as much as possible first...
|
| 211 |
+
_script_submodules_helper_(model, types, attempt_trace, to_trace)
|
| 212 |
+
|
| 213 |
+
# ... and then trace stragglers.
|
| 214 |
+
if(attempt_trace and len(to_trace) > 0):
|
| 215 |
+
_trace_submodules_(model, to_trace, batch_dims=batch_dims)
|
openfold/model/triangular_attention.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partialmethod, partial
|
| 17 |
+
import math
|
| 18 |
+
from typing import Optional, List
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
from openfold.model.primitives import Linear, LayerNorm, Attention
|
| 24 |
+
from openfold.utils.tensor_utils import (
|
| 25 |
+
chunk_layer,
|
| 26 |
+
permute_final_dims,
|
| 27 |
+
flatten_final_dims,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TriangleAttention(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self, c_in, c_hidden, no_heads, starting, inf=1e9
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
c_in:
|
| 38 |
+
Input channel dimension
|
| 39 |
+
c_hidden:
|
| 40 |
+
Overall hidden channel dimension (not per-head)
|
| 41 |
+
no_heads:
|
| 42 |
+
Number of attention heads
|
| 43 |
+
"""
|
| 44 |
+
super(TriangleAttention, self).__init__()
|
| 45 |
+
|
| 46 |
+
self.c_in = c_in
|
| 47 |
+
self.c_hidden = c_hidden
|
| 48 |
+
self.no_heads = no_heads
|
| 49 |
+
self.starting = starting
|
| 50 |
+
self.inf = inf
|
| 51 |
+
|
| 52 |
+
self.layer_norm = LayerNorm(self.c_in)
|
| 53 |
+
|
| 54 |
+
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
|
| 55 |
+
|
| 56 |
+
self.mha = Attention(
|
| 57 |
+
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
@torch.jit.ignore
|
| 61 |
+
def _chunk(self,
|
| 62 |
+
x: torch.Tensor,
|
| 63 |
+
biases: List[torch.Tensor],
|
| 64 |
+
chunk_size: int,
|
| 65 |
+
) -> torch.Tensor:
|
| 66 |
+
mha_inputs = {
|
| 67 |
+
"q_x": x,
|
| 68 |
+
"kv_x": x,
|
| 69 |
+
"biases": biases,
|
| 70 |
+
}
|
| 71 |
+
return chunk_layer(
|
| 72 |
+
partial(self.mha),
|
| 73 |
+
mha_inputs,
|
| 74 |
+
chunk_size=chunk_size,
|
| 75 |
+
no_batch_dims=len(x.shape[:-2]),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self,
|
| 79 |
+
x: torch.Tensor,
|
| 80 |
+
mask: Optional[torch.Tensor] = None,
|
| 81 |
+
chunk_size: Optional[int] = None
|
| 82 |
+
) -> torch.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Args:
|
| 85 |
+
x:
|
| 86 |
+
[*, I, J, C_in] input tensor (e.g. the pair representation)
|
| 87 |
+
Returns:
|
| 88 |
+
[*, I, J, C_in] output tensor
|
| 89 |
+
"""
|
| 90 |
+
if mask is None:
|
| 91 |
+
# [*, I, J]
|
| 92 |
+
mask = x.new_ones(
|
| 93 |
+
x.shape[:-1],
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Shape annotations assume self.starting. Else, I and J are flipped
|
| 97 |
+
if not self.starting:
|
| 98 |
+
x = x.transpose(-2, -3)
|
| 99 |
+
mask = mask.transpose(-1, -2)
|
| 100 |
+
|
| 101 |
+
# [*, I, J, C_in]
|
| 102 |
+
x = self.layer_norm(x)
|
| 103 |
+
|
| 104 |
+
# [*, I, 1, 1, J]
|
| 105 |
+
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
| 106 |
+
|
| 107 |
+
# [*, H, I, J]
|
| 108 |
+
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
| 109 |
+
|
| 110 |
+
# [*, 1, H, I, J]
|
| 111 |
+
triangle_bias = triangle_bias.unsqueeze(-4)
|
| 112 |
+
|
| 113 |
+
biases = [mask_bias, triangle_bias]
|
| 114 |
+
|
| 115 |
+
if chunk_size is not None:
|
| 116 |
+
x = self._chunk(x, biases, chunk_size)
|
| 117 |
+
else:
|
| 118 |
+
x = self.mha(q_x=x, kv_x=x, biases=biases)
|
| 119 |
+
|
| 120 |
+
if not self.starting:
|
| 121 |
+
x = x.transpose(-2, -3)
|
| 122 |
+
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class TriangleAttentionStartingNode(TriangleAttention):
|
| 127 |
+
"""
|
| 128 |
+
Implements Algorithm 13.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class TriangleAttentionEndingNode(TriangleAttention):
|
| 135 |
+
"""
|
| 136 |
+
Implements Algorithm 14.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
|
openfold/model/triangular_multiplicative_update.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partialmethod
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from openfold.model.primitives import Linear, LayerNorm
|
| 23 |
+
from openfold.utils.tensor_utils import permute_final_dims
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TriangleMultiplicativeUpdate(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Implements Algorithms 11 and 12.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, c_z, c_hidden, _outgoing=True):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
c_z:
|
| 34 |
+
Input channel dimension
|
| 35 |
+
c:
|
| 36 |
+
Hidden channel dimension
|
| 37 |
+
"""
|
| 38 |
+
super(TriangleMultiplicativeUpdate, self).__init__()
|
| 39 |
+
self.c_z = c_z
|
| 40 |
+
self.c_hidden = c_hidden
|
| 41 |
+
self._outgoing = _outgoing
|
| 42 |
+
|
| 43 |
+
self.linear_a_p = Linear(self.c_z, self.c_hidden)
|
| 44 |
+
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
|
| 45 |
+
self.linear_b_p = Linear(self.c_z, self.c_hidden)
|
| 46 |
+
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
|
| 47 |
+
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
|
| 48 |
+
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
|
| 49 |
+
|
| 50 |
+
self.layer_norm_in = LayerNorm(self.c_z)
|
| 51 |
+
self.layer_norm_out = LayerNorm(self.c_hidden)
|
| 52 |
+
|
| 53 |
+
self.sigmoid = nn.Sigmoid()
|
| 54 |
+
|
| 55 |
+
def _combine_projections(self,
|
| 56 |
+
a: torch.Tensor,
|
| 57 |
+
b: torch.Tensor,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
raise NotImplementedError("This method needs to be overridden")
|
| 60 |
+
|
| 61 |
+
def forward(self,
|
| 62 |
+
z: torch.Tensor,
|
| 63 |
+
mask: Optional[torch.Tensor] = None
|
| 64 |
+
) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
Args:
|
| 67 |
+
x:
|
| 68 |
+
[*, N_res, N_res, C_z] input tensor
|
| 69 |
+
mask:
|
| 70 |
+
[*, N_res, N_res] input mask
|
| 71 |
+
Returns:
|
| 72 |
+
[*, N_res, N_res, C_z] output tensor
|
| 73 |
+
"""
|
| 74 |
+
if mask is None:
|
| 75 |
+
mask = z.new_ones(z.shape[:-1])
|
| 76 |
+
|
| 77 |
+
mask = mask.unsqueeze(-1)
|
| 78 |
+
|
| 79 |
+
z = self.layer_norm_in(z)
|
| 80 |
+
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
|
| 81 |
+
a = a * mask
|
| 82 |
+
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
|
| 83 |
+
b = b * mask
|
| 84 |
+
x = self._combine_projections(a, b)
|
| 85 |
+
x = self.layer_norm_out(x)
|
| 86 |
+
x = self.linear_z(x)
|
| 87 |
+
g = self.sigmoid(self.linear_g(z))
|
| 88 |
+
z = x * g
|
| 89 |
+
|
| 90 |
+
return z
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
|
| 94 |
+
"""
|
| 95 |
+
Implements Algorithm 11.
|
| 96 |
+
"""
|
| 97 |
+
def _combine_projections(self,
|
| 98 |
+
a: torch.Tensor, # [*, N_i, N_k, C]
|
| 99 |
+
b: torch.Tensor, # [*, N_j, N_k, C]
|
| 100 |
+
):
|
| 101 |
+
# [*, C, N_i, N_j]
|
| 102 |
+
p = torch.matmul(
|
| 103 |
+
permute_final_dims(a, (2, 0, 1)),
|
| 104 |
+
permute_final_dims(b, (2, 1, 0)),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# [*, N_i, N_j, C]
|
| 108 |
+
return permute_final_dims(p, (1, 2, 0))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
|
| 112 |
+
"""
|
| 113 |
+
Implements Algorithm 12.
|
| 114 |
+
"""
|
| 115 |
+
def _combine_projections(self,
|
| 116 |
+
a: torch.Tensor, # [*, N_k, N_i, C]
|
| 117 |
+
b: torch.Tensor, # [*, N_k, N_j, C]
|
| 118 |
+
):
|
| 119 |
+
# [*, C, N_i, N_j]
|
| 120 |
+
p = torch.matmul(
|
| 121 |
+
permute_final_dims(a, (2, 1, 0)),
|
| 122 |
+
permute_final_dims(b, (2, 0, 1)),
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# [*, N_i, N_j, C]
|
| 126 |
+
return permute_final_dims(p, (1, 2, 0))
|
| 127 |
+
|
openfold/np/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import importlib as importlib
|
| 4 |
+
|
| 5 |
+
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
|
| 6 |
+
__all__ = [
|
| 7 |
+
os.path.basename(f)[:-3]
|
| 8 |
+
for f in _files
|
| 9 |
+
if os.path.isfile(f) and not f.endswith("__init__.py")
|
| 10 |
+
]
|
| 11 |
+
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
|
| 12 |
+
for _m in _modules:
|
| 13 |
+
globals()[_m[0]] = _m[1]
|
| 14 |
+
|
| 15 |
+
# Avoid needlessly cluttering the global namespace
|
| 16 |
+
del _files, _m, _modules
|
openfold/np/protein.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Protein data type."""
|
| 17 |
+
import dataclasses
|
| 18 |
+
import io
|
| 19 |
+
from typing import Any, Sequence, Mapping, Optional
|
| 20 |
+
import re
|
| 21 |
+
import string
|
| 22 |
+
|
| 23 |
+
from openfold.np import residue_constants
|
| 24 |
+
from Bio.PDB import PDBParser
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
FeatureDict = Mapping[str, np.ndarray]
|
| 29 |
+
ModelOutput = Mapping[str, Any] # Is a nested dict.
|
| 30 |
+
PICO_TO_ANGSTROM = 0.01
|
| 31 |
+
|
| 32 |
+
@dataclasses.dataclass(frozen=True)
|
| 33 |
+
class Protein:
|
| 34 |
+
"""Protein structure representation."""
|
| 35 |
+
|
| 36 |
+
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
|
| 37 |
+
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
|
| 38 |
+
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
|
| 39 |
+
|
| 40 |
+
# Amino-acid type for each residue represented as an integer between 0 and
|
| 41 |
+
# 20, where 20 is 'X'.
|
| 42 |
+
aatype: np.ndarray # [num_res]
|
| 43 |
+
|
| 44 |
+
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
|
| 45 |
+
# is present and 0.0 if not. This should be used for loss masking.
|
| 46 |
+
atom_mask: np.ndarray # [num_res, num_atom_type]
|
| 47 |
+
|
| 48 |
+
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
|
| 49 |
+
residue_index: np.ndarray # [num_res]
|
| 50 |
+
|
| 51 |
+
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
|
| 52 |
+
# representing the displacement of the residue from its ground truth mean
|
| 53 |
+
# value.
|
| 54 |
+
b_factors: np.ndarray # [num_res, num_atom_type]
|
| 55 |
+
|
| 56 |
+
# Chain indices for multi-chain predictions
|
| 57 |
+
chain_index: Optional[np.ndarray] = None
|
| 58 |
+
|
| 59 |
+
# Optional remark about the protein. Included as a comment in output PDB
|
| 60 |
+
# files
|
| 61 |
+
remark: Optional[str] = None
|
| 62 |
+
|
| 63 |
+
# Templates used to generate this protein (prediction-only)
|
| 64 |
+
parents: Optional[Sequence[str]] = None
|
| 65 |
+
|
| 66 |
+
# Chain corresponding to each parent
|
| 67 |
+
parents_chain_index: Optional[Sequence[int]] = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
|
| 71 |
+
"""Takes a PDB string and constructs a Protein object.
|
| 72 |
+
|
| 73 |
+
WARNING: All non-standard residue types will be converted into UNK. All
|
| 74 |
+
non-standard atoms will be ignored.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
pdb_str: The contents of the pdb file
|
| 78 |
+
chain_id: If None, then the pdb file must contain a single chain (which
|
| 79 |
+
will be parsed). If chain_id is specified (e.g. A), then only that chain
|
| 80 |
+
is parsed.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
A new `Protein` parsed from the pdb contents.
|
| 84 |
+
"""
|
| 85 |
+
pdb_fh = io.StringIO(pdb_str)
|
| 86 |
+
parser = PDBParser(QUIET=True)
|
| 87 |
+
structure = parser.get_structure("none", pdb_fh)
|
| 88 |
+
models = list(structure.get_models())
|
| 89 |
+
if len(models) != 1:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"Only single model PDBs are supported. Found {len(models)} models."
|
| 92 |
+
)
|
| 93 |
+
model = models[0]
|
| 94 |
+
|
| 95 |
+
atom_positions = []
|
| 96 |
+
aatype = []
|
| 97 |
+
atom_mask = []
|
| 98 |
+
residue_index = []
|
| 99 |
+
chain_ids = []
|
| 100 |
+
b_factors = []
|
| 101 |
+
|
| 102 |
+
for chain in model:
|
| 103 |
+
if(chain_id is not None and chain.id != chain_id):
|
| 104 |
+
continue
|
| 105 |
+
for res in chain:
|
| 106 |
+
if res.id[2] != " ":
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"PDB contains an insertion code at chain {chain.id} and residue "
|
| 109 |
+
f"index {res.id[1]}. These are not supported."
|
| 110 |
+
)
|
| 111 |
+
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
|
| 112 |
+
restype_idx = residue_constants.restype_order.get(
|
| 113 |
+
res_shortname, residue_constants.restype_num
|
| 114 |
+
)
|
| 115 |
+
pos = np.zeros((residue_constants.atom_type_num, 3))
|
| 116 |
+
mask = np.zeros((residue_constants.atom_type_num,))
|
| 117 |
+
res_b_factors = np.zeros((residue_constants.atom_type_num,))
|
| 118 |
+
for atom in res:
|
| 119 |
+
if atom.name not in residue_constants.atom_types:
|
| 120 |
+
continue
|
| 121 |
+
pos[residue_constants.atom_order[atom.name]] = atom.coord
|
| 122 |
+
mask[residue_constants.atom_order[atom.name]] = 1.0
|
| 123 |
+
res_b_factors[
|
| 124 |
+
residue_constants.atom_order[atom.name]
|
| 125 |
+
] = atom.bfactor
|
| 126 |
+
if np.sum(mask) < 0.5:
|
| 127 |
+
# If no known atom positions are reported for the residue then skip it.
|
| 128 |
+
continue
|
| 129 |
+
aatype.append(restype_idx)
|
| 130 |
+
atom_positions.append(pos)
|
| 131 |
+
atom_mask.append(mask)
|
| 132 |
+
residue_index.append(res.id[1])
|
| 133 |
+
chain_ids.append(chain.id)
|
| 134 |
+
b_factors.append(res_b_factors)
|
| 135 |
+
|
| 136 |
+
parents = None
|
| 137 |
+
parents_chain_index = None
|
| 138 |
+
if("PARENT" in pdb_str):
|
| 139 |
+
parents = []
|
| 140 |
+
parents_chain_index = []
|
| 141 |
+
chain_id = 0
|
| 142 |
+
for l in pdb_str.split("\n"):
|
| 143 |
+
if("PARENT" in l):
|
| 144 |
+
if(not "N/A" in l):
|
| 145 |
+
parent_names = l.split()[1:]
|
| 146 |
+
parents.extend(parent_names)
|
| 147 |
+
parents_chain_index.extend([
|
| 148 |
+
chain_id for _ in parent_names
|
| 149 |
+
])
|
| 150 |
+
chain_id += 1
|
| 151 |
+
|
| 152 |
+
unique_chain_ids = np.unique(chain_ids)
|
| 153 |
+
chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
|
| 154 |
+
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
|
| 155 |
+
|
| 156 |
+
return Protein(
|
| 157 |
+
atom_positions=np.array(atom_positions),
|
| 158 |
+
atom_mask=np.array(atom_mask),
|
| 159 |
+
aatype=np.array(aatype),
|
| 160 |
+
residue_index=np.array(residue_index),
|
| 161 |
+
chain_index=chain_index,
|
| 162 |
+
b_factors=np.array(b_factors),
|
| 163 |
+
parents=parents,
|
| 164 |
+
parents_chain_index=parents_chain_index,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
| 169 |
+
tag_re = r'(\[[A-Z]+\]\n)'
|
| 170 |
+
tags = [
|
| 171 |
+
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
|
| 172 |
+
]
|
| 173 |
+
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
|
| 174 |
+
|
| 175 |
+
atoms = ['N', 'CA', 'C']
|
| 176 |
+
aatype = None
|
| 177 |
+
atom_positions = None
|
| 178 |
+
atom_mask = None
|
| 179 |
+
for g in groups:
|
| 180 |
+
if("[PRIMARY]" == g[0]):
|
| 181 |
+
seq = g[1][0].strip()
|
| 182 |
+
for i in range(len(seq)):
|
| 183 |
+
if(seq[i] not in residue_constants.restypes):
|
| 184 |
+
seq[i] = 'X'
|
| 185 |
+
aatype = np.array([
|
| 186 |
+
residue_constants.restype_order.get(
|
| 187 |
+
res_symbol, residue_constants.restype_num
|
| 188 |
+
) for res_symbol in seq
|
| 189 |
+
])
|
| 190 |
+
elif("[TERTIARY]" == g[0]):
|
| 191 |
+
tertiary = []
|
| 192 |
+
for axis in range(3):
|
| 193 |
+
tertiary.append(list(map(float, g[1][axis].split())))
|
| 194 |
+
tertiary_np = np.array(tertiary)
|
| 195 |
+
atom_positions = np.zeros(
|
| 196 |
+
(len(tertiary[0])//3, residue_constants.atom_type_num, 3)
|
| 197 |
+
).astype(np.float32)
|
| 198 |
+
for i, atom in enumerate(atoms):
|
| 199 |
+
atom_positions[:, residue_constants.atom_order[atom], :] = (
|
| 200 |
+
np.transpose(tertiary_np[:, i::3])
|
| 201 |
+
)
|
| 202 |
+
atom_positions *= PICO_TO_ANGSTROM
|
| 203 |
+
elif("[MASK]" == g[0]):
|
| 204 |
+
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
|
| 205 |
+
atom_mask = np.zeros(
|
| 206 |
+
(len(mask), residue_constants.atom_type_num,)
|
| 207 |
+
).astype(np.float32)
|
| 208 |
+
for i, atom in enumerate(atoms):
|
| 209 |
+
atom_mask[:, residue_constants.atom_order[atom]] = 1
|
| 210 |
+
atom_mask *= mask[..., None]
|
| 211 |
+
|
| 212 |
+
return Protein(
|
| 213 |
+
atom_positions=atom_positions,
|
| 214 |
+
atom_mask=atom_mask,
|
| 215 |
+
aatype=aatype,
|
| 216 |
+
residue_index=np.arange(len(aatype)),
|
| 217 |
+
b_factors=None,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
|
| 222 |
+
pdb_headers = []
|
| 223 |
+
|
| 224 |
+
remark = prot.remark
|
| 225 |
+
if(remark is not None):
|
| 226 |
+
pdb_headers.append(f"REMARK {remark}")
|
| 227 |
+
|
| 228 |
+
parents = prot.parents
|
| 229 |
+
parents_chain_index = prot.parents_chain_index
|
| 230 |
+
if(parents_chain_index is not None):
|
| 231 |
+
parents = [
|
| 232 |
+
p for i, p in zip(parents_chain_index, parents) if i == chain_id
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
if(parents is None or len(parents) == 0):
|
| 236 |
+
parents = ["N/A"]
|
| 237 |
+
|
| 238 |
+
pdb_headers.append(f"PARENT {' '.join(parents)}")
|
| 239 |
+
|
| 240 |
+
return pdb_headers
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
|
| 244 |
+
""" Add pdb headers to an existing PDB string. Useful during multi-chain
|
| 245 |
+
recycling
|
| 246 |
+
"""
|
| 247 |
+
out_pdb_lines = []
|
| 248 |
+
lines = pdb_str.split('\n')
|
| 249 |
+
|
| 250 |
+
remark = prot.remark
|
| 251 |
+
if(remark is not None):
|
| 252 |
+
out_pdb_lines.append(f"REMARK {remark}")
|
| 253 |
+
|
| 254 |
+
parents_per_chain = None
|
| 255 |
+
if(prot.parents is not None and len(prot.parents) > 0):
|
| 256 |
+
parents_per_chain = []
|
| 257 |
+
if(prot.parents_chain_index is not None):
|
| 258 |
+
cur_chain = prot.parents_chain_index[0]
|
| 259 |
+
parent_dict = {}
|
| 260 |
+
for p, i in zip(prot.parents, prot.parents_chain_index):
|
| 261 |
+
parent_dict.setdefault(str(i), [])
|
| 262 |
+
parent_dict[str(i)].append(p)
|
| 263 |
+
|
| 264 |
+
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
|
| 265 |
+
for i in range(max_idx + 1):
|
| 266 |
+
chain_parents = parent_dict.get(str(i), ["N/A"])
|
| 267 |
+
parents_per_chain.append(chain_parents)
|
| 268 |
+
else:
|
| 269 |
+
parents_per_chain.append(prot.parents)
|
| 270 |
+
else:
|
| 271 |
+
parents_per_chain = [["N/A"]]
|
| 272 |
+
|
| 273 |
+
make_parent_line = lambda p: f"PARENT {' '.join(p)}"
|
| 274 |
+
|
| 275 |
+
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
|
| 276 |
+
|
| 277 |
+
chain_counter = 0
|
| 278 |
+
for i, l in enumerate(lines):
|
| 279 |
+
if("PARENT" not in l and "REMARK" not in l):
|
| 280 |
+
out_pdb_lines.append(l)
|
| 281 |
+
if("TER" in l and not "END" in lines[i + 1]):
|
| 282 |
+
chain_counter += 1
|
| 283 |
+
if(not chain_counter >= len(parents_per_chain)):
|
| 284 |
+
chain_parents = parents_per_chain[chain_counter]
|
| 285 |
+
else:
|
| 286 |
+
chain_parents = ["N/A"]
|
| 287 |
+
|
| 288 |
+
out_pdb_lines.append(make_parent_line(chain_parents))
|
| 289 |
+
|
| 290 |
+
return '\n'.join(out_pdb_lines)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def to_pdb(prot: Protein) -> str:
|
| 294 |
+
"""Converts a `Protein` instance to a PDB string.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
prot: The protein to convert to PDB.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
PDB string.
|
| 301 |
+
"""
|
| 302 |
+
restypes = residue_constants.restypes + ["X"]
|
| 303 |
+
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
|
| 304 |
+
atom_types = residue_constants.atom_types
|
| 305 |
+
|
| 306 |
+
pdb_lines = []
|
| 307 |
+
|
| 308 |
+
atom_mask = prot.atom_mask
|
| 309 |
+
aatype = prot.aatype
|
| 310 |
+
atom_positions = prot.atom_positions
|
| 311 |
+
residue_index = prot.residue_index.astype(int)
|
| 312 |
+
b_factors = prot.b_factors
|
| 313 |
+
chain_index = prot.chain_index
|
| 314 |
+
|
| 315 |
+
if np.any(aatype > residue_constants.restype_num):
|
| 316 |
+
raise ValueError("Invalid aatypes.")
|
| 317 |
+
|
| 318 |
+
headers = get_pdb_headers(prot)
|
| 319 |
+
if(len(headers) > 0):
|
| 320 |
+
pdb_lines.extend(headers)
|
| 321 |
+
|
| 322 |
+
n = aatype.shape[0]
|
| 323 |
+
atom_index = 1
|
| 324 |
+
prev_chain_index = 0
|
| 325 |
+
chain_tags = string.ascii_uppercase
|
| 326 |
+
# Add all atom sites.
|
| 327 |
+
for i in range(n):
|
| 328 |
+
res_name_3 = res_1to3(aatype[i])
|
| 329 |
+
for atom_name, pos, mask, b_factor in zip(
|
| 330 |
+
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
|
| 331 |
+
):
|
| 332 |
+
if mask < 0.5:
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
record_type = "ATOM"
|
| 336 |
+
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
| 337 |
+
alt_loc = ""
|
| 338 |
+
insertion_code = ""
|
| 339 |
+
occupancy = 1.00
|
| 340 |
+
element = atom_name[
|
| 341 |
+
0
|
| 342 |
+
] # Protein supports only C, N, O, S, this works.
|
| 343 |
+
charge = ""
|
| 344 |
+
|
| 345 |
+
chain_tag = "A"
|
| 346 |
+
if(chain_index is not None):
|
| 347 |
+
chain_tag = chain_tags[chain_index[i]]
|
| 348 |
+
|
| 349 |
+
# PDB is a columnar format, every space matters here!
|
| 350 |
+
atom_line = (
|
| 351 |
+
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
|
| 352 |
+
f"{res_name_3:>3} {chain_tag:>1}"
|
| 353 |
+
f"{residue_index[i]:>4}{insertion_code:>1} "
|
| 354 |
+
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
| 355 |
+
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
| 356 |
+
f"{element:>2}{charge:>2}"
|
| 357 |
+
)
|
| 358 |
+
pdb_lines.append(atom_line)
|
| 359 |
+
atom_index += 1
|
| 360 |
+
|
| 361 |
+
should_terminate = (i == n - 1)
|
| 362 |
+
if(chain_index is not None):
|
| 363 |
+
if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
|
| 364 |
+
should_terminate = True
|
| 365 |
+
prev_chain_index = chain_index[i + 1]
|
| 366 |
+
|
| 367 |
+
if(should_terminate):
|
| 368 |
+
# Close the chain.
|
| 369 |
+
chain_end = "TER"
|
| 370 |
+
chain_termination_line = (
|
| 371 |
+
f"{chain_end:<6}{atom_index:>5} "
|
| 372 |
+
f"{res_1to3(aatype[i]):>3} "
|
| 373 |
+
f"{chain_tag:>1}{residue_index[i]:>4}"
|
| 374 |
+
)
|
| 375 |
+
pdb_lines.append(chain_termination_line)
|
| 376 |
+
atom_index += 1
|
| 377 |
+
|
| 378 |
+
if(i != n - 1):
|
| 379 |
+
# "prev" is a misnomer here. This happens at the beginning of
|
| 380 |
+
# each new chain.
|
| 381 |
+
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
|
| 382 |
+
|
| 383 |
+
pdb_lines.append("END")
|
| 384 |
+
pdb_lines.append("")
|
| 385 |
+
return "\n".join(pdb_lines)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
| 389 |
+
"""Computes an ideal atom mask.
|
| 390 |
+
|
| 391 |
+
`Protein.atom_mask` typically is defined according to the atoms that are
|
| 392 |
+
reported in the PDB. This function computes a mask according to heavy atoms
|
| 393 |
+
that should be present in the given sequence of amino acids.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
prot: `Protein` whose fields are `numpy.ndarray` objects.
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
An ideal atom mask.
|
| 400 |
+
"""
|
| 401 |
+
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def from_prediction(
|
| 405 |
+
features: FeatureDict,
|
| 406 |
+
result: ModelOutput,
|
| 407 |
+
b_factors: Optional[np.ndarray] = None,
|
| 408 |
+
chain_index: Optional[np.ndarray] = None,
|
| 409 |
+
remark: Optional[str] = None,
|
| 410 |
+
parents: Optional[Sequence[str]] = None,
|
| 411 |
+
parents_chain_index: Optional[Sequence[int]] = None
|
| 412 |
+
) -> Protein:
|
| 413 |
+
"""Assembles a protein from a prediction.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
features: Dictionary holding model inputs.
|
| 417 |
+
result: Dictionary holding model outputs.
|
| 418 |
+
b_factors: (Optional) B-factors to use for the protein.
|
| 419 |
+
chain_index: (Optional) Chain indices for multi-chain predictions
|
| 420 |
+
remark: (Optional) Remark about the prediction
|
| 421 |
+
parents: (Optional) List of template names
|
| 422 |
+
Returns:
|
| 423 |
+
A protein instance.
|
| 424 |
+
"""
|
| 425 |
+
if b_factors is None:
|
| 426 |
+
b_factors = np.zeros_like(result["final_atom_mask"])
|
| 427 |
+
|
| 428 |
+
return Protein(
|
| 429 |
+
aatype=features["aatype"],
|
| 430 |
+
atom_positions=result["final_atom_positions"],
|
| 431 |
+
atom_mask=result["final_atom_mask"],
|
| 432 |
+
residue_index=features["residue_index"] + 1,
|
| 433 |
+
b_factors=b_factors,
|
| 434 |
+
chain_index=chain_index,
|
| 435 |
+
remark=remark,
|
| 436 |
+
parents=parents,
|
| 437 |
+
parents_chain_index=parents_chain_index,
|
| 438 |
+
)
|
openfold/np/relax/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import importlib as importlib
|
| 4 |
+
|
| 5 |
+
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
|
| 6 |
+
__all__ = [
|
| 7 |
+
os.path.basename(f)[:-3]
|
| 8 |
+
for f in _files
|
| 9 |
+
if os.path.isfile(f) and not f.endswith("__init__.py")
|
| 10 |
+
]
|
| 11 |
+
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
|
| 12 |
+
for _m in _modules:
|
| 13 |
+
globals()[_m[0]] = _m[1]
|
| 14 |
+
|
| 15 |
+
# Avoid needlessly cluttering the global namespace
|
| 16 |
+
del _files, _m, _modules
|
openfold/np/relax/amber_minimize.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Restrained Amber Minimization of a structure."""
|
| 17 |
+
|
| 18 |
+
import io
|
| 19 |
+
import time
|
| 20 |
+
from typing import Collection, Optional, Sequence
|
| 21 |
+
|
| 22 |
+
from absl import logging
|
| 23 |
+
from openfold.np import (
|
| 24 |
+
protein,
|
| 25 |
+
residue_constants,
|
| 26 |
+
)
|
| 27 |
+
import openfold.utils.loss as loss
|
| 28 |
+
from openfold.np.relax import cleanup, utils
|
| 29 |
+
import ml_collections
|
| 30 |
+
import numpy as np
|
| 31 |
+
import openmm
|
| 32 |
+
from openmm import unit
|
| 33 |
+
from openmm import app as openmm_app
|
| 34 |
+
from openmm.app.internal.pdbstructure import PdbStructure
|
| 35 |
+
|
| 36 |
+
ENERGY = unit.kilocalories_per_mole
|
| 37 |
+
LENGTH = unit.angstroms
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
|
| 41 |
+
"""Returns True if the atom will be restrained by the given restraint set."""
|
| 42 |
+
|
| 43 |
+
if rset == "non_hydrogen":
|
| 44 |
+
return atom.element.name != "hydrogen"
|
| 45 |
+
elif rset == "c_alpha":
|
| 46 |
+
return atom.name == "CA"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _add_restraints(
|
| 50 |
+
system: openmm.System,
|
| 51 |
+
reference_pdb: openmm_app.PDBFile,
|
| 52 |
+
stiffness: unit.Unit,
|
| 53 |
+
rset: str,
|
| 54 |
+
exclude_residues: Sequence[int],
|
| 55 |
+
):
|
| 56 |
+
"""Adds a harmonic potential that restrains the system to a structure."""
|
| 57 |
+
assert rset in ["non_hydrogen", "c_alpha"]
|
| 58 |
+
|
| 59 |
+
force = openmm.CustomExternalForce(
|
| 60 |
+
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
|
| 61 |
+
)
|
| 62 |
+
force.addGlobalParameter("k", stiffness)
|
| 63 |
+
for p in ["x0", "y0", "z0"]:
|
| 64 |
+
force.addPerParticleParameter(p)
|
| 65 |
+
|
| 66 |
+
for i, atom in enumerate(reference_pdb.topology.atoms()):
|
| 67 |
+
if atom.residue.index in exclude_residues:
|
| 68 |
+
continue
|
| 69 |
+
if will_restrain(atom, rset):
|
| 70 |
+
force.addParticle(i, reference_pdb.positions[i])
|
| 71 |
+
logging.info(
|
| 72 |
+
"Restraining %d / %d particles.",
|
| 73 |
+
force.getNumParticles(),
|
| 74 |
+
system.getNumParticles(),
|
| 75 |
+
)
|
| 76 |
+
system.addForce(force)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _openmm_minimize(
|
| 80 |
+
pdb_str: str,
|
| 81 |
+
max_iterations: int,
|
| 82 |
+
tolerance: unit.Unit,
|
| 83 |
+
stiffness: unit.Unit,
|
| 84 |
+
restraint_set: str,
|
| 85 |
+
exclude_residues: Sequence[int],
|
| 86 |
+
use_gpu: bool,
|
| 87 |
+
):
|
| 88 |
+
"""Minimize energy via openmm."""
|
| 89 |
+
|
| 90 |
+
pdb_file = io.StringIO(pdb_str)
|
| 91 |
+
pdb = openmm_app.PDBFile(pdb_file)
|
| 92 |
+
|
| 93 |
+
force_field = openmm_app.ForceField("amber99sb.xml")
|
| 94 |
+
constraints = openmm_app.HBonds
|
| 95 |
+
system = force_field.createSystem(pdb.topology, constraints=constraints)
|
| 96 |
+
if stiffness > 0 * ENERGY / (LENGTH ** 2):
|
| 97 |
+
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
|
| 98 |
+
|
| 99 |
+
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
|
| 100 |
+
platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
|
| 101 |
+
simulation = openmm_app.Simulation(
|
| 102 |
+
pdb.topology, system, integrator, platform
|
| 103 |
+
)
|
| 104 |
+
simulation.context.setPositions(pdb.positions)
|
| 105 |
+
|
| 106 |
+
ret = {}
|
| 107 |
+
state = simulation.context.getState(getEnergy=True, getPositions=True)
|
| 108 |
+
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
|
| 109 |
+
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
|
| 110 |
+
simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
|
| 111 |
+
state = simulation.context.getState(getEnergy=True, getPositions=True)
|
| 112 |
+
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
|
| 113 |
+
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
|
| 114 |
+
ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
|
| 115 |
+
return ret
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
|
| 119 |
+
"""Returns a pdb string provided OpenMM topology and positions."""
|
| 120 |
+
with io.StringIO() as f:
|
| 121 |
+
openmm_app.PDBFile.writeFile(topology, positions, f)
|
| 122 |
+
return f.getvalue()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
|
| 126 |
+
"""Checks that no atom positions have been altered by cleaning."""
|
| 127 |
+
cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
|
| 128 |
+
reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
|
| 129 |
+
|
| 130 |
+
cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
|
| 131 |
+
ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
|
| 132 |
+
|
| 133 |
+
for ref_res, cl_res in zip(
|
| 134 |
+
reference.topology.residues(), cleaned.topology.residues()
|
| 135 |
+
):
|
| 136 |
+
assert ref_res.name == cl_res.name
|
| 137 |
+
for rat in ref_res.atoms():
|
| 138 |
+
for cat in cl_res.atoms():
|
| 139 |
+
if cat.name == rat.name:
|
| 140 |
+
if not np.array_equal(
|
| 141 |
+
cl_xyz[cat.index], ref_xyz[rat.index]
|
| 142 |
+
):
|
| 143 |
+
raise ValueError(
|
| 144 |
+
f"Coordinates of cleaned atom {cat} do not match "
|
| 145 |
+
f"coordinates of reference atom {rat}."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _check_residues_are_well_defined(prot: protein.Protein):
|
| 150 |
+
"""Checks that all residues contain non-empty atom sets."""
|
| 151 |
+
if (prot.atom_mask.sum(axis=-1) == 0).any():
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"Amber minimization can only be performed on proteins with"
|
| 154 |
+
" well-defined residues. This protein contains at least"
|
| 155 |
+
" one residue with no atoms."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _check_atom_mask_is_ideal(prot):
|
| 160 |
+
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
|
| 161 |
+
atom_mask = prot.atom_mask
|
| 162 |
+
ideal_atom_mask = protein.ideal_atom_mask(prot)
|
| 163 |
+
utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def clean_protein(prot: protein.Protein, checks: bool = True):
|
| 167 |
+
"""Adds missing atoms to Protein instance.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
prot: A `protein.Protein` instance.
|
| 171 |
+
checks: A `bool` specifying whether to add additional checks to the cleaning
|
| 172 |
+
process.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
pdb_string: A string of the cleaned protein.
|
| 176 |
+
"""
|
| 177 |
+
_check_atom_mask_is_ideal(prot)
|
| 178 |
+
|
| 179 |
+
# Clean pdb.
|
| 180 |
+
prot_pdb_string = protein.to_pdb(prot)
|
| 181 |
+
pdb_file = io.StringIO(prot_pdb_string)
|
| 182 |
+
alterations_info = {}
|
| 183 |
+
fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
|
| 184 |
+
fixed_pdb_file = io.StringIO(fixed_pdb)
|
| 185 |
+
pdb_structure = PdbStructure(fixed_pdb_file)
|
| 186 |
+
cleanup.clean_structure(pdb_structure, alterations_info)
|
| 187 |
+
|
| 188 |
+
logging.info("alterations info: %s", alterations_info)
|
| 189 |
+
|
| 190 |
+
# Write pdb file of cleaned structure.
|
| 191 |
+
as_file = openmm_app.PDBFile(pdb_structure)
|
| 192 |
+
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
|
| 193 |
+
if checks:
|
| 194 |
+
_check_cleaned_atoms(pdb_string, prot_pdb_string)
|
| 195 |
+
return pdb_string
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def make_atom14_positions(prot):
|
| 199 |
+
"""Constructs denser atom positions (14 dimensions instead of 37)."""
|
| 200 |
+
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
|
| 201 |
+
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
|
| 202 |
+
restype_atom14_mask = []
|
| 203 |
+
|
| 204 |
+
for rt in residue_constants.restypes:
|
| 205 |
+
atom_names = residue_constants.restype_name_to_atom14_names[
|
| 206 |
+
residue_constants.restype_1to3[rt]
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
restype_atom14_to_atom37.append(
|
| 210 |
+
[
|
| 211 |
+
(residue_constants.atom_order[name] if name else 0)
|
| 212 |
+
for name in atom_names
|
| 213 |
+
]
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
|
| 217 |
+
restype_atom37_to_atom14.append(
|
| 218 |
+
[
|
| 219 |
+
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
|
| 220 |
+
for name in residue_constants.atom_types
|
| 221 |
+
]
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
restype_atom14_mask.append(
|
| 225 |
+
[(1.0 if name else 0.0) for name in atom_names]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Add dummy mapping for restype 'UNK'.
|
| 229 |
+
restype_atom14_to_atom37.append([0] * 14)
|
| 230 |
+
restype_atom37_to_atom14.append([0] * 37)
|
| 231 |
+
restype_atom14_mask.append([0.0] * 14)
|
| 232 |
+
|
| 233 |
+
restype_atom14_to_atom37 = np.array(
|
| 234 |
+
restype_atom14_to_atom37, dtype=int
|
| 235 |
+
)
|
| 236 |
+
restype_atom37_to_atom14 = np.array(
|
| 237 |
+
restype_atom37_to_atom14, dtype=int
|
| 238 |
+
)
|
| 239 |
+
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
|
| 240 |
+
|
| 241 |
+
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
|
| 242 |
+
# with shape (num_res, 14) containing the atom37 indices for this protein.
|
| 243 |
+
residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
|
| 244 |
+
residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
|
| 245 |
+
|
| 246 |
+
# Create a mask for known ground truth positions.
|
| 247 |
+
residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
|
| 248 |
+
prot["all_atom_mask"], residx_atom14_to_atom37, axis=1
|
| 249 |
+
).astype(np.float32)
|
| 250 |
+
|
| 251 |
+
# Gather the ground truth positions.
|
| 252 |
+
residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
|
| 253 |
+
np.take_along_axis(
|
| 254 |
+
prot["all_atom_positions"],
|
| 255 |
+
residx_atom14_to_atom37[..., None],
|
| 256 |
+
axis=1,
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
prot["atom14_atom_exists"] = residx_atom14_mask
|
| 261 |
+
prot["atom14_gt_exists"] = residx_atom14_gt_mask
|
| 262 |
+
prot["atom14_gt_positions"] = residx_atom14_gt_positions
|
| 263 |
+
|
| 264 |
+
prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
|
| 265 |
+
|
| 266 |
+
# Create the gather indices for mapping back.
|
| 267 |
+
residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
|
| 268 |
+
prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
|
| 269 |
+
|
| 270 |
+
# Create the corresponding mask.
|
| 271 |
+
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
|
| 272 |
+
for restype, restype_letter in enumerate(residue_constants.restypes):
|
| 273 |
+
restype_name = residue_constants.restype_1to3[restype_letter]
|
| 274 |
+
atom_names = residue_constants.residue_atoms[restype_name]
|
| 275 |
+
for atom_name in atom_names:
|
| 276 |
+
atom_type = residue_constants.atom_order[atom_name]
|
| 277 |
+
restype_atom37_mask[restype, atom_type] = 1
|
| 278 |
+
|
| 279 |
+
residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
|
| 280 |
+
prot["atom37_atom_exists"] = residx_atom37_mask
|
| 281 |
+
|
| 282 |
+
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
|
| 283 |
+
# alternative ground truth coordinates where the naming is swapped
|
| 284 |
+
restype_3 = [
|
| 285 |
+
residue_constants.restype_1to3[res]
|
| 286 |
+
for res in residue_constants.restypes
|
| 287 |
+
]
|
| 288 |
+
restype_3 += ["UNK"]
|
| 289 |
+
|
| 290 |
+
# Matrices for renaming ambiguous atoms.
|
| 291 |
+
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
|
| 292 |
+
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
|
| 293 |
+
correspondences = np.arange(14)
|
| 294 |
+
for source_atom_swap, target_atom_swap in swap.items():
|
| 295 |
+
source_index = residue_constants.restype_name_to_atom14_names[
|
| 296 |
+
resname
|
| 297 |
+
].index(source_atom_swap)
|
| 298 |
+
target_index = residue_constants.restype_name_to_atom14_names[
|
| 299 |
+
resname
|
| 300 |
+
].index(target_atom_swap)
|
| 301 |
+
correspondences[source_index] = target_index
|
| 302 |
+
correspondences[target_index] = source_index
|
| 303 |
+
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
|
| 304 |
+
for index, correspondence in enumerate(correspondences):
|
| 305 |
+
renaming_matrix[index, correspondence] = 1.0
|
| 306 |
+
all_matrices[resname] = renaming_matrix.astype(np.float32)
|
| 307 |
+
renaming_matrices = np.stack(
|
| 308 |
+
[all_matrices[restype] for restype in restype_3]
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Pick the transformation matrices for the given residue sequence
|
| 312 |
+
# shape (num_res, 14, 14).
|
| 313 |
+
renaming_transform = renaming_matrices[prot["aatype"]]
|
| 314 |
+
|
| 315 |
+
# Apply it to the ground truth positions. shape (num_res, 14, 3).
|
| 316 |
+
alternative_gt_positions = np.einsum(
|
| 317 |
+
"rac,rab->rbc", residx_atom14_gt_positions, renaming_transform
|
| 318 |
+
)
|
| 319 |
+
prot["atom14_alt_gt_positions"] = alternative_gt_positions
|
| 320 |
+
|
| 321 |
+
# Create the mask for the alternative ground truth (differs from the
|
| 322 |
+
# ground truth mask, if only one of the atoms in an ambiguous pair has a
|
| 323 |
+
# ground truth position).
|
| 324 |
+
alternative_gt_mask = np.einsum(
|
| 325 |
+
"ra,rab->rb", residx_atom14_gt_mask, renaming_transform
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
prot["atom14_alt_gt_exists"] = alternative_gt_mask
|
| 329 |
+
|
| 330 |
+
# Create an ambiguous atoms mask. shape: (21, 14).
|
| 331 |
+
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
|
| 332 |
+
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
|
| 333 |
+
for atom_name1, atom_name2 in swap.items():
|
| 334 |
+
restype = residue_constants.restype_order[
|
| 335 |
+
residue_constants.restype_3to1[resname]
|
| 336 |
+
]
|
| 337 |
+
atom_idx1 = residue_constants.restype_name_to_atom14_names[
|
| 338 |
+
resname
|
| 339 |
+
].index(atom_name1)
|
| 340 |
+
atom_idx2 = residue_constants.restype_name_to_atom14_names[
|
| 341 |
+
resname
|
| 342 |
+
].index(atom_name2)
|
| 343 |
+
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
|
| 344 |
+
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
|
| 345 |
+
|
| 346 |
+
# From this create an ambiguous_mask for the given sequence.
|
| 347 |
+
prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
|
| 348 |
+
prot["aatype"]
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
return prot
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def find_violations(prot_np: protein.Protein):
|
| 355 |
+
"""Analyzes a protein and returns structural violation information.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
prot_np: A protein.
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
violations: A `dict` of structure components with structural violations.
|
| 362 |
+
violation_metrics: A `dict` of violation metrics.
|
| 363 |
+
"""
|
| 364 |
+
batch = {
|
| 365 |
+
"aatype": prot_np.aatype,
|
| 366 |
+
"all_atom_positions": prot_np.atom_positions.astype(np.float32),
|
| 367 |
+
"all_atom_mask": prot_np.atom_mask.astype(np.float32),
|
| 368 |
+
"residue_index": prot_np.residue_index,
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
|
| 372 |
+
batch = make_atom14_positions(batch)
|
| 373 |
+
|
| 374 |
+
violations = loss.find_structural_violations_np(
|
| 375 |
+
batch=batch,
|
| 376 |
+
atom14_pred_positions=batch["atom14_gt_positions"],
|
| 377 |
+
config=ml_collections.ConfigDict(
|
| 378 |
+
{
|
| 379 |
+
"violation_tolerance_factor": 12, # Taken from model config.
|
| 380 |
+
"clash_overlap_tolerance": 1.5, # Taken from model config.
|
| 381 |
+
}
|
| 382 |
+
),
|
| 383 |
+
)
|
| 384 |
+
violation_metrics = loss.compute_violation_metrics_np(
|
| 385 |
+
batch=batch,
|
| 386 |
+
atom14_pred_positions=batch["atom14_gt_positions"],
|
| 387 |
+
violations=violations,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
return violations, violation_metrics
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def get_violation_metrics(prot: protein.Protein):
|
| 394 |
+
"""Computes violation and alignment metrics."""
|
| 395 |
+
structural_violations, struct_metrics = find_violations(prot)
|
| 396 |
+
violation_idx = np.flatnonzero(
|
| 397 |
+
structural_violations["total_per_residue_violations_mask"]
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
struct_metrics["residue_violations"] = violation_idx
|
| 401 |
+
struct_metrics["num_residue_violations"] = len(violation_idx)
|
| 402 |
+
struct_metrics["structural_violations"] = structural_violations
|
| 403 |
+
return struct_metrics
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _run_one_iteration(
|
| 407 |
+
*,
|
| 408 |
+
pdb_string: str,
|
| 409 |
+
max_iterations: int,
|
| 410 |
+
tolerance: float,
|
| 411 |
+
stiffness: float,
|
| 412 |
+
restraint_set: str,
|
| 413 |
+
max_attempts: int,
|
| 414 |
+
exclude_residues: Optional[Collection[int]] = None,
|
| 415 |
+
use_gpu: bool,
|
| 416 |
+
):
|
| 417 |
+
"""Runs the minimization pipeline.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
pdb_string: A pdb string.
|
| 421 |
+
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
|
| 422 |
+
A value of 0 specifies no limit.
|
| 423 |
+
tolerance: kcal/mol, the energy tolerance of L-BFGS.
|
| 424 |
+
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
|
| 425 |
+
potential.
|
| 426 |
+
restraint_set: The set of atoms to restrain.
|
| 427 |
+
max_attempts: The maximum number of minimization attempts.
|
| 428 |
+
exclude_residues: An optional list of zero-indexed residues to exclude from
|
| 429 |
+
restraints.
|
| 430 |
+
use_gpu: Whether to run relaxation on GPU
|
| 431 |
+
Returns:
|
| 432 |
+
A `dict` of minimization info.
|
| 433 |
+
"""
|
| 434 |
+
exclude_residues = exclude_residues or []
|
| 435 |
+
|
| 436 |
+
# Assign physical dimensions.
|
| 437 |
+
tolerance = tolerance * ENERGY
|
| 438 |
+
stiffness = stiffness * ENERGY / (LENGTH ** 2)
|
| 439 |
+
|
| 440 |
+
start = time.perf_counter()
|
| 441 |
+
minimized = False
|
| 442 |
+
attempts = 0
|
| 443 |
+
while not minimized and attempts < max_attempts:
|
| 444 |
+
attempts += 1
|
| 445 |
+
try:
|
| 446 |
+
logging.info(
|
| 447 |
+
"Minimizing protein, attempt %d of %d.", attempts, max_attempts
|
| 448 |
+
)
|
| 449 |
+
ret = _openmm_minimize(
|
| 450 |
+
pdb_string,
|
| 451 |
+
max_iterations=max_iterations,
|
| 452 |
+
tolerance=tolerance,
|
| 453 |
+
stiffness=stiffness,
|
| 454 |
+
restraint_set=restraint_set,
|
| 455 |
+
exclude_residues=exclude_residues,
|
| 456 |
+
use_gpu=use_gpu,
|
| 457 |
+
)
|
| 458 |
+
minimized = True
|
| 459 |
+
except Exception as e: # pylint: disable=broad-except
|
| 460 |
+
print(e)
|
| 461 |
+
logging.info(e)
|
| 462 |
+
if not minimized:
|
| 463 |
+
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
|
| 464 |
+
ret["opt_time"] = time.perf_counter() - start
|
| 465 |
+
ret["min_attempts"] = attempts
|
| 466 |
+
return ret
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def run_pipeline(
|
| 470 |
+
prot: protein.Protein,
|
| 471 |
+
stiffness: float,
|
| 472 |
+
use_gpu: bool,
|
| 473 |
+
max_outer_iterations: int = 1,
|
| 474 |
+
place_hydrogens_every_iteration: bool = True,
|
| 475 |
+
max_iterations: int = 0,
|
| 476 |
+
tolerance: float = 2.39,
|
| 477 |
+
restraint_set: str = "non_hydrogen",
|
| 478 |
+
max_attempts: int = 100,
|
| 479 |
+
checks: bool = True,
|
| 480 |
+
exclude_residues: Optional[Sequence[int]] = None,
|
| 481 |
+
):
|
| 482 |
+
"""Run iterative amber relax.
|
| 483 |
+
|
| 484 |
+
Successive relax iterations are performed until all violations have been
|
| 485 |
+
resolved. Each iteration involves a restrained Amber minimization, with
|
| 486 |
+
restraint exclusions determined by violation-participating residues.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
prot: A protein to be relaxed.
|
| 490 |
+
stiffness: kcal/mol A**2, the restraint stiffness.
|
| 491 |
+
use_gpu: Whether to run on GPU
|
| 492 |
+
max_outer_iterations: The maximum number of iterative minimization.
|
| 493 |
+
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
|
| 494 |
+
prior to every minimization.
|
| 495 |
+
max_iterations: An `int` specifying the maximum number of L-BFGS steps
|
| 496 |
+
per relax iteration. A value of 0 specifies no limit.
|
| 497 |
+
tolerance: kcal/mol, the energy tolerance of L-BFGS.
|
| 498 |
+
The default value is the OpenMM default.
|
| 499 |
+
restraint_set: The set of atoms to restrain.
|
| 500 |
+
max_attempts: The maximum number of minimization attempts per iteration.
|
| 501 |
+
checks: Whether to perform cleaning checks.
|
| 502 |
+
exclude_residues: An optional list of zero-indexed residues to exclude from
|
| 503 |
+
restraints.
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
out: A dictionary of output values.
|
| 507 |
+
"""
|
| 508 |
+
|
| 509 |
+
# `protein.to_pdb` will strip any poorly-defined residues so we need to
|
| 510 |
+
# perform this check before `clean_protein`.
|
| 511 |
+
_check_residues_are_well_defined(prot)
|
| 512 |
+
pdb_string = clean_protein(prot, checks=checks)
|
| 513 |
+
|
| 514 |
+
exclude_residues = exclude_residues or []
|
| 515 |
+
exclude_residues = set(exclude_residues)
|
| 516 |
+
violations = np.inf
|
| 517 |
+
iteration = 0
|
| 518 |
+
|
| 519 |
+
while violations > 0 and iteration < max_outer_iterations:
|
| 520 |
+
ret = _run_one_iteration(
|
| 521 |
+
pdb_string=pdb_string,
|
| 522 |
+
exclude_residues=exclude_residues,
|
| 523 |
+
max_iterations=max_iterations,
|
| 524 |
+
tolerance=tolerance,
|
| 525 |
+
stiffness=stiffness,
|
| 526 |
+
restraint_set=restraint_set,
|
| 527 |
+
max_attempts=max_attempts,
|
| 528 |
+
use_gpu=use_gpu,
|
| 529 |
+
)
|
| 530 |
+
prot = protein.from_pdb_string(ret["min_pdb"])
|
| 531 |
+
if place_hydrogens_every_iteration:
|
| 532 |
+
pdb_string = clean_protein(prot, checks=True)
|
| 533 |
+
else:
|
| 534 |
+
pdb_string = ret["min_pdb"]
|
| 535 |
+
ret.update(get_violation_metrics(prot))
|
| 536 |
+
ret.update(
|
| 537 |
+
{
|
| 538 |
+
"num_exclusions": len(exclude_residues),
|
| 539 |
+
"iteration": iteration,
|
| 540 |
+
}
|
| 541 |
+
)
|
| 542 |
+
violations = ret["violations_per_residue"]
|
| 543 |
+
exclude_residues = exclude_residues.union(ret["residue_violations"])
|
| 544 |
+
|
| 545 |
+
logging.info(
|
| 546 |
+
"Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
|
| 547 |
+
"num residue violations %d num residue exclusions %d ",
|
| 548 |
+
ret["einit"],
|
| 549 |
+
ret["efinal"],
|
| 550 |
+
ret["opt_time"],
|
| 551 |
+
ret["num_residue_violations"],
|
| 552 |
+
ret["num_exclusions"],
|
| 553 |
+
)
|
| 554 |
+
iteration += 1
|
| 555 |
+
return ret
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def get_initial_energies(
|
| 559 |
+
pdb_strs: Sequence[str],
|
| 560 |
+
stiffness: float = 0.0,
|
| 561 |
+
restraint_set: str = "non_hydrogen",
|
| 562 |
+
exclude_residues: Optional[Sequence[int]] = None,
|
| 563 |
+
):
|
| 564 |
+
"""Returns initial potential energies for a sequence of PDBs.
|
| 565 |
+
|
| 566 |
+
Assumes the input PDBs are ready for minimization, and all have the same
|
| 567 |
+
topology.
|
| 568 |
+
Allows time to be saved by not pdbfixing / rebuilding the system.
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
pdb_strs: List of PDB strings.
|
| 572 |
+
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
|
| 573 |
+
potential.
|
| 574 |
+
restraint_set: Which atom types to restrain.
|
| 575 |
+
exclude_residues: An optional list of zero-indexed residues to exclude from
|
| 576 |
+
restraints.
|
| 577 |
+
|
| 578 |
+
Returns:
|
| 579 |
+
A list of initial energies in the same order as pdb_strs.
|
| 580 |
+
"""
|
| 581 |
+
exclude_residues = exclude_residues or []
|
| 582 |
+
|
| 583 |
+
openmm_pdbs = [
|
| 584 |
+
openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
|
| 585 |
+
]
|
| 586 |
+
force_field = openmm_app.ForceField("amber99sb.xml")
|
| 587 |
+
system = force_field.createSystem(
|
| 588 |
+
openmm_pdbs[0].topology, constraints=openmm_app.HBonds
|
| 589 |
+
)
|
| 590 |
+
stiffness = stiffness * ENERGY / (LENGTH ** 2)
|
| 591 |
+
if stiffness > 0 * ENERGY / (LENGTH ** 2):
|
| 592 |
+
_add_restraints(
|
| 593 |
+
system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
|
| 594 |
+
)
|
| 595 |
+
simulation = openmm_app.Simulation(
|
| 596 |
+
openmm_pdbs[0].topology,
|
| 597 |
+
system,
|
| 598 |
+
openmm.LangevinIntegrator(0, 0.01, 0.0),
|
| 599 |
+
openmm.Platform.getPlatformByName("CPU"),
|
| 600 |
+
)
|
| 601 |
+
energies = []
|
| 602 |
+
for pdb in openmm_pdbs:
|
| 603 |
+
try:
|
| 604 |
+
simulation.context.setPositions(pdb.positions)
|
| 605 |
+
state = simulation.context.getState(getEnergy=True)
|
| 606 |
+
energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
|
| 607 |
+
except Exception as e: # pylint: disable=broad-except
|
| 608 |
+
logging.error(
|
| 609 |
+
"Error getting initial energy, returning large value %s", e
|
| 610 |
+
)
|
| 611 |
+
energies.append(unit.Quantity(1e20, ENERGY))
|
| 612 |
+
return energies
|
openfold/np/relax/cleanup.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
|
| 16 |
+
|
| 17 |
+
fix_pdb uses a third-party tool. We also support fixing some additional edge
|
| 18 |
+
cases like removing chains of length one (see clean_structure).
|
| 19 |
+
"""
|
| 20 |
+
import io
|
| 21 |
+
|
| 22 |
+
import pdbfixer
|
| 23 |
+
from simtk.openmm import app
|
| 24 |
+
from simtk.openmm.app import element
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def fix_pdb(pdbfile, alterations_info):
|
| 28 |
+
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
|
| 29 |
+
|
| 30 |
+
1) Replaces nonstandard residues.
|
| 31 |
+
2) Removes heterogens (non protein residues) including water.
|
| 32 |
+
3) Adds missing residues and missing atoms within existing residues.
|
| 33 |
+
4) Adds hydrogens assuming pH=7.0.
|
| 34 |
+
5) KeepIds is currently true, so the fixer must keep the existing chain and
|
| 35 |
+
residue identifiers. This will fail for some files in wider PDB that have
|
| 36 |
+
invalid IDs.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
pdbfile: Input PDB file handle.
|
| 40 |
+
alterations_info: A dict that will store details of changes made.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
A PDB string representing the fixed structure.
|
| 44 |
+
"""
|
| 45 |
+
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
|
| 46 |
+
fixer.findNonstandardResidues()
|
| 47 |
+
alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
|
| 48 |
+
fixer.replaceNonstandardResidues()
|
| 49 |
+
_remove_heterogens(fixer, alterations_info, keep_water=False)
|
| 50 |
+
fixer.findMissingResidues()
|
| 51 |
+
alterations_info["missing_residues"] = fixer.missingResidues
|
| 52 |
+
fixer.findMissingAtoms()
|
| 53 |
+
alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
|
| 54 |
+
alterations_info["missing_terminals"] = fixer.missingTerminals
|
| 55 |
+
fixer.addMissingAtoms(seed=0)
|
| 56 |
+
fixer.addMissingHydrogens()
|
| 57 |
+
out_handle = io.StringIO()
|
| 58 |
+
app.PDBFile.writeFile(
|
| 59 |
+
fixer.topology, fixer.positions, out_handle, keepIds=True
|
| 60 |
+
)
|
| 61 |
+
return out_handle.getvalue()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def clean_structure(pdb_structure, alterations_info):
|
| 65 |
+
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
pdb_structure: An OpenMM structure to modify and fix.
|
| 69 |
+
alterations_info: A dict that will store details of changes made.
|
| 70 |
+
"""
|
| 71 |
+
_replace_met_se(pdb_structure, alterations_info)
|
| 72 |
+
_remove_chains_of_length_one(pdb_structure, alterations_info)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _remove_heterogens(fixer, alterations_info, keep_water):
|
| 76 |
+
"""Removes the residues that Pdbfixer considers to be heterogens.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
fixer: A Pdbfixer instance.
|
| 80 |
+
alterations_info: A dict that will store details of changes made.
|
| 81 |
+
keep_water: If True, water (HOH) is not considered to be a heterogen.
|
| 82 |
+
"""
|
| 83 |
+
initial_resnames = set()
|
| 84 |
+
for chain in fixer.topology.chains():
|
| 85 |
+
for residue in chain.residues():
|
| 86 |
+
initial_resnames.add(residue.name)
|
| 87 |
+
fixer.removeHeterogens(keepWater=keep_water)
|
| 88 |
+
final_resnames = set()
|
| 89 |
+
for chain in fixer.topology.chains():
|
| 90 |
+
for residue in chain.residues():
|
| 91 |
+
final_resnames.add(residue.name)
|
| 92 |
+
alterations_info["removed_heterogens"] = initial_resnames.difference(
|
| 93 |
+
final_resnames
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _replace_met_se(pdb_structure, alterations_info):
|
| 98 |
+
"""Replace the Se in any MET residues that were not marked as modified."""
|
| 99 |
+
modified_met_residues = []
|
| 100 |
+
for res in pdb_structure.iter_residues():
|
| 101 |
+
name = res.get_name_with_spaces().strip()
|
| 102 |
+
if name == "MET":
|
| 103 |
+
s_atom = res.get_atom("SD")
|
| 104 |
+
if s_atom.element_symbol == "Se":
|
| 105 |
+
s_atom.element_symbol = "S"
|
| 106 |
+
s_atom.element = element.get_by_symbol("S")
|
| 107 |
+
modified_met_residues.append(s_atom.residue_number)
|
| 108 |
+
alterations_info["Se_in_MET"] = modified_met_residues
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _remove_chains_of_length_one(pdb_structure, alterations_info):
|
| 112 |
+
"""Removes chains that correspond to a single amino acid.
|
| 113 |
+
|
| 114 |
+
A single amino acid in a chain is both N and C terminus. There is no force
|
| 115 |
+
template for this case.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
pdb_structure: An OpenMM pdb_structure to modify and fix.
|
| 119 |
+
alterations_info: A dict that will store details of changes made.
|
| 120 |
+
"""
|
| 121 |
+
removed_chains = {}
|
| 122 |
+
for model in pdb_structure.iter_models():
|
| 123 |
+
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
|
| 124 |
+
invalid_chain_ids = [
|
| 125 |
+
c.chain_id for c in model.iter_chains() if len(c) <= 1
|
| 126 |
+
]
|
| 127 |
+
model.chains = valid_chains
|
| 128 |
+
for chain_id in invalid_chain_ids:
|
| 129 |
+
model.chains_by_id.pop(chain_id)
|
| 130 |
+
removed_chains[model.number] = invalid_chain_ids
|
| 131 |
+
alterations_info["removed_chains"] = removed_chains
|
openfold/np/relax/relax.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Amber relaxation."""
|
| 17 |
+
from typing import Any, Dict, Sequence, Tuple
|
| 18 |
+
from openfold.np import protein
|
| 19 |
+
from openfold.np.relax import amber_minimize, utils
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AmberRelaxation(object):
|
| 24 |
+
"""Amber relaxation."""
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
*,
|
| 28 |
+
max_iterations: int,
|
| 29 |
+
tolerance: float,
|
| 30 |
+
stiffness: float,
|
| 31 |
+
exclude_residues: Sequence[int],
|
| 32 |
+
max_outer_iterations: int,
|
| 33 |
+
use_gpu: bool,
|
| 34 |
+
):
|
| 35 |
+
"""Initialize Amber Relaxer.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
|
| 39 |
+
tolerance: kcal/mol, the energy tolerance of L-BFGS.
|
| 40 |
+
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
|
| 41 |
+
potential.
|
| 42 |
+
exclude_residues: Residues to exclude from per-atom restraining.
|
| 43 |
+
Zero-indexed.
|
| 44 |
+
max_outer_iterations: Maximum number of violation-informed relax
|
| 45 |
+
iterations. A value of 1 will run the non-iterative procedure used in
|
| 46 |
+
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
|
| 47 |
+
as soon as there are no violations, hence in most cases this causes no
|
| 48 |
+
slowdown. In the worst case we do 20 outer iterations.
|
| 49 |
+
use_gpu: Whether to run on GPU
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
self._max_iterations = max_iterations
|
| 53 |
+
self._tolerance = tolerance
|
| 54 |
+
self._stiffness = stiffness
|
| 55 |
+
self._exclude_residues = exclude_residues
|
| 56 |
+
self._max_outer_iterations = max_outer_iterations
|
| 57 |
+
self._use_gpu = use_gpu
|
| 58 |
+
|
| 59 |
+
def process(
|
| 60 |
+
self, *, prot: protein.Protein
|
| 61 |
+
) -> Tuple[str, Dict[str, Any], np.ndarray]:
|
| 62 |
+
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
|
| 63 |
+
out = amber_minimize.run_pipeline(
|
| 64 |
+
prot=prot,
|
| 65 |
+
max_iterations=self._max_iterations,
|
| 66 |
+
tolerance=self._tolerance,
|
| 67 |
+
stiffness=self._stiffness,
|
| 68 |
+
exclude_residues=self._exclude_residues,
|
| 69 |
+
max_outer_iterations=self._max_outer_iterations,
|
| 70 |
+
use_gpu=self._use_gpu,
|
| 71 |
+
)
|
| 72 |
+
min_pos = out["pos"]
|
| 73 |
+
start_pos = out["posinit"]
|
| 74 |
+
rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
|
| 75 |
+
debug_data = {
|
| 76 |
+
"initial_energy": out["einit"],
|
| 77 |
+
"final_energy": out["efinal"],
|
| 78 |
+
"attempts": out["min_attempts"],
|
| 79 |
+
"rmsd": rmsd,
|
| 80 |
+
}
|
| 81 |
+
pdb_str = amber_minimize.clean_protein(prot)
|
| 82 |
+
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
|
| 83 |
+
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
|
| 84 |
+
utils.assert_equal_nonterminal_atom_types(
|
| 85 |
+
protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
|
| 86 |
+
)
|
| 87 |
+
violations = out["structural_violations"][
|
| 88 |
+
"total_per_residue_violations_mask"
|
| 89 |
+
]
|
| 90 |
+
return min_pdb, debug_data, violations
|
openfold/np/relax/utils.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Utils for minimization."""
|
| 17 |
+
import io
|
| 18 |
+
from openfold.np import residue_constants
|
| 19 |
+
from Bio import PDB
|
| 20 |
+
import numpy as np
|
| 21 |
+
# simtk.openmm is not supported anymore. Remove simtk.
|
| 22 |
+
# https://github.com/openmm/openmm/releases
|
| 23 |
+
from openmm import app as openmm_app
|
| 24 |
+
from openmm.app.internal.pdbstructure import PdbStructure
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
|
| 28 |
+
pdb_file = io.StringIO(pdb_str)
|
| 29 |
+
structure = PdbStructure(pdb_file)
|
| 30 |
+
topology = openmm_app.PDBFile(structure).getTopology()
|
| 31 |
+
with io.StringIO() as f:
|
| 32 |
+
openmm_app.PDBFile.writeFile(topology, pos, f)
|
| 33 |
+
return f.getvalue()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
|
| 37 |
+
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
pdb_str: An input PDB string.
|
| 41 |
+
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
|
| 42 |
+
B-factors are per residue; i.e. that the nonzero entries are identical in
|
| 43 |
+
[0, i, :].
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
A new PDB string with the B-factors replaced.
|
| 47 |
+
"""
|
| 48 |
+
if bfactors.shape[-1] != residue_constants.atom_type_num:
|
| 49 |
+
raise ValueError(
|
| 50 |
+
f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
parser = PDB.PDBParser(QUIET=True)
|
| 54 |
+
handle = io.StringIO(pdb_str)
|
| 55 |
+
structure = parser.get_structure("", handle)
|
| 56 |
+
|
| 57 |
+
curr_resid = ("", "", "")
|
| 58 |
+
idx = -1
|
| 59 |
+
for atom in structure.get_atoms():
|
| 60 |
+
atom_resid = atom.parent.get_id()
|
| 61 |
+
if atom_resid != curr_resid:
|
| 62 |
+
idx += 1
|
| 63 |
+
if idx >= bfactors.shape[0]:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
"Index into bfactors exceeds number of residues. "
|
| 66 |
+
"B-factors shape: {shape}, idx: {idx}."
|
| 67 |
+
)
|
| 68 |
+
curr_resid = atom_resid
|
| 69 |
+
atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
|
| 70 |
+
|
| 71 |
+
new_pdb = io.StringIO()
|
| 72 |
+
pdb_io = PDB.PDBIO()
|
| 73 |
+
pdb_io.set_structure(structure)
|
| 74 |
+
pdb_io.save(new_pdb)
|
| 75 |
+
return new_pdb.getvalue()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def assert_equal_nonterminal_atom_types(
|
| 79 |
+
atom_mask: np.ndarray, ref_atom_mask: np.ndarray
|
| 80 |
+
):
|
| 81 |
+
"""Checks that pre- and post-minimized proteins have same atom set."""
|
| 82 |
+
# Ignore any terminal OXT atoms which may have been added by minimization.
|
| 83 |
+
oxt = residue_constants.atom_order["OXT"]
|
| 84 |
+
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
|
| 85 |
+
no_oxt_mask[..., oxt] = False
|
| 86 |
+
np.testing.assert_almost_equal(
|
| 87 |
+
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
|
| 88 |
+
)
|
openfold/np/residue_constants.py
ADDED
|
@@ -0,0 +1,1310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Constants used in AlphaFold."""
|
| 17 |
+
|
| 18 |
+
import collections
|
| 19 |
+
import functools
|
| 20 |
+
from typing import Mapping, List, Tuple
|
| 21 |
+
from importlib import resources
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tree
|
| 25 |
+
|
| 26 |
+
# Internal import (35fd).
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Distance from one CA to next CA [trans configuration: omega = 180].
|
| 30 |
+
ca_ca = 3.80209737096
|
| 31 |
+
|
| 32 |
+
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
|
| 33 |
+
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
|
| 34 |
+
# chi angles so their chi angle lists are empty.
|
| 35 |
+
chi_angles_atoms = {
|
| 36 |
+
"ALA": [],
|
| 37 |
+
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
|
| 38 |
+
"ARG": [
|
| 39 |
+
["N", "CA", "CB", "CG"],
|
| 40 |
+
["CA", "CB", "CG", "CD"],
|
| 41 |
+
["CB", "CG", "CD", "NE"],
|
| 42 |
+
["CG", "CD", "NE", "CZ"],
|
| 43 |
+
],
|
| 44 |
+
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
|
| 45 |
+
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
|
| 46 |
+
"CYS": [["N", "CA", "CB", "SG"]],
|
| 47 |
+
"GLN": [
|
| 48 |
+
["N", "CA", "CB", "CG"],
|
| 49 |
+
["CA", "CB", "CG", "CD"],
|
| 50 |
+
["CB", "CG", "CD", "OE1"],
|
| 51 |
+
],
|
| 52 |
+
"GLU": [
|
| 53 |
+
["N", "CA", "CB", "CG"],
|
| 54 |
+
["CA", "CB", "CG", "CD"],
|
| 55 |
+
["CB", "CG", "CD", "OE1"],
|
| 56 |
+
],
|
| 57 |
+
"GLY": [],
|
| 58 |
+
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
|
| 59 |
+
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
|
| 60 |
+
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 61 |
+
"LYS": [
|
| 62 |
+
["N", "CA", "CB", "CG"],
|
| 63 |
+
["CA", "CB", "CG", "CD"],
|
| 64 |
+
["CB", "CG", "CD", "CE"],
|
| 65 |
+
["CG", "CD", "CE", "NZ"],
|
| 66 |
+
],
|
| 67 |
+
"MET": [
|
| 68 |
+
["N", "CA", "CB", "CG"],
|
| 69 |
+
["CA", "CB", "CG", "SD"],
|
| 70 |
+
["CB", "CG", "SD", "CE"],
|
| 71 |
+
],
|
| 72 |
+
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 73 |
+
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
|
| 74 |
+
"SER": [["N", "CA", "CB", "OG"]],
|
| 75 |
+
"THR": [["N", "CA", "CB", "OG1"]],
|
| 76 |
+
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 77 |
+
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 78 |
+
"VAL": [["N", "CA", "CB", "CG1"]],
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# If chi angles given in fixed-length array, this matrix determines how to mask
|
| 82 |
+
# them for each AA type. The order is as per restype_order (see below).
|
| 83 |
+
chi_angles_mask = [
|
| 84 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
| 85 |
+
[1.0, 1.0, 1.0, 1.0], # ARG
|
| 86 |
+
[1.0, 1.0, 0.0, 0.0], # ASN
|
| 87 |
+
[1.0, 1.0, 0.0, 0.0], # ASP
|
| 88 |
+
[1.0, 0.0, 0.0, 0.0], # CYS
|
| 89 |
+
[1.0, 1.0, 1.0, 0.0], # GLN
|
| 90 |
+
[1.0, 1.0, 1.0, 0.0], # GLU
|
| 91 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
| 92 |
+
[1.0, 1.0, 0.0, 0.0], # HIS
|
| 93 |
+
[1.0, 1.0, 0.0, 0.0], # ILE
|
| 94 |
+
[1.0, 1.0, 0.0, 0.0], # LEU
|
| 95 |
+
[1.0, 1.0, 1.0, 1.0], # LYS
|
| 96 |
+
[1.0, 1.0, 1.0, 0.0], # MET
|
| 97 |
+
[1.0, 1.0, 0.0, 0.0], # PHE
|
| 98 |
+
[1.0, 1.0, 0.0, 0.0], # PRO
|
| 99 |
+
[1.0, 0.0, 0.0, 0.0], # SER
|
| 100 |
+
[1.0, 0.0, 0.0, 0.0], # THR
|
| 101 |
+
[1.0, 1.0, 0.0, 0.0], # TRP
|
| 102 |
+
[1.0, 1.0, 0.0, 0.0], # TYR
|
| 103 |
+
[1.0, 0.0, 0.0, 0.0], # VAL
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
# The following chi angles are pi periodic: they can be rotated by a multiple
|
| 107 |
+
# of pi without affecting the structure.
|
| 108 |
+
chi_pi_periodic = [
|
| 109 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
| 110 |
+
[0.0, 0.0, 0.0, 0.0], # ARG
|
| 111 |
+
[0.0, 0.0, 0.0, 0.0], # ASN
|
| 112 |
+
[0.0, 1.0, 0.0, 0.0], # ASP
|
| 113 |
+
[0.0, 0.0, 0.0, 0.0], # CYS
|
| 114 |
+
[0.0, 0.0, 0.0, 0.0], # GLN
|
| 115 |
+
[0.0, 0.0, 1.0, 0.0], # GLU
|
| 116 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
| 117 |
+
[0.0, 0.0, 0.0, 0.0], # HIS
|
| 118 |
+
[0.0, 0.0, 0.0, 0.0], # ILE
|
| 119 |
+
[0.0, 0.0, 0.0, 0.0], # LEU
|
| 120 |
+
[0.0, 0.0, 0.0, 0.0], # LYS
|
| 121 |
+
[0.0, 0.0, 0.0, 0.0], # MET
|
| 122 |
+
[0.0, 1.0, 0.0, 0.0], # PHE
|
| 123 |
+
[0.0, 0.0, 0.0, 0.0], # PRO
|
| 124 |
+
[0.0, 0.0, 0.0, 0.0], # SER
|
| 125 |
+
[0.0, 0.0, 0.0, 0.0], # THR
|
| 126 |
+
[0.0, 0.0, 0.0, 0.0], # TRP
|
| 127 |
+
[0.0, 1.0, 0.0, 0.0], # TYR
|
| 128 |
+
[0.0, 0.0, 0.0, 0.0], # VAL
|
| 129 |
+
[0.0, 0.0, 0.0, 0.0], # UNK
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
|
| 133 |
+
# psi and chi angles:
|
| 134 |
+
# 0: 'backbone group',
|
| 135 |
+
# 1: 'pre-omega-group', (empty)
|
| 136 |
+
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
|
| 137 |
+
# 3: 'psi-group',
|
| 138 |
+
# 4,5,6,7: 'chi1,2,3,4-group'
|
| 139 |
+
# The atom positions are relative to the axis-end-atom of the corresponding
|
| 140 |
+
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
|
| 141 |
+
# is defined such that the dihedral-angle-definiting atom (the last entry in
|
| 142 |
+
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
|
| 143 |
+
# format: [atomname, group_idx, rel_position]
|
| 144 |
+
rigid_group_atom_positions = {
|
| 145 |
+
"ALA": [
|
| 146 |
+
["N", 0, (-0.525, 1.363, 0.000)],
|
| 147 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 148 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 149 |
+
["CB", 0, (-0.529, -0.774, -1.205)],
|
| 150 |
+
["O", 3, (0.627, 1.062, 0.000)],
|
| 151 |
+
],
|
| 152 |
+
"ARG": [
|
| 153 |
+
["N", 0, (-0.524, 1.362, -0.000)],
|
| 154 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 155 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 156 |
+
["CB", 0, (-0.524, -0.778, -1.209)],
|
| 157 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 158 |
+
["CG", 4, (0.616, 1.390, -0.000)],
|
| 159 |
+
["CD", 5, (0.564, 1.414, 0.000)],
|
| 160 |
+
["NE", 6, (0.539, 1.357, -0.000)],
|
| 161 |
+
["NH1", 7, (0.206, 2.301, 0.000)],
|
| 162 |
+
["NH2", 7, (2.078, 0.978, -0.000)],
|
| 163 |
+
["CZ", 7, (0.758, 1.093, -0.000)],
|
| 164 |
+
],
|
| 165 |
+
"ASN": [
|
| 166 |
+
["N", 0, (-0.536, 1.357, 0.000)],
|
| 167 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 168 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 169 |
+
["CB", 0, (-0.531, -0.787, -1.200)],
|
| 170 |
+
["O", 3, (0.625, 1.062, 0.000)],
|
| 171 |
+
["CG", 4, (0.584, 1.399, 0.000)],
|
| 172 |
+
["ND2", 5, (0.593, -1.188, 0.001)],
|
| 173 |
+
["OD1", 5, (0.633, 1.059, 0.000)],
|
| 174 |
+
],
|
| 175 |
+
"ASP": [
|
| 176 |
+
["N", 0, (-0.525, 1.362, -0.000)],
|
| 177 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 178 |
+
["C", 0, (1.527, 0.000, -0.000)],
|
| 179 |
+
["CB", 0, (-0.526, -0.778, -1.208)],
|
| 180 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 181 |
+
["CG", 4, (0.593, 1.398, -0.000)],
|
| 182 |
+
["OD1", 5, (0.610, 1.091, 0.000)],
|
| 183 |
+
["OD2", 5, (0.592, -1.101, -0.003)],
|
| 184 |
+
],
|
| 185 |
+
"CYS": [
|
| 186 |
+
["N", 0, (-0.522, 1.362, -0.000)],
|
| 187 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 188 |
+
["C", 0, (1.524, 0.000, 0.000)],
|
| 189 |
+
["CB", 0, (-0.519, -0.773, -1.212)],
|
| 190 |
+
["O", 3, (0.625, 1.062, -0.000)],
|
| 191 |
+
["SG", 4, (0.728, 1.653, 0.000)],
|
| 192 |
+
],
|
| 193 |
+
"GLN": [
|
| 194 |
+
["N", 0, (-0.526, 1.361, -0.000)],
|
| 195 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 196 |
+
["C", 0, (1.526, 0.000, 0.000)],
|
| 197 |
+
["CB", 0, (-0.525, -0.779, -1.207)],
|
| 198 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 199 |
+
["CG", 4, (0.615, 1.393, 0.000)],
|
| 200 |
+
["CD", 5, (0.587, 1.399, -0.000)],
|
| 201 |
+
["NE2", 6, (0.593, -1.189, -0.001)],
|
| 202 |
+
["OE1", 6, (0.634, 1.060, 0.000)],
|
| 203 |
+
],
|
| 204 |
+
"GLU": [
|
| 205 |
+
["N", 0, (-0.528, 1.361, 0.000)],
|
| 206 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 207 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 208 |
+
["CB", 0, (-0.526, -0.781, -1.207)],
|
| 209 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 210 |
+
["CG", 4, (0.615, 1.392, 0.000)],
|
| 211 |
+
["CD", 5, (0.600, 1.397, 0.000)],
|
| 212 |
+
["OE1", 6, (0.607, 1.095, -0.000)],
|
| 213 |
+
["OE2", 6, (0.589, -1.104, -0.001)],
|
| 214 |
+
],
|
| 215 |
+
"GLY": [
|
| 216 |
+
["N", 0, (-0.572, 1.337, 0.000)],
|
| 217 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 218 |
+
["C", 0, (1.517, -0.000, -0.000)],
|
| 219 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 220 |
+
],
|
| 221 |
+
"HIS": [
|
| 222 |
+
["N", 0, (-0.527, 1.360, 0.000)],
|
| 223 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 224 |
+
["C", 0, (1.525, 0.000, 0.000)],
|
| 225 |
+
["CB", 0, (-0.525, -0.778, -1.208)],
|
| 226 |
+
["O", 3, (0.625, 1.063, 0.000)],
|
| 227 |
+
["CG", 4, (0.600, 1.370, -0.000)],
|
| 228 |
+
["CD2", 5, (0.889, -1.021, 0.003)],
|
| 229 |
+
["ND1", 5, (0.744, 1.160, -0.000)],
|
| 230 |
+
["CE1", 5, (2.030, 0.851, 0.002)],
|
| 231 |
+
["NE2", 5, (2.145, -0.466, 0.004)],
|
| 232 |
+
],
|
| 233 |
+
"ILE": [
|
| 234 |
+
["N", 0, (-0.493, 1.373, -0.000)],
|
| 235 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 236 |
+
["C", 0, (1.527, -0.000, -0.000)],
|
| 237 |
+
["CB", 0, (-0.536, -0.793, -1.213)],
|
| 238 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 239 |
+
["CG1", 4, (0.534, 1.437, -0.000)],
|
| 240 |
+
["CG2", 4, (0.540, -0.785, -1.199)],
|
| 241 |
+
["CD1", 5, (0.619, 1.391, 0.000)],
|
| 242 |
+
],
|
| 243 |
+
"LEU": [
|
| 244 |
+
["N", 0, (-0.520, 1.363, 0.000)],
|
| 245 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 246 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 247 |
+
["CB", 0, (-0.522, -0.773, -1.214)],
|
| 248 |
+
["O", 3, (0.625, 1.063, -0.000)],
|
| 249 |
+
["CG", 4, (0.678, 1.371, 0.000)],
|
| 250 |
+
["CD1", 5, (0.530, 1.430, -0.000)],
|
| 251 |
+
["CD2", 5, (0.535, -0.774, 1.200)],
|
| 252 |
+
],
|
| 253 |
+
"LYS": [
|
| 254 |
+
["N", 0, (-0.526, 1.362, -0.000)],
|
| 255 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 256 |
+
["C", 0, (1.526, 0.000, 0.000)],
|
| 257 |
+
["CB", 0, (-0.524, -0.778, -1.208)],
|
| 258 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 259 |
+
["CG", 4, (0.619, 1.390, 0.000)],
|
| 260 |
+
["CD", 5, (0.559, 1.417, 0.000)],
|
| 261 |
+
["CE", 6, (0.560, 1.416, 0.000)],
|
| 262 |
+
["NZ", 7, (0.554, 1.387, 0.000)],
|
| 263 |
+
],
|
| 264 |
+
"MET": [
|
| 265 |
+
["N", 0, (-0.521, 1.364, -0.000)],
|
| 266 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 267 |
+
["C", 0, (1.525, 0.000, 0.000)],
|
| 268 |
+
["CB", 0, (-0.523, -0.776, -1.210)],
|
| 269 |
+
["O", 3, (0.625, 1.062, -0.000)],
|
| 270 |
+
["CG", 4, (0.613, 1.391, -0.000)],
|
| 271 |
+
["SD", 5, (0.703, 1.695, 0.000)],
|
| 272 |
+
["CE", 6, (0.320, 1.786, -0.000)],
|
| 273 |
+
],
|
| 274 |
+
"PHE": [
|
| 275 |
+
["N", 0, (-0.518, 1.363, 0.000)],
|
| 276 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 277 |
+
["C", 0, (1.524, 0.000, -0.000)],
|
| 278 |
+
["CB", 0, (-0.525, -0.776, -1.212)],
|
| 279 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 280 |
+
["CG", 4, (0.607, 1.377, 0.000)],
|
| 281 |
+
["CD1", 5, (0.709, 1.195, -0.000)],
|
| 282 |
+
["CD2", 5, (0.706, -1.196, 0.000)],
|
| 283 |
+
["CE1", 5, (2.102, 1.198, -0.000)],
|
| 284 |
+
["CE2", 5, (2.098, -1.201, -0.000)],
|
| 285 |
+
["CZ", 5, (2.794, -0.003, -0.001)],
|
| 286 |
+
],
|
| 287 |
+
"PRO": [
|
| 288 |
+
["N", 0, (-0.566, 1.351, -0.000)],
|
| 289 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 290 |
+
["C", 0, (1.527, -0.000, 0.000)],
|
| 291 |
+
["CB", 0, (-0.546, -0.611, -1.293)],
|
| 292 |
+
["O", 3, (0.621, 1.066, 0.000)],
|
| 293 |
+
["CG", 4, (0.382, 1.445, 0.0)],
|
| 294 |
+
# ['CD', 5, (0.427, 1.440, 0.0)],
|
| 295 |
+
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
|
| 296 |
+
],
|
| 297 |
+
"SER": [
|
| 298 |
+
["N", 0, (-0.529, 1.360, -0.000)],
|
| 299 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 300 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 301 |
+
["CB", 0, (-0.518, -0.777, -1.211)],
|
| 302 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 303 |
+
["OG", 4, (0.503, 1.325, 0.000)],
|
| 304 |
+
],
|
| 305 |
+
"THR": [
|
| 306 |
+
["N", 0, (-0.517, 1.364, 0.000)],
|
| 307 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 308 |
+
["C", 0, (1.526, 0.000, -0.000)],
|
| 309 |
+
["CB", 0, (-0.516, -0.793, -1.215)],
|
| 310 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 311 |
+
["CG2", 4, (0.550, -0.718, -1.228)],
|
| 312 |
+
["OG1", 4, (0.472, 1.353, 0.000)],
|
| 313 |
+
],
|
| 314 |
+
"TRP": [
|
| 315 |
+
["N", 0, (-0.521, 1.363, 0.000)],
|
| 316 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 317 |
+
["C", 0, (1.525, -0.000, 0.000)],
|
| 318 |
+
["CB", 0, (-0.523, -0.776, -1.212)],
|
| 319 |
+
["O", 3, (0.627, 1.062, 0.000)],
|
| 320 |
+
["CG", 4, (0.609, 1.370, -0.000)],
|
| 321 |
+
["CD1", 5, (0.824, 1.091, 0.000)],
|
| 322 |
+
["CD2", 5, (0.854, -1.148, -0.005)],
|
| 323 |
+
["CE2", 5, (2.186, -0.678, -0.007)],
|
| 324 |
+
["CE3", 5, (0.622, -2.530, -0.007)],
|
| 325 |
+
["NE1", 5, (2.140, 0.690, -0.004)],
|
| 326 |
+
["CH2", 5, (3.028, -2.890, -0.013)],
|
| 327 |
+
["CZ2", 5, (3.283, -1.543, -0.011)],
|
| 328 |
+
["CZ3", 5, (1.715, -3.389, -0.011)],
|
| 329 |
+
],
|
| 330 |
+
"TYR": [
|
| 331 |
+
["N", 0, (-0.522, 1.362, 0.000)],
|
| 332 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 333 |
+
["C", 0, (1.524, -0.000, -0.000)],
|
| 334 |
+
["CB", 0, (-0.522, -0.776, -1.213)],
|
| 335 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 336 |
+
["CG", 4, (0.607, 1.382, -0.000)],
|
| 337 |
+
["CD1", 5, (0.716, 1.195, -0.000)],
|
| 338 |
+
["CD2", 5, (0.713, -1.194, -0.001)],
|
| 339 |
+
["CE1", 5, (2.107, 1.200, -0.002)],
|
| 340 |
+
["CE2", 5, (2.104, -1.201, -0.003)],
|
| 341 |
+
["OH", 5, (4.168, -0.002, -0.005)],
|
| 342 |
+
["CZ", 5, (2.791, -0.001, -0.003)],
|
| 343 |
+
],
|
| 344 |
+
"VAL": [
|
| 345 |
+
["N", 0, (-0.494, 1.373, -0.000)],
|
| 346 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 347 |
+
["C", 0, (1.527, -0.000, -0.000)],
|
| 348 |
+
["CB", 0, (-0.533, -0.795, -1.213)],
|
| 349 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 350 |
+
["CG1", 4, (0.540, 1.429, -0.000)],
|
| 351 |
+
["CG2", 4, (0.533, -0.776, 1.203)],
|
| 352 |
+
],
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
|
| 356 |
+
residue_atoms = {
|
| 357 |
+
"ALA": ["C", "CA", "CB", "N", "O"],
|
| 358 |
+
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
|
| 359 |
+
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
|
| 360 |
+
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
|
| 361 |
+
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
|
| 362 |
+
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
|
| 363 |
+
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
|
| 364 |
+
"GLY": ["C", "CA", "N", "O"],
|
| 365 |
+
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
|
| 366 |
+
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
|
| 367 |
+
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
|
| 368 |
+
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
|
| 369 |
+
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
|
| 370 |
+
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
|
| 371 |
+
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
|
| 372 |
+
"SER": ["C", "CA", "CB", "N", "O", "OG"],
|
| 373 |
+
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
|
| 374 |
+
"TRP": [
|
| 375 |
+
"C",
|
| 376 |
+
"CA",
|
| 377 |
+
"CB",
|
| 378 |
+
"CG",
|
| 379 |
+
"CD1",
|
| 380 |
+
"CD2",
|
| 381 |
+
"CE2",
|
| 382 |
+
"CE3",
|
| 383 |
+
"CZ2",
|
| 384 |
+
"CZ3",
|
| 385 |
+
"CH2",
|
| 386 |
+
"N",
|
| 387 |
+
"NE1",
|
| 388 |
+
"O",
|
| 389 |
+
],
|
| 390 |
+
"TYR": [
|
| 391 |
+
"C",
|
| 392 |
+
"CA",
|
| 393 |
+
"CB",
|
| 394 |
+
"CG",
|
| 395 |
+
"CD1",
|
| 396 |
+
"CD2",
|
| 397 |
+
"CE1",
|
| 398 |
+
"CE2",
|
| 399 |
+
"CZ",
|
| 400 |
+
"N",
|
| 401 |
+
"O",
|
| 402 |
+
"OH",
|
| 403 |
+
],
|
| 404 |
+
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
# Naming swaps for ambiguous atom names.
|
| 408 |
+
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
|
| 409 |
+
# 4 of the 20 amino acids.
|
| 410 |
+
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
|
| 411 |
+
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
|
| 412 |
+
# the 'ambiguous' atoms and their neighbours)
|
| 413 |
+
# TODO: ^ interpret this
|
| 414 |
+
residue_atom_renaming_swaps = {
|
| 415 |
+
"ASP": {"OD1": "OD2"},
|
| 416 |
+
"GLU": {"OE1": "OE2"},
|
| 417 |
+
"PHE": {"CD1": "CD2", "CE1": "CE2"},
|
| 418 |
+
"TYR": {"CD1": "CD2", "CE1": "CE2"},
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
|
| 422 |
+
van_der_waals_radius = {
|
| 423 |
+
"C": 1.7,
|
| 424 |
+
"N": 1.55,
|
| 425 |
+
"O": 1.52,
|
| 426 |
+
"S": 1.8,
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
Bond = collections.namedtuple(
|
| 430 |
+
"Bond", ["atom1_name", "atom2_name", "length", "stddev"]
|
| 431 |
+
)
|
| 432 |
+
BondAngle = collections.namedtuple(
|
| 433 |
+
"BondAngle",
|
| 434 |
+
["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
@functools.lru_cache(maxsize=None)
|
| 439 |
+
def load_stereo_chemical_props() -> Tuple[
|
| 440 |
+
Mapping[str, List[Bond]],
|
| 441 |
+
Mapping[str, List[Bond]],
|
| 442 |
+
Mapping[str, List[BondAngle]],
|
| 443 |
+
]:
|
| 444 |
+
"""Load stereo_chemical_props.txt into a nice structure.
|
| 445 |
+
|
| 446 |
+
Load literature values for bond lengths and bond angles and translate
|
| 447 |
+
bond angles into the length of the opposite edge of the triangle
|
| 448 |
+
("residue_virtual_bonds").
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
residue_bonds: dict that maps resname --> list of Bond tuples
|
| 452 |
+
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
|
| 453 |
+
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
|
| 454 |
+
"""
|
| 455 |
+
# TODO: this file should be downloaded in a setup script
|
| 456 |
+
stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
|
| 457 |
+
|
| 458 |
+
lines_iter = iter(stereo_chemical_props.splitlines())
|
| 459 |
+
# Load bond lengths.
|
| 460 |
+
residue_bonds = {}
|
| 461 |
+
next(lines_iter) # Skip header line.
|
| 462 |
+
for line in lines_iter:
|
| 463 |
+
if line.strip() == "-":
|
| 464 |
+
break
|
| 465 |
+
bond, resname, length, stddev = line.split()
|
| 466 |
+
atom1, atom2 = bond.split("-")
|
| 467 |
+
if resname not in residue_bonds:
|
| 468 |
+
residue_bonds[resname] = []
|
| 469 |
+
residue_bonds[resname].append(
|
| 470 |
+
Bond(atom1, atom2, float(length), float(stddev))
|
| 471 |
+
)
|
| 472 |
+
residue_bonds["UNK"] = []
|
| 473 |
+
|
| 474 |
+
# Load bond angles.
|
| 475 |
+
residue_bond_angles = {}
|
| 476 |
+
next(lines_iter) # Skip empty line.
|
| 477 |
+
next(lines_iter) # Skip header line.
|
| 478 |
+
for line in lines_iter:
|
| 479 |
+
if line.strip() == "-":
|
| 480 |
+
break
|
| 481 |
+
bond, resname, angle_degree, stddev_degree = line.split()
|
| 482 |
+
atom1, atom2, atom3 = bond.split("-")
|
| 483 |
+
if resname not in residue_bond_angles:
|
| 484 |
+
residue_bond_angles[resname] = []
|
| 485 |
+
residue_bond_angles[resname].append(
|
| 486 |
+
BondAngle(
|
| 487 |
+
atom1,
|
| 488 |
+
atom2,
|
| 489 |
+
atom3,
|
| 490 |
+
float(angle_degree) / 180.0 * np.pi,
|
| 491 |
+
float(stddev_degree) / 180.0 * np.pi,
|
| 492 |
+
)
|
| 493 |
+
)
|
| 494 |
+
residue_bond_angles["UNK"] = []
|
| 495 |
+
|
| 496 |
+
def make_bond_key(atom1_name, atom2_name):
|
| 497 |
+
"""Unique key to lookup bonds."""
|
| 498 |
+
return "-".join(sorted([atom1_name, atom2_name]))
|
| 499 |
+
|
| 500 |
+
# Translate bond angles into distances ("virtual bonds").
|
| 501 |
+
residue_virtual_bonds = {}
|
| 502 |
+
for resname, bond_angles in residue_bond_angles.items():
|
| 503 |
+
# Create a fast lookup dict for bond lengths.
|
| 504 |
+
bond_cache = {}
|
| 505 |
+
for b in residue_bonds[resname]:
|
| 506 |
+
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
|
| 507 |
+
residue_virtual_bonds[resname] = []
|
| 508 |
+
for ba in bond_angles:
|
| 509 |
+
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
|
| 510 |
+
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
|
| 511 |
+
|
| 512 |
+
# Compute distance between atom1 and atom3 using the law of cosines
|
| 513 |
+
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
|
| 514 |
+
gamma = ba.angle_rad
|
| 515 |
+
length = np.sqrt(
|
| 516 |
+
bond1.length ** 2
|
| 517 |
+
+ bond2.length ** 2
|
| 518 |
+
- 2 * bond1.length * bond2.length * np.cos(gamma)
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Propagation of uncertainty assuming uncorrelated errors.
|
| 522 |
+
dl_outer = 0.5 / length
|
| 523 |
+
dl_dgamma = (
|
| 524 |
+
2 * bond1.length * bond2.length * np.sin(gamma)
|
| 525 |
+
) * dl_outer
|
| 526 |
+
dl_db1 = (
|
| 527 |
+
2 * bond1.length - 2 * bond2.length * np.cos(gamma)
|
| 528 |
+
) * dl_outer
|
| 529 |
+
dl_db2 = (
|
| 530 |
+
2 * bond2.length - 2 * bond1.length * np.cos(gamma)
|
| 531 |
+
) * dl_outer
|
| 532 |
+
stddev = np.sqrt(
|
| 533 |
+
(dl_dgamma * ba.stddev) ** 2
|
| 534 |
+
+ (dl_db1 * bond1.stddev) ** 2
|
| 535 |
+
+ (dl_db2 * bond2.stddev) ** 2
|
| 536 |
+
)
|
| 537 |
+
residue_virtual_bonds[resname].append(
|
| 538 |
+
Bond(ba.atom1_name, ba.atom3name, length, stddev)
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# Between-residue bond lengths for general bonds (first element) and for Proline
|
| 545 |
+
# (second element).
|
| 546 |
+
between_res_bond_length_c_n = [1.329, 1.341]
|
| 547 |
+
between_res_bond_length_stddev_c_n = [0.014, 0.016]
|
| 548 |
+
|
| 549 |
+
# Between-residue cos_angles.
|
| 550 |
+
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
|
| 551 |
+
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
|
| 552 |
+
|
| 553 |
+
# This mapping is used when we need to store atom data in a format that requires
|
| 554 |
+
# fixed atom data size for every residue (e.g. a numpy array).
|
| 555 |
+
atom_types = [
|
| 556 |
+
"N",
|
| 557 |
+
"CA",
|
| 558 |
+
"C",
|
| 559 |
+
"CB",
|
| 560 |
+
"O",
|
| 561 |
+
"CG",
|
| 562 |
+
"CG1",
|
| 563 |
+
"CG2",
|
| 564 |
+
"OG",
|
| 565 |
+
"OG1",
|
| 566 |
+
"SG",
|
| 567 |
+
"CD",
|
| 568 |
+
"CD1",
|
| 569 |
+
"CD2",
|
| 570 |
+
"ND1",
|
| 571 |
+
"ND2",
|
| 572 |
+
"OD1",
|
| 573 |
+
"OD2",
|
| 574 |
+
"SD",
|
| 575 |
+
"CE",
|
| 576 |
+
"CE1",
|
| 577 |
+
"CE2",
|
| 578 |
+
"CE3",
|
| 579 |
+
"NE",
|
| 580 |
+
"NE1",
|
| 581 |
+
"NE2",
|
| 582 |
+
"OE1",
|
| 583 |
+
"OE2",
|
| 584 |
+
"CH2",
|
| 585 |
+
"NH1",
|
| 586 |
+
"NH2",
|
| 587 |
+
"OH",
|
| 588 |
+
"CZ",
|
| 589 |
+
"CZ2",
|
| 590 |
+
"CZ3",
|
| 591 |
+
"NZ",
|
| 592 |
+
"OXT",
|
| 593 |
+
]
|
| 594 |
+
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
|
| 595 |
+
atom_type_num = len(atom_types) # := 37.
|
| 596 |
+
|
| 597 |
+
# A compact atom encoding with 14 columns
|
| 598 |
+
# pylint: disable=line-too-long
|
| 599 |
+
# pylint: disable=bad-whitespace
|
| 600 |
+
restype_name_to_atom14_names = {
|
| 601 |
+
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
|
| 602 |
+
"ARG": [
|
| 603 |
+
"N",
|
| 604 |
+
"CA",
|
| 605 |
+
"C",
|
| 606 |
+
"O",
|
| 607 |
+
"CB",
|
| 608 |
+
"CG",
|
| 609 |
+
"CD",
|
| 610 |
+
"NE",
|
| 611 |
+
"CZ",
|
| 612 |
+
"NH1",
|
| 613 |
+
"NH2",
|
| 614 |
+
"",
|
| 615 |
+
"",
|
| 616 |
+
"",
|
| 617 |
+
],
|
| 618 |
+
"ASN": [
|
| 619 |
+
"N",
|
| 620 |
+
"CA",
|
| 621 |
+
"C",
|
| 622 |
+
"O",
|
| 623 |
+
"CB",
|
| 624 |
+
"CG",
|
| 625 |
+
"OD1",
|
| 626 |
+
"ND2",
|
| 627 |
+
"",
|
| 628 |
+
"",
|
| 629 |
+
"",
|
| 630 |
+
"",
|
| 631 |
+
"",
|
| 632 |
+
"",
|
| 633 |
+
],
|
| 634 |
+
"ASP": [
|
| 635 |
+
"N",
|
| 636 |
+
"CA",
|
| 637 |
+
"C",
|
| 638 |
+
"O",
|
| 639 |
+
"CB",
|
| 640 |
+
"CG",
|
| 641 |
+
"OD1",
|
| 642 |
+
"OD2",
|
| 643 |
+
"",
|
| 644 |
+
"",
|
| 645 |
+
"",
|
| 646 |
+
"",
|
| 647 |
+
"",
|
| 648 |
+
"",
|
| 649 |
+
],
|
| 650 |
+
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
|
| 651 |
+
"GLN": [
|
| 652 |
+
"N",
|
| 653 |
+
"CA",
|
| 654 |
+
"C",
|
| 655 |
+
"O",
|
| 656 |
+
"CB",
|
| 657 |
+
"CG",
|
| 658 |
+
"CD",
|
| 659 |
+
"OE1",
|
| 660 |
+
"NE2",
|
| 661 |
+
"",
|
| 662 |
+
"",
|
| 663 |
+
"",
|
| 664 |
+
"",
|
| 665 |
+
"",
|
| 666 |
+
],
|
| 667 |
+
"GLU": [
|
| 668 |
+
"N",
|
| 669 |
+
"CA",
|
| 670 |
+
"C",
|
| 671 |
+
"O",
|
| 672 |
+
"CB",
|
| 673 |
+
"CG",
|
| 674 |
+
"CD",
|
| 675 |
+
"OE1",
|
| 676 |
+
"OE2",
|
| 677 |
+
"",
|
| 678 |
+
"",
|
| 679 |
+
"",
|
| 680 |
+
"",
|
| 681 |
+
"",
|
| 682 |
+
],
|
| 683 |
+
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
|
| 684 |
+
"HIS": [
|
| 685 |
+
"N",
|
| 686 |
+
"CA",
|
| 687 |
+
"C",
|
| 688 |
+
"O",
|
| 689 |
+
"CB",
|
| 690 |
+
"CG",
|
| 691 |
+
"ND1",
|
| 692 |
+
"CD2",
|
| 693 |
+
"CE1",
|
| 694 |
+
"NE2",
|
| 695 |
+
"",
|
| 696 |
+
"",
|
| 697 |
+
"",
|
| 698 |
+
"",
|
| 699 |
+
],
|
| 700 |
+
"ILE": [
|
| 701 |
+
"N",
|
| 702 |
+
"CA",
|
| 703 |
+
"C",
|
| 704 |
+
"O",
|
| 705 |
+
"CB",
|
| 706 |
+
"CG1",
|
| 707 |
+
"CG2",
|
| 708 |
+
"CD1",
|
| 709 |
+
"",
|
| 710 |
+
"",
|
| 711 |
+
"",
|
| 712 |
+
"",
|
| 713 |
+
"",
|
| 714 |
+
"",
|
| 715 |
+
],
|
| 716 |
+
"LEU": [
|
| 717 |
+
"N",
|
| 718 |
+
"CA",
|
| 719 |
+
"C",
|
| 720 |
+
"O",
|
| 721 |
+
"CB",
|
| 722 |
+
"CG",
|
| 723 |
+
"CD1",
|
| 724 |
+
"CD2",
|
| 725 |
+
"",
|
| 726 |
+
"",
|
| 727 |
+
"",
|
| 728 |
+
"",
|
| 729 |
+
"",
|
| 730 |
+
"",
|
| 731 |
+
],
|
| 732 |
+
"LYS": [
|
| 733 |
+
"N",
|
| 734 |
+
"CA",
|
| 735 |
+
"C",
|
| 736 |
+
"O",
|
| 737 |
+
"CB",
|
| 738 |
+
"CG",
|
| 739 |
+
"CD",
|
| 740 |
+
"CE",
|
| 741 |
+
"NZ",
|
| 742 |
+
"",
|
| 743 |
+
"",
|
| 744 |
+
"",
|
| 745 |
+
"",
|
| 746 |
+
"",
|
| 747 |
+
],
|
| 748 |
+
"MET": [
|
| 749 |
+
"N",
|
| 750 |
+
"CA",
|
| 751 |
+
"C",
|
| 752 |
+
"O",
|
| 753 |
+
"CB",
|
| 754 |
+
"CG",
|
| 755 |
+
"SD",
|
| 756 |
+
"CE",
|
| 757 |
+
"",
|
| 758 |
+
"",
|
| 759 |
+
"",
|
| 760 |
+
"",
|
| 761 |
+
"",
|
| 762 |
+
"",
|
| 763 |
+
],
|
| 764 |
+
"PHE": [
|
| 765 |
+
"N",
|
| 766 |
+
"CA",
|
| 767 |
+
"C",
|
| 768 |
+
"O",
|
| 769 |
+
"CB",
|
| 770 |
+
"CG",
|
| 771 |
+
"CD1",
|
| 772 |
+
"CD2",
|
| 773 |
+
"CE1",
|
| 774 |
+
"CE2",
|
| 775 |
+
"CZ",
|
| 776 |
+
"",
|
| 777 |
+
"",
|
| 778 |
+
"",
|
| 779 |
+
],
|
| 780 |
+
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
|
| 781 |
+
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
|
| 782 |
+
"THR": [
|
| 783 |
+
"N",
|
| 784 |
+
"CA",
|
| 785 |
+
"C",
|
| 786 |
+
"O",
|
| 787 |
+
"CB",
|
| 788 |
+
"OG1",
|
| 789 |
+
"CG2",
|
| 790 |
+
"",
|
| 791 |
+
"",
|
| 792 |
+
"",
|
| 793 |
+
"",
|
| 794 |
+
"",
|
| 795 |
+
"",
|
| 796 |
+
"",
|
| 797 |
+
],
|
| 798 |
+
"TRP": [
|
| 799 |
+
"N",
|
| 800 |
+
"CA",
|
| 801 |
+
"C",
|
| 802 |
+
"O",
|
| 803 |
+
"CB",
|
| 804 |
+
"CG",
|
| 805 |
+
"CD1",
|
| 806 |
+
"CD2",
|
| 807 |
+
"NE1",
|
| 808 |
+
"CE2",
|
| 809 |
+
"CE3",
|
| 810 |
+
"CZ2",
|
| 811 |
+
"CZ3",
|
| 812 |
+
"CH2",
|
| 813 |
+
],
|
| 814 |
+
"TYR": [
|
| 815 |
+
"N",
|
| 816 |
+
"CA",
|
| 817 |
+
"C",
|
| 818 |
+
"O",
|
| 819 |
+
"CB",
|
| 820 |
+
"CG",
|
| 821 |
+
"CD1",
|
| 822 |
+
"CD2",
|
| 823 |
+
"CE1",
|
| 824 |
+
"CE2",
|
| 825 |
+
"CZ",
|
| 826 |
+
"OH",
|
| 827 |
+
"",
|
| 828 |
+
"",
|
| 829 |
+
],
|
| 830 |
+
"VAL": [
|
| 831 |
+
"N",
|
| 832 |
+
"CA",
|
| 833 |
+
"C",
|
| 834 |
+
"O",
|
| 835 |
+
"CB",
|
| 836 |
+
"CG1",
|
| 837 |
+
"CG2",
|
| 838 |
+
"",
|
| 839 |
+
"",
|
| 840 |
+
"",
|
| 841 |
+
"",
|
| 842 |
+
"",
|
| 843 |
+
"",
|
| 844 |
+
"",
|
| 845 |
+
],
|
| 846 |
+
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
|
| 847 |
+
}
|
| 848 |
+
# pylint: enable=line-too-long
|
| 849 |
+
# pylint: enable=bad-whitespace
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
# This is the standard residue order when coding AA type as a number.
|
| 853 |
+
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
|
| 854 |
+
restypes = [
|
| 855 |
+
"A",
|
| 856 |
+
"R",
|
| 857 |
+
"N",
|
| 858 |
+
"D",
|
| 859 |
+
"C",
|
| 860 |
+
"Q",
|
| 861 |
+
"E",
|
| 862 |
+
"G",
|
| 863 |
+
"H",
|
| 864 |
+
"I",
|
| 865 |
+
"L",
|
| 866 |
+
"K",
|
| 867 |
+
"M",
|
| 868 |
+
"F",
|
| 869 |
+
"P",
|
| 870 |
+
"S",
|
| 871 |
+
"T",
|
| 872 |
+
"W",
|
| 873 |
+
"Y",
|
| 874 |
+
"V",
|
| 875 |
+
]
|
| 876 |
+
restype_order = {restype: i for i, restype in enumerate(restypes)}
|
| 877 |
+
restype_num = len(restypes) # := 20.
|
| 878 |
+
unk_restype_index = restype_num # Catch-all index for unknown restypes.
|
| 879 |
+
|
| 880 |
+
restypes_with_x = restypes + ["X"]
|
| 881 |
+
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def sequence_to_onehot(
|
| 885 |
+
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
|
| 886 |
+
) -> np.ndarray:
|
| 887 |
+
"""Maps the given sequence into a one-hot encoded matrix.
|
| 888 |
+
|
| 889 |
+
Args:
|
| 890 |
+
sequence: An amino acid sequence.
|
| 891 |
+
mapping: A dictionary mapping amino acids to integers.
|
| 892 |
+
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
|
| 893 |
+
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
|
| 894 |
+
amino acid 'X', an error will be thrown. If False, any amino acid not in
|
| 895 |
+
the mapping will throw an error.
|
| 896 |
+
|
| 897 |
+
Returns:
|
| 898 |
+
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
|
| 899 |
+
the sequence.
|
| 900 |
+
|
| 901 |
+
Raises:
|
| 902 |
+
ValueError: If the mapping doesn't contain values from 0 to
|
| 903 |
+
num_unique_aas - 1 without any gaps.
|
| 904 |
+
"""
|
| 905 |
+
num_entries = max(mapping.values()) + 1
|
| 906 |
+
|
| 907 |
+
if sorted(set(mapping.values())) != list(range(num_entries)):
|
| 908 |
+
raise ValueError(
|
| 909 |
+
"The mapping must have values from 0 to num_unique_aas-1 "
|
| 910 |
+
"without any gaps. Got: %s" % sorted(mapping.values())
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int)
|
| 914 |
+
|
| 915 |
+
for aa_index, aa_type in enumerate(sequence):
|
| 916 |
+
if map_unknown_to_x:
|
| 917 |
+
if aa_type.isalpha() and aa_type.isupper():
|
| 918 |
+
aa_id = mapping.get(aa_type, mapping["X"])
|
| 919 |
+
else:
|
| 920 |
+
raise ValueError(
|
| 921 |
+
f"Invalid character in the sequence: {aa_type}"
|
| 922 |
+
)
|
| 923 |
+
else:
|
| 924 |
+
aa_id = mapping[aa_type]
|
| 925 |
+
one_hot_arr[aa_index, aa_id] = 1
|
| 926 |
+
|
| 927 |
+
return one_hot_arr
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
restype_1to3 = {
|
| 931 |
+
"A": "ALA",
|
| 932 |
+
"R": "ARG",
|
| 933 |
+
"N": "ASN",
|
| 934 |
+
"D": "ASP",
|
| 935 |
+
"C": "CYS",
|
| 936 |
+
"Q": "GLN",
|
| 937 |
+
"E": "GLU",
|
| 938 |
+
"G": "GLY",
|
| 939 |
+
"H": "HIS",
|
| 940 |
+
"I": "ILE",
|
| 941 |
+
"L": "LEU",
|
| 942 |
+
"K": "LYS",
|
| 943 |
+
"M": "MET",
|
| 944 |
+
"F": "PHE",
|
| 945 |
+
"P": "PRO",
|
| 946 |
+
"S": "SER",
|
| 947 |
+
"T": "THR",
|
| 948 |
+
"W": "TRP",
|
| 949 |
+
"Y": "TYR",
|
| 950 |
+
"V": "VAL",
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
|
| 955 |
+
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
|
| 956 |
+
# many more, and less common, three letter names as keys and maps many of these
|
| 957 |
+
# to the same one letter name (including 'X' and 'U' which we don't use here).
|
| 958 |
+
restype_3to1 = {v: k for k, v in restype_1to3.items()}
|
| 959 |
+
|
| 960 |
+
# Define a restype name for all unknown residues.
|
| 961 |
+
unk_restype = "UNK"
|
| 962 |
+
|
| 963 |
+
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
|
| 964 |
+
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
|
| 968 |
+
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
|
| 969 |
+
# remaining 20 amino acids are kept in alphabetical order.
|
| 970 |
+
# There are 2 non-amino acid codes, X (representing any amino acid) and
|
| 971 |
+
# "-" representing a missing amino acid in an alignment. The id for these
|
| 972 |
+
# codes is put at the end (20 and 21) so that they can easily be ignored if
|
| 973 |
+
# desired.
|
| 974 |
+
HHBLITS_AA_TO_ID = {
|
| 975 |
+
"A": 0,
|
| 976 |
+
"B": 2,
|
| 977 |
+
"C": 1,
|
| 978 |
+
"D": 2,
|
| 979 |
+
"E": 3,
|
| 980 |
+
"F": 4,
|
| 981 |
+
"G": 5,
|
| 982 |
+
"H": 6,
|
| 983 |
+
"I": 7,
|
| 984 |
+
"J": 20,
|
| 985 |
+
"K": 8,
|
| 986 |
+
"L": 9,
|
| 987 |
+
"M": 10,
|
| 988 |
+
"N": 11,
|
| 989 |
+
"O": 20,
|
| 990 |
+
"P": 12,
|
| 991 |
+
"Q": 13,
|
| 992 |
+
"R": 14,
|
| 993 |
+
"S": 15,
|
| 994 |
+
"T": 16,
|
| 995 |
+
"U": 1,
|
| 996 |
+
"V": 17,
|
| 997 |
+
"W": 18,
|
| 998 |
+
"X": 20,
|
| 999 |
+
"Y": 19,
|
| 1000 |
+
"Z": 3,
|
| 1001 |
+
"-": 21,
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
# Partial inversion of HHBLITS_AA_TO_ID.
|
| 1005 |
+
ID_TO_HHBLITS_AA = {
|
| 1006 |
+
0: "A",
|
| 1007 |
+
1: "C", # Also U.
|
| 1008 |
+
2: "D", # Also B.
|
| 1009 |
+
3: "E", # Also Z.
|
| 1010 |
+
4: "F",
|
| 1011 |
+
5: "G",
|
| 1012 |
+
6: "H",
|
| 1013 |
+
7: "I",
|
| 1014 |
+
8: "K",
|
| 1015 |
+
9: "L",
|
| 1016 |
+
10: "M",
|
| 1017 |
+
11: "N",
|
| 1018 |
+
12: "P",
|
| 1019 |
+
13: "Q",
|
| 1020 |
+
14: "R",
|
| 1021 |
+
15: "S",
|
| 1022 |
+
16: "T",
|
| 1023 |
+
17: "V",
|
| 1024 |
+
18: "W",
|
| 1025 |
+
19: "Y",
|
| 1026 |
+
20: "X", # Includes J and O.
|
| 1027 |
+
21: "-",
|
| 1028 |
+
}
|
| 1029 |
+
|
| 1030 |
+
restypes_with_x_and_gap = restypes + ["X", "-"]
|
| 1031 |
+
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
|
| 1032 |
+
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
|
| 1033 |
+
for i in range(len(restypes_with_x_and_gap))
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
def _make_standard_atom_mask() -> np.ndarray:
|
| 1038 |
+
"""Returns [num_res_types, num_atom_types] mask array."""
|
| 1039 |
+
# +1 to account for unknown (all 0s).
|
| 1040 |
+
mask = np.zeros([restype_num + 1, atom_type_num], dtype=int)
|
| 1041 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1042 |
+
restype_name = restype_1to3[restype_letter]
|
| 1043 |
+
atom_names = residue_atoms[restype_name]
|
| 1044 |
+
for atom_name in atom_names:
|
| 1045 |
+
atom_type = atom_order[atom_name]
|
| 1046 |
+
mask[restype, atom_type] = 1
|
| 1047 |
+
return mask
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
STANDARD_ATOM_MASK = _make_standard_atom_mask()
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
# A one hot representation for the first and second atoms defining the axis
|
| 1054 |
+
# of rotation for each chi-angle in each residue.
|
| 1055 |
+
def chi_angle_atom(atom_index: int) -> np.ndarray:
|
| 1056 |
+
"""Define chi-angle rigid groups via one-hot representations."""
|
| 1057 |
+
chi_angles_index = {}
|
| 1058 |
+
one_hots = []
|
| 1059 |
+
|
| 1060 |
+
for k, v in chi_angles_atoms.items():
|
| 1061 |
+
indices = [atom_types.index(s[atom_index]) for s in v]
|
| 1062 |
+
indices.extend([-1] * (4 - len(indices)))
|
| 1063 |
+
chi_angles_index[k] = indices
|
| 1064 |
+
|
| 1065 |
+
for r in restypes:
|
| 1066 |
+
res3 = restype_1to3[r]
|
| 1067 |
+
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
|
| 1068 |
+
one_hots.append(one_hot)
|
| 1069 |
+
|
| 1070 |
+
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
|
| 1071 |
+
one_hot = np.stack(one_hots, axis=0)
|
| 1072 |
+
one_hot = np.transpose(one_hot, [0, 2, 1])
|
| 1073 |
+
|
| 1074 |
+
return one_hot
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
chi_atom_1_one_hot = chi_angle_atom(1)
|
| 1078 |
+
chi_atom_2_one_hot = chi_angle_atom(2)
|
| 1079 |
+
|
| 1080 |
+
# An array like chi_angles_atoms but using indices rather than names.
|
| 1081 |
+
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
|
| 1082 |
+
chi_angles_atom_indices = tree.map_structure(
|
| 1083 |
+
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
|
| 1084 |
+
)
|
| 1085 |
+
chi_angles_atom_indices = np.array(
|
| 1086 |
+
[
|
| 1087 |
+
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
|
| 1088 |
+
for chi_atoms in chi_angles_atom_indices
|
| 1089 |
+
]
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
|
| 1093 |
+
# and atom index within that group.
|
| 1094 |
+
chi_groups_for_atom = collections.defaultdict(list)
|
| 1095 |
+
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
|
| 1096 |
+
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
|
| 1097 |
+
for atom_i, atom in enumerate(chi_group):
|
| 1098 |
+
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
|
| 1099 |
+
chi_groups_for_atom = dict(chi_groups_for_atom)
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
def _make_rigid_transformation_4x4(ex, ey, translation):
|
| 1103 |
+
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
|
| 1104 |
+
# Normalize ex.
|
| 1105 |
+
ex_normalized = ex / np.linalg.norm(ex)
|
| 1106 |
+
|
| 1107 |
+
# make ey perpendicular to ex
|
| 1108 |
+
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
|
| 1109 |
+
ey_normalized /= np.linalg.norm(ey_normalized)
|
| 1110 |
+
|
| 1111 |
+
# compute ez as cross product
|
| 1112 |
+
eznorm = np.cross(ex_normalized, ey_normalized)
|
| 1113 |
+
m = np.stack(
|
| 1114 |
+
[ex_normalized, ey_normalized, eznorm, translation]
|
| 1115 |
+
).transpose()
|
| 1116 |
+
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
|
| 1117 |
+
return m
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
# create an array with (restype, atomtype) --> rigid_group_idx
|
| 1121 |
+
# and an array with (restype, atomtype, coord) for the atom positions
|
| 1122 |
+
# and compute affine transformation matrices (4,4) from one rigid group to the
|
| 1123 |
+
# previous group
|
| 1124 |
+
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
|
| 1125 |
+
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
|
| 1126 |
+
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
|
| 1127 |
+
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
|
| 1128 |
+
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
|
| 1129 |
+
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
|
| 1130 |
+
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
|
| 1131 |
+
|
| 1132 |
+
|
| 1133 |
+
def _make_rigid_group_constants():
|
| 1134 |
+
"""Fill the arrays above."""
|
| 1135 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1136 |
+
resname = restype_1to3[restype_letter]
|
| 1137 |
+
for atomname, group_idx, atom_position in rigid_group_atom_positions[
|
| 1138 |
+
resname
|
| 1139 |
+
]:
|
| 1140 |
+
atomtype = atom_order[atomname]
|
| 1141 |
+
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
|
| 1142 |
+
restype_atom37_mask[restype, atomtype] = 1
|
| 1143 |
+
restype_atom37_rigid_group_positions[
|
| 1144 |
+
restype, atomtype, :
|
| 1145 |
+
] = atom_position
|
| 1146 |
+
|
| 1147 |
+
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
|
| 1148 |
+
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
|
| 1149 |
+
restype_atom14_mask[restype, atom14idx] = 1
|
| 1150 |
+
restype_atom14_rigid_group_positions[
|
| 1151 |
+
restype, atom14idx, :
|
| 1152 |
+
] = atom_position
|
| 1153 |
+
|
| 1154 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1155 |
+
resname = restype_1to3[restype_letter]
|
| 1156 |
+
atom_positions = {
|
| 1157 |
+
name: np.array(pos)
|
| 1158 |
+
for name, _, pos in rigid_group_atom_positions[resname]
|
| 1159 |
+
}
|
| 1160 |
+
|
| 1161 |
+
# backbone to backbone is the identity transform
|
| 1162 |
+
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
|
| 1163 |
+
|
| 1164 |
+
# pre-omega-frame to backbone (currently dummy identity matrix)
|
| 1165 |
+
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
|
| 1166 |
+
|
| 1167 |
+
# phi-frame to backbone
|
| 1168 |
+
mat = _make_rigid_transformation_4x4(
|
| 1169 |
+
ex=atom_positions["N"] - atom_positions["CA"],
|
| 1170 |
+
ey=np.array([1.0, 0.0, 0.0]),
|
| 1171 |
+
translation=atom_positions["N"],
|
| 1172 |
+
)
|
| 1173 |
+
restype_rigid_group_default_frame[restype, 2, :, :] = mat
|
| 1174 |
+
|
| 1175 |
+
# psi-frame to backbone
|
| 1176 |
+
mat = _make_rigid_transformation_4x4(
|
| 1177 |
+
ex=atom_positions["C"] - atom_positions["CA"],
|
| 1178 |
+
ey=atom_positions["CA"] - atom_positions["N"],
|
| 1179 |
+
translation=atom_positions["C"],
|
| 1180 |
+
)
|
| 1181 |
+
restype_rigid_group_default_frame[restype, 3, :, :] = mat
|
| 1182 |
+
|
| 1183 |
+
# chi1-frame to backbone
|
| 1184 |
+
if chi_angles_mask[restype][0]:
|
| 1185 |
+
base_atom_names = chi_angles_atoms[resname][0]
|
| 1186 |
+
base_atom_positions = [
|
| 1187 |
+
atom_positions[name] for name in base_atom_names
|
| 1188 |
+
]
|
| 1189 |
+
mat = _make_rigid_transformation_4x4(
|
| 1190 |
+
ex=base_atom_positions[2] - base_atom_positions[1],
|
| 1191 |
+
ey=base_atom_positions[0] - base_atom_positions[1],
|
| 1192 |
+
translation=base_atom_positions[2],
|
| 1193 |
+
)
|
| 1194 |
+
restype_rigid_group_default_frame[restype, 4, :, :] = mat
|
| 1195 |
+
|
| 1196 |
+
# chi2-frame to chi1-frame
|
| 1197 |
+
# chi3-frame to chi2-frame
|
| 1198 |
+
# chi4-frame to chi3-frame
|
| 1199 |
+
# luckily all rotation axes for the next frame start at (0,0,0) of the
|
| 1200 |
+
# previous frame
|
| 1201 |
+
for chi_idx in range(1, 4):
|
| 1202 |
+
if chi_angles_mask[restype][chi_idx]:
|
| 1203 |
+
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
|
| 1204 |
+
axis_end_atom_position = atom_positions[axis_end_atom_name]
|
| 1205 |
+
mat = _make_rigid_transformation_4x4(
|
| 1206 |
+
ex=axis_end_atom_position,
|
| 1207 |
+
ey=np.array([-1.0, 0.0, 0.0]),
|
| 1208 |
+
translation=axis_end_atom_position,
|
| 1209 |
+
)
|
| 1210 |
+
restype_rigid_group_default_frame[
|
| 1211 |
+
restype, 4 + chi_idx, :, :
|
| 1212 |
+
] = mat
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
_make_rigid_group_constants()
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
def make_atom14_dists_bounds(
|
| 1219 |
+
overlap_tolerance=1.5, bond_length_tolerance_factor=15
|
| 1220 |
+
):
|
| 1221 |
+
"""compute upper and lower bounds for bonds to assess violations."""
|
| 1222 |
+
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
|
| 1223 |
+
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
|
| 1224 |
+
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
|
| 1225 |
+
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
|
| 1226 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1227 |
+
resname = restype_1to3[restype_letter]
|
| 1228 |
+
atom_list = restype_name_to_atom14_names[resname]
|
| 1229 |
+
|
| 1230 |
+
# create lower and upper bounds for clashes
|
| 1231 |
+
for atom1_idx, atom1_name in enumerate(atom_list):
|
| 1232 |
+
if not atom1_name:
|
| 1233 |
+
continue
|
| 1234 |
+
atom1_radius = van_der_waals_radius[atom1_name[0]]
|
| 1235 |
+
for atom2_idx, atom2_name in enumerate(atom_list):
|
| 1236 |
+
if (not atom2_name) or atom1_idx == atom2_idx:
|
| 1237 |
+
continue
|
| 1238 |
+
atom2_radius = van_der_waals_radius[atom2_name[0]]
|
| 1239 |
+
lower = atom1_radius + atom2_radius - overlap_tolerance
|
| 1240 |
+
upper = 1e10
|
| 1241 |
+
restype_atom14_bond_lower_bound[
|
| 1242 |
+
restype, atom1_idx, atom2_idx
|
| 1243 |
+
] = lower
|
| 1244 |
+
restype_atom14_bond_lower_bound[
|
| 1245 |
+
restype, atom2_idx, atom1_idx
|
| 1246 |
+
] = lower
|
| 1247 |
+
restype_atom14_bond_upper_bound[
|
| 1248 |
+
restype, atom1_idx, atom2_idx
|
| 1249 |
+
] = upper
|
| 1250 |
+
restype_atom14_bond_upper_bound[
|
| 1251 |
+
restype, atom2_idx, atom1_idx
|
| 1252 |
+
] = upper
|
| 1253 |
+
|
| 1254 |
+
# overwrite lower and upper bounds for bonds and angles
|
| 1255 |
+
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
|
| 1256 |
+
atom1_idx = atom_list.index(b.atom1_name)
|
| 1257 |
+
atom2_idx = atom_list.index(b.atom2_name)
|
| 1258 |
+
lower = b.length - bond_length_tolerance_factor * b.stddev
|
| 1259 |
+
upper = b.length + bond_length_tolerance_factor * b.stddev
|
| 1260 |
+
restype_atom14_bond_lower_bound[
|
| 1261 |
+
restype, atom1_idx, atom2_idx
|
| 1262 |
+
] = lower
|
| 1263 |
+
restype_atom14_bond_lower_bound[
|
| 1264 |
+
restype, atom2_idx, atom1_idx
|
| 1265 |
+
] = lower
|
| 1266 |
+
restype_atom14_bond_upper_bound[
|
| 1267 |
+
restype, atom1_idx, atom2_idx
|
| 1268 |
+
] = upper
|
| 1269 |
+
restype_atom14_bond_upper_bound[
|
| 1270 |
+
restype, atom2_idx, atom1_idx
|
| 1271 |
+
] = upper
|
| 1272 |
+
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
|
| 1273 |
+
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
|
| 1274 |
+
return {
|
| 1275 |
+
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
|
| 1276 |
+
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
|
| 1277 |
+
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
|
| 1278 |
+
}
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
|
| 1282 |
+
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
|
| 1283 |
+
np.arange(14, dtype=int), (21, 1)
|
| 1284 |
+
)
|
| 1285 |
+
|
| 1286 |
+
|
| 1287 |
+
def _make_atom14_ambiguity_feats():
|
| 1288 |
+
for res, pairs in residue_atom_renaming_swaps.items():
|
| 1289 |
+
res_idx = restype_order[restype_3to1[res]]
|
| 1290 |
+
for atom1, atom2 in pairs.items():
|
| 1291 |
+
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
|
| 1292 |
+
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
|
| 1293 |
+
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
|
| 1294 |
+
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
|
| 1295 |
+
restype_atom14_ambiguous_atoms_swap_idx[
|
| 1296 |
+
res_idx, atom1_idx
|
| 1297 |
+
] = atom2_idx
|
| 1298 |
+
restype_atom14_ambiguous_atoms_swap_idx[
|
| 1299 |
+
res_idx, atom2_idx
|
| 1300 |
+
] = atom1_idx
|
| 1301 |
+
|
| 1302 |
+
|
| 1303 |
+
_make_atom14_ambiguity_feats()
|
| 1304 |
+
|
| 1305 |
+
|
| 1306 |
+
def aatype_to_str_sequence(aatype):
|
| 1307 |
+
return ''.join([
|
| 1308 |
+
restypes_with_x[aatype[i]]
|
| 1309 |
+
for i in range(len(aatype))
|
| 1310 |
+
])
|