PepFlow / eval /geometry.py
Irwiny123's picture
添加PepFlow模型初始代码
ef423c5
from Bio.PDB import PDBParser, Superimposer, is_aa, Select, NeighborSearch
import tmtools
import os
import numpy as np
import mdtraj as md
from Bio.SeqUtils import seq1
import warnings
from Bio import BiopythonWarning, SeqIO
import difflib
import torch
# 忽略PDBConstructionWarning
warnings.filterwarnings('ignore', category=BiopythonWarning)
def get_chain_from_pdb(pdb_path, chain_id='A'):
parser = PDBParser()
structure = parser.get_structure('X', pdb_path)[0]
for chain in structure:
if chain.id == chain_id:
# print(len(chain))
return chain
return None
def diff_ratio(str1, str2):
# Create a SequenceMatcher object
seq_matcher = difflib.SequenceMatcher(None, str1, str2)
# Calculate the difference ratio
return seq_matcher.ratio()
#######################################
#RMSD and Tm
#######################################
def align_chains(chain1, chain2):
reslist1 = []
reslist2 = []
for residue1,residue2 in zip(chain1.get_residues(),chain2.get_residues()):
if is_aa(residue1) and residue1.has_id('CA'): # at least have CA
reslist1.append(residue1)
reslist2.append(residue2)
return reslist1,reslist2
def get_rmsd(chain1, chain2):
# chain1 = get_chain_from_pdb(pdb1, chain_id1)
# chain2 = get_chain_from_pdb(pdb2, chain_id2)
if chain1 is None or chain2 is None:
return None
super_imposer = Superimposer()
pos1 = np.array([atom.get_coord() for atom in chain1.get_atoms() if atom.name == 'CA'])
pos2 = np.array([atom.get_coord() for atom in chain2.get_atoms() if atom.name == 'CA'])
rmsd1 = np.sqrt(np.sum((pos1 - pos2)**2) / len(pos1))
super_imposer.set_atoms([atom for atom in chain1.get_atoms() if atom.name == 'CA'],
[atom for atom in chain2.get_atoms() if atom.name == 'CA'])
rmsd2 = super_imposer.rms
return rmsd1,rmsd2
def get_tm(chain1,chain2):
# chain1 = get_chain_from_pdb(pdb1, chain_id1)
# chain2 = get_chain_from_pdb(pdb2, chain_id2)
pos1 = np.array([atom.get_coord() for atom in chain1.get_atoms() if atom.name == 'CA'])
pos2 = np.array([atom.get_coord() for atom in chain2.get_atoms() if atom.name == 'CA'])
tm_results = tmtools.tm_align(pos1, pos2, 'A'*len(pos1), 'A'*len(pos2))
# print(dir(tm_results))
return tm_results.tm_norm_chain2
def get_traj_chain(pdb, chain):
parser = PDBParser()
structure = parser.get_structure('X', pdb)[0]
chain2id = {chain.id:i for i,chain in enumerate(structure)}
traj = md.load(pdb)
chain_indices = traj.topology.select(f"chainid {chain2id[chain]}")
traj = traj.atom_slice(chain_indices)
return traj
def get_second_stru(pdb,chain):
parser = PDBParser()
structure = parser.get_structure('X', pdb)[0]
chain2id = {chain.id:i for i,chain in enumerate(structure)}
traj = md.load(pdb)
chain_indices = traj.topology.select(f"chainid {chain2id[chain]}")
traj = traj.atom_slice(chain_indices)
return md.compute_dssp(traj,simplified=True)
def get_ss(traj1,traj2):
# traj1,traj2 = get_traj_chain(pdb1,chain_id1),get_traj_chain(pdb2,chain_id2)
ss1,ss2 = md.compute_dssp(traj1,simplified=True),md.compute_dssp(traj2,simplified=True)
return (ss1==ss2).mean()
def get_bind_site(pdb,chain_id):
parser = PDBParser()
structure = parser.get_structure('X', pdb)[0]
peps = [atom for res in structure[chain_id] for atom in res if atom.get_name() == 'CA']
recs = [atom for chain in structure if chain.get_id()!=chain_id for res in chain for atom in res if atom.get_name() == 'CA']
# print(recs)
search = NeighborSearch(recs)
near_res = []
for atom in peps:
near_res += search.search(atom.get_coord(), 10.0, level='R')
near_res = set([res.get_id()[1] for res in near_res])
return near_res
def get_bind_ratio(pdb1, pdb2, chain_id1, chain_id2):
near_res1,near_res2 = get_bind_site(pdb1,chain_id1),get_bind_site(pdb2,chain_id2)
# print(near_res1)
# print(near_res2)
return len(near_res1.intersection(near_res2))/(len(near_res2)+1e-10) # last one is gt
def get_dihedral(pdb,chain):
traj = get_traj_chain(pdb,chain)
#TODO: dihedral
def get_seq(pdb,chain_id):
parser = PDBParser()
chain = parser.get_structure('X', pdb)[0][chain_id]
return seq1("".join([residue.get_resname() for residue in chain])) # ignore is_aa,used for extract seq from genrated pdb
def get_mpnn_seqs(path):
fastas = []
for record in SeqIO.parse(path, "fasta"):
tmp = [c for c in str(record.seq)]
fastas.append(tmp)
return fastas