| |
| import numpy as np |
| from collections import defaultdict |
| import prody as pr |
| import os |
|
|
| from datasets.constants import chi, atom_order, aa_long2short, aa_short2aa_idx, aa_idx2aa_short |
|
|
|
|
| def get_dihedral_indices(resname, chi_num): |
| """Return the atom indices for the specified dihedral angle. |
| """ |
| if resname not in chi: |
| return np.array([np.nan]*4) |
| if chi_num not in chi[resname]: |
| return np.array([np.nan]*4) |
| return np.array([atom_order[resname].index(x) for x in chi[resname][chi_num]]) |
|
|
|
|
| dihedral_indices = defaultdict(list) |
| for aa in atom_order.keys(): |
| for i in range(1, 5): |
| inds = get_dihedral_indices(aa, i) |
| dihedral_indices[aa].append(inds) |
| dihedral_indices[aa] = np.array(dihedral_indices[aa]) |
|
|
|
|
| def vector_batch(a, b): |
| return a - b |
|
|
|
|
| def unit_vector_batch(v): |
| return v / np.linalg.norm(v, axis=1, keepdims=True) |
|
|
|
|
| def dihedral_angle_batch(p): |
| b0 = vector_batch(p[:, 0], p[:, 1]) |
| b1 = vector_batch(p[:, 1], p[:, 2]) |
| b2 = vector_batch(p[:, 2], p[:, 3]) |
| |
| n1 = np.cross(b0, b1) |
| n2 = np.cross(b1, b2) |
| |
| m1 = np.cross(n1, b1 / np.linalg.norm(b1, axis=1, keepdims=True)) |
| |
| x = np.sum(n1 * n2, axis=1) |
| y = np.sum(m1 * n2, axis=1) |
| |
| deg = np.degrees(np.arctan2(y, x)) |
|
|
| deg[deg < 0] += 360 |
|
|
| return deg |
|
|
|
|
| def batch_compute_dihedral_angles(sidechains): |
| sidechains_np = np.array(sidechains) |
| dihedral_angles = dihedral_angle_batch(sidechains_np) |
| return dihedral_angles |
|
|
|
|
| def get_coords(prody_pdb): |
| resindices = sorted(set(prody_pdb.ca.getResindices())) |
| coords = np.full((len(resindices), 14, 3), np.nan) |
| for i, resind in enumerate(resindices): |
| sel = prody_pdb.select(f'resindex {resind}') |
| resname = sel.getResnames()[0] |
| for j, name in enumerate(atom_order[aa_long2short[resname] if resname in aa_long2short else 'X']): |
| sel_resnum_name = sel.select(f'name {name}') |
| if sel_resnum_name is not None: |
| coords[i, j, :] = sel_resnum_name.getCoords()[0] |
| else: |
| coords[i, j, :] = [np.nan, np.nan, np.nan] |
| return coords |
|
|
|
|
| def get_onehot_sequence(seq): |
| onehot = np.zeros((len(seq), 20)) |
| for i, aa in enumerate(seq): |
| idx = aa_short2aa_idx[aa] if aa in aa_short2aa_idx else 7 |
| onehot[i, idx] = 1 |
| return onehot |
|
|
|
|
| def get_dihedral_indices(onehot_sequence): |
| return np.array([dihedral_indices[aa_idx2aa_short[aa_idx]] for aa_idx in np.where(onehot_sequence)[1]]) |
|
|
|
|
| def _get_chi_angles(coords, indices): |
| X = coords |
| Y = indices.astype(int) |
| N = coords.shape[0] |
| mask = np.isnan(indices) |
| Y[mask] = 0 |
| Z = X[np.arange(N)[:, None, None], Y, :] |
| Z[mask] = np.nan |
| chi_angles = batch_compute_dihedral_angles(Z.reshape(-1, 4, 3)).reshape(N, 4) |
| return chi_angles |
|
|
|
|
| def get_chi_angles(coords, seq, return_onehot=False): |
| """ |
| |
| Parameters |
| ---------- |
| prody_pdb : prody.AtomGroup |
| prody pdb object or selection |
| return_coords : bool, optional |
| return coordinates of prody_pdb in (N, 14, 3) array format, by default False |
| return_onehot : bool, optional |
| return one-hot sequence of prody_pdb, by default False |
| |
| Returns |
| ------- |
| numpy array of shape (N, 4) |
| Array contains chi angles of sidechains in row-order of residue indices in prody_pdb. |
| If a chi angle is not defined for a residue, due to missing atoms or GLY / ALA, it is set to np.nan. |
| """ |
| onehot = get_onehot_sequence(seq) |
| dihedral_indices = get_dihedral_indices(onehot) |
| if return_onehot: |
| return _get_chi_angles(coords, dihedral_indices), onehot |
| return _get_chi_angles(coords, dihedral_indices) |
|
|
|
|
| def test_get_chi_angles(print_chi_angles=False): |
| |
| pdb = pr.parsePDB('6w70') |
| prody_pdb = pdb.select('chain A') |
| chi_angles = get_chi_angles(prody_pdb) |
| assert chi_angles.shape == (prody_pdb.ca.numAtoms(), 4) |
| assert chi_angles[0,0] < 56.0 and chi_angles[0,0] > 55.0 |
| print('test_get_chi_angles passed') |
| try: |
| os.remove('6w70.pdb.gz') |
| except: |
| pass |
| if print_chi_angles: |
| print(chi_angles) |
| return True |
|
|
|
|
| if __name__ == '__main__': |
| test_get_chi_angles(print_chi_angles=True) |
|
|
|
|
|
|