| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import logging |
| import re |
| import warnings |
| from dataclasses import dataclass |
| from functools import partial |
| from typing import Dict, List, Tuple |
|
|
| import numpy as np |
| import torch |
|
|
| import src.data.protein.polyseq as polyseq |
| import src.data.protein.starparser as sp |
| from src.data import constants |
| import gzip |
| import io |
|
|
| @dataclass |
| class SystemAssemblyInfo: |
| """A class for representing the assembly information for System objects. |
| |
| assemblies (dict): a dictionary of assemblies with keys being assembly IDs |
| and values being dictionaries with of the following structure: |
| { |
| "details": "complete icosahedral assembly", |
| "instructions": [ |
| { |
| "oper_expression": "(1-60)", |
| "chains": [0, 1, 2], |
| |
| # Each assembly instruction has information for generating |
| # one or more images, with image `i` generated by applying |
| # the sequence of operations with IDs in `operations[i]` to the |
| # list of chains in `chains`. The corresponding operations |
| # are described under `assembly_info["operations"][ID]`. |
| "operations": [["X0", "1", "2", "3"], ["X0", "4", "5", "6"]]], |
| }, |
| ... |
| ], |
| } |
| |
| operations (dict): a dictionary with symmetry operations. Keys are operation IDs |
| and values being dictionaries with the following structure: |
| { |
| "type": "identity operation", |
| "name": "1_555", |
| "matrix": np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]), |
| "vector": np.array([0., 0., 0.]), |
| }, |
| ... |
| """ |
|
|
| assemblies: dict |
| operations: dict |
|
|
| def __init__(self, assemblies: dict = dict(), operations: dict = dict()): |
| self.assemblies = assemblies |
| self.operations = operations |
|
|
| @staticmethod |
| def make_operation(type: str, name: str, matrix: list, vector: list): |
| op = { |
| "type": type, |
| "name": name, |
| "matrix": np.zeros([3, 3]), |
| "vector": np.zeros([3, 1]), |
| } |
| assert len(matrix) == 9, "expected 9 elements in rotation matrix" |
| assert len(vector) == 3, "expected 3 elements in translation vector" |
| for i in range(3): |
| op["vector"][i] = float(vector[i]) |
| for j in range(3): |
| op["matrix"][i][j] = float(matrix[i * 3 + j]) |
| return op |
|
|
| def delete_chain(self, cid: str): |
| """Deletes the mention of the chain from assembly information. |
| |
| Args: |
| cid (str): Chain ID to delete. |
| """ |
| for ass_id, assembly in self.assemblies.items(): |
| for ins in assembly["instructions"]: |
| ins["chains"] = [_id for _id in ins["chains"] if _id != cid] |
|
|
| def rename_chain(self, old_cid: str, new_cid: str): |
| """Renames all mentions of a chain to its new chain ID. |
| |
| Args: |
| old_cid (str): Chain ID to rename. |
| new_cid (str): Newly assigned Chain ID. |
| """ |
| for ass_id, assembly in self.assemblies.items(): |
| for ins in assembly["instructions"]: |
| ins["chains"] = [ |
| new_cid if cid == old_cid else cid for cid in ins["chains"] |
| ] |
|
|
|
|
| class StringList: |
| """A class for representing and accessing a list of strings in a highly memory-efficient |
| manner. Access is constant time, but modification is linear time in length of list. |
| """ |
|
|
| def __init__(self, init_list: List[str] = []): |
| self.string = "" |
| self.rng = ArrayList(2, dtype=int) |
| for i in range(len(init_list)): |
| self.append(init_list[i]) |
|
|
| def __getitem__(self, i: int): |
| beg, length = self.rng[i] |
| return self.string[beg : beg + length] |
|
|
| def __setitem__(self, i: int, new_string: str): |
| beg, length = self.rng[i] |
| self.string = self.string[:beg] + new_string + self.string[beg + length :] |
| if len(new_string) != length: |
| self.rng[i, 1] = len(new_string) |
| self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) - length |
|
|
| def __str__(self): |
| return self.string |
|
|
| def __len__(self): |
| return len(self.rng) |
|
|
| def copy(self): |
| new_list = StringList() |
| new_list.string = self.string |
| new_list.rng = self.rng.copy() |
| return new_list |
|
|
| def append(self, new_string: str): |
| self.rng.append([len(self.string), len(new_string)]) |
| self.string = self.string + new_string |
|
|
| def insert(self, i: int, new_string: str): |
| if i < len(self): |
| ix, _ = self.rng[i] |
| elif i == len(self): |
| if len(self) == 0: |
| ix = 0 |
| else: |
| ix = self.rng[i - 1].sum() |
| else: |
| raise Exception( |
| f"cannot insert in position {i} for stringList of length {len(self)}" |
| ) |
| self.string = self.string[0:ix] + new_string + self.string[ix:] |
| self.rng.insert(i, [ix, len(new_string)]) |
| if len(new_string) > 0: |
| self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] + len(new_string) |
|
|
| def pop(self, i: int): |
| beg, length = self.rng[i] |
| val = self.string[beg : beg + length] |
| self.string = self.string[0:beg] + self.string[beg + length :] |
| self.rng[i + 1 :, 0] = self.rng[i + 1 :, 0] - len(val) |
| self.rng.pop(i) |
| return val |
|
|
| def delete_range(self, rng: range): |
| rng = sorted(rng) |
| [i, j] = [rng[0], rng[-1]] |
| beg, _ = self.rng[i] |
| end = self.rng[j].sum() |
| self.string = self.string[0:beg] + self.string[end:] |
| self.rng[j + 1 :, 0] = self.rng[j + 1 :, 0] - (end - beg + 1) |
| self.rng.delete_range(rng) |
|
|
|
|
| class NameList: |
| """A class for representing and accessing a list of "names"--i.e., strings that tend to |
| have generic values, such that many repeat values are expected in a given list.""" |
|
|
| def __init__(self, init_list: List[str] = []): |
| self._reindex(init_list) |
|
|
| def _reindex(self, init_list: List[str]): |
| self.unique_names = [] |
| self.name_indicies = dict() |
| self.index_use = dict() |
| self.indices = ArrayList(1, dtype=int) |
| for name in init_list: |
| self.append(name) |
|
|
| def copy(self): |
| new_list = NameList() |
| new_list.unique_names = self.unique_names.copy() |
| new_list.name_indicies = self.name_indicies.copy() |
| new_list.index_use = self.index_use.copy() |
| new_list.indices = self.indices.copy() |
| return new_list |
|
|
| def _check_index(self): |
| L = len(self.unique_names) |
| I = len(self.index_use) |
| if (L > 2 * I) and (L - I > 10): |
| self._reindex([self[i] for i in range(len(self))]) |
|
|
| def __getitem__(self, i: int): |
| try: |
| idx = self.indices[i].item() |
| except IndexError as e: |
| raise IndexError(f"index {i} out of range for nameList\n" + str(e)) |
| return self.unique_names[idx] |
|
|
| def __setitem__(self, i: int, new_name: str): |
| try: |
| idx = self.indices[i] |
| except IndexError as e: |
| raise IndexError(f"index {i} out of range for nameList\n" + str(e)) |
| self.index_use[idx] = self.index_use[idx] - 1 |
| if self.index_use[idx] == 0: |
| del self.index_use[idx] |
| if new_name not in self.name_indicies: |
| idx = len(self.name_indicies) |
| self.name_indicies[new_name] = idx |
| self.unique_names.append(new_name) |
| else: |
| idx = self.name_indicies[new_name] |
| self.indices[i] = idx |
| self._update_use(idx, 1) |
| self._check_index() |
|
|
| def __str__(self): |
| return str([self[i] for i in range(len(self))]) |
|
|
| def __len__(self): |
| return len(self.indices) |
|
|
| def _update_use(self, idx, delta): |
| self.index_use[idx] = self.index_use.get(idx, 0) + delta |
| if self.index_use[idx] <= 0: |
| del self.index_use[idx] |
|
|
| def _get_name_index(self, name: str): |
| if name not in self.name_indicies: |
| idx = len(self.name_indicies) |
| self.name_indicies[name] = idx |
| self.unique_names.append(name) |
| else: |
| idx = self.name_indicies[name] |
| return idx |
|
|
| def append(self, name: str): |
| idx = self._get_name_index(name) |
| self.indices.append(idx) |
| self.index_use[idx] = self.index_use.get(idx, 0) + 1 |
|
|
| def insert(self, i: int, new_string: str): |
| idx = self._get_name_index(new_string) |
| self.indices.insert(i, idx) |
| self.index_use[idx] = self.index_use.get(idx, 0) + 1 |
|
|
| def pop(self, i: int): |
| idx = self.indices.pop(i).item() |
| val = self.unique_names[idx] |
| self._update_use(idx, -1) |
| self._check_index() |
| return val |
|
|
| def delete_range(self, rng: range): |
| for i in reversed(sorted(rng)): |
| self.pop(i) |
|
|
|
|
| class ArrayList: |
| def __init__(self, ndims: int, dtype: type, length: int = 0, val=0): |
| if ndims == 1: |
| self._array = np.ndarray(shape=(max(length, 2)), dtype=dtype) |
| else: |
| self._array = np.ndarray(shape=(max(length, 2), ndims), dtype=dtype) |
| self.ndims = ndims |
| self._array[:] = val |
| self.length = length |
| |
| self.array = self._array[: self.length] |
|
|
| def convert_negative_slice(self, slice_obj): |
| start = slice_obj.start if slice_obj.start is not None else 0 |
| stop = slice_obj.stop if slice_obj.stop is not None else self.length |
|
|
| if start < 0: |
| start = self.length + start |
| if stop < 0: |
| stop = self.length + stop |
|
|
| return slice(start, stop, slice_obj.step) |
|
|
| def copy(self): |
| new_list = ArrayList(ndims=self.ndims, dtype=self.array.dtype, length=len(self)) |
| new_list[:] = self[:] |
| return new_list |
|
|
| def __len__(self): |
| return self.length |
|
|
| def capacity(self): |
| return self._array.shape[0] |
|
|
| def __getitem__(self, i: int): |
| return self.array[i] |
|
|
| def __setitem__(self, i: int, row: list): |
| self.array[i] = row |
|
|
| def resize(self, delta): |
| |
| new_length = self.length + delta |
| cap = self._array.shape[0] |
| if (new_length > cap) or (new_length < cap / 3): |
| new_capacity = 2 * new_length |
| self._resize(new_capacity) |
| self.length = new_length |
| self.array = self._array[: self.length] |
|
|
| def _resize(self, new_size): |
| if self.ndims == 1: |
| self._array.resize((new_size), refcheck=False) |
| else: |
| self._array.resize((new_size, self.ndims), refcheck=False) |
|
|
| def items(self): |
| for i in range(self.length): |
| yield self.array[i, :] |
|
|
| def append(self, row: list): |
| self.resize(1) |
| self.array[-1] = row |
|
|
| def insert(self, i: int, row: list): |
| """Insert the row such that it ends up being at index ``i`` in the new arrayList""" |
| |
| self.resize(1) |
|
|
| |
| self.array[i + 1 :] = self.array[i:-1] |
|
|
| |
| self.array[i] = row |
|
|
| def pop(self, i: int): |
| """Remove and return element at index i""" |
|
|
| |
| row = self.array[i].copy() |
|
|
| |
| self.array[i:-1] = self.array[i + 1 :] |
|
|
| |
| self.resize(-1) |
|
|
| return row |
|
|
| def delete_range(self, rng: range): |
| i, j = min(rng), max(rng) |
|
|
| |
| cut_length = j - i + 1 |
| new_length = len(self) - cut_length |
| self.array[i:new_length] = self.array[j + 1 :] |
|
|
| |
| self.resize(-cut_length) |
|
|
| def __str__(self): |
| return str([self[i] for i in range(len(self))]) |
|
|
|
|
| @dataclass |
| class HierarchicList: |
| """A utility class that represents a hierarchy of lists. Each level represents |
| a list of elements, each element having a set of properties (each property being |
| stored as an array-like object over elements). Further, each element has a number |
| of children corresponding to a range of elements in a lower-hierarhy list.""" |
|
|
| _properties: dict |
| _parent_list: HierarchicList |
| _child_list: HierarchicList |
| _num_children: ArrayList |
| _child_offset: ArrayList |
|
|
| def __init__( |
| self, |
| properties: dict, |
| parent_list: HierarchicList = None, |
| num_children: ArrayList = ArrayList(1, dtype=int), |
| ): |
| self._properties = dict() |
| for key in properties: |
| self._properties[key] = properties[key].copy() |
| self._parent_list = parent_list |
| if self._parent_list is not None: |
| self._parent_list._child_list = self |
| self._child_list = None |
| self._num_children = num_children.copy() if num_children is not None else None |
| |
| self._child_offset = None |
|
|
| def copy(self): |
| new_list = HierarchicList( |
| self._properties, self._parent_list, self._num_children |
| ) |
| new_list._child_list = self._child_list |
| if self._child_offset is None: |
| new_list._child_offset = None |
| else: |
| new_list._child_offset = self._child_offset.copy() |
| return new_list |
|
|
| def set_parent(self, parent_list: HierarchicList): |
| self._parent_list = parent_list |
|
|
| def child_index(self, i: int, at: int): |
| if self._child_offset is not None: |
| return self._child_offset[i] + at |
| return self._num_children[0:i].sum() + at |
|
|
| def reindex(self): |
| if self._num_children is not None: |
| self._child_offset = ArrayList( |
| 1, dtype=int, length=len(self._num_children), val=0 |
| ) |
| for i in range(1, len(self)): |
| self._child_offset[i] = ( |
| self._child_offset[i - 1] + self._num_children[i - 1] |
| ) |
|
|
| def append_child(self, properties): |
| self._num_children[len(self._num_children) - 1] += 1 |
| self._child_list.append(properties) |
|
|
| def insert_child(self, i: int, at: int, properties): |
| idx = self.child_index(i, at) |
| self._num_children[i] += 1 |
| self._child_offset = None |
| self._child_list.insert(idx, properties) |
| return idx |
|
|
| def delete_child(self, i: int, at: int): |
| idx = self.child_index(i, at) |
| self._num_children[i] -= 1 |
| self._child_offset = None |
| self._child_list.delete(idx) |
|
|
| def append(self, properties): |
| if set(properties.keys()) != set(self._properties.keys()): |
| raise Exception(f"unexpected set of attributes '{list(properties.keys())}") |
| for key, value in properties.items(): |
| self._properties[key].append(value) |
| if self._child_offset is not None: |
| self._child_offset.append( |
| self._child_offset[-1:].sum() + self._num_children[-1:].sum() |
| ) |
| if self._num_children is not None: |
| self._num_children.append(0) |
|
|
| def insert(self, i: int, properties): |
| if set(properties.keys()) != set(self._properties.keys()): |
| raise Exception(f"unexpected set of attributes '{list(properties.keys())}") |
| for key, value in properties.items(): |
| self._properties[key].insert(i, value) |
| if self._child_offset is not None: |
| if i >= len(self._child_offset): |
| off = self._child_offset[-1:].sum() + self._num_children[-1:].sum() |
| else: |
| off = self._child_offset[i] |
| self._child_offset.insert(i, off) |
| if self._num_children is not None: |
| self._num_children.insert(i, 0) |
|
|
| def delete(self, i: int): |
| for key in self._properties: |
| self._properties[key].pop(i) |
| if self._num_children is not None and self._num_children[i] != 0: |
| for at in range(self._num_children[i] - 1, -1, -1): |
| self.delete_child(i, at) |
| self._num_children.pop(i) |
| self._child_offset = None |
|
|
| def delete_range(self, rng: range): |
| for key in self._properties: |
| self._properties[key].delete_range(rng) |
| |
| for i in reversed(sorted(rng)): |
| if self._num_children is not None and self._num_children[i] != 0: |
| idx = self.child_index(i, 0) |
| self._child_list.delete_range( |
| self, range(idx, idx + self._num_children[i]) |
| ) |
| self._num_children[i] = 0 |
| self._child_offset = None |
|
|
| def __len__(self): |
| for key in self._properties: |
| return len(self._properties[key]) |
| return None |
|
|
| def __getitem__(self, i: str): |
| return self._properties[i] |
|
|
| |
| |
|
|
| def num_children(self, i: int): |
| return self._num_children[i] |
|
|
| def has_children(self, i: int): |
| return self._num_children is not None and self._num_children[i] |
|
|
| def __str__(self): |
| string = "Properties:\n" |
| for key in self._properties: |
| string += f"{key}: {str(self._properties[key])}\n" |
| string += f"num_children: {str(self._num_children)}\n" |
| string += f"child_offset: {str(self._child_offset)}\n" |
| string += "----\n" |
| string += str(self._child_list) |
| return string |
|
|
|
|
| @dataclass |
| class System: |
| """A class for storing, accessing, managing, and manipulating a molecular |
| system's structure, sequence, and topological information. The class is |
| organized as a hierarchy of objects: |
| |
| System: top-level class containing all information about a molecular system |
| -> Chain: a sub-portion of the System; for polymers this is generally a |
| chemically connected molecular graph belong to a System (e.g., for |
| protein complexes, this would be one of the proteins). |
| -> Residue: a generally chemically-connected molecular unit (for polymers, |
| the repeating unit), belonging to a Chain. |
| -> Atom: an atom belonging to a Residue with zero, one, or more locations. |
| -> AtomLocation: the location of an Atom (3D coordinates and other information). |
| |
| Attributes: |
| name (str): given name for System |
| _chains (list): a list of Chain objects |
| _entities (dict): a dictionary of SystemEntity objects, with keys being entity IDs |
| _chain_entities (list): `chain_entities[ci]` stores entity IDs (i.e., keys into |
| `entities`) corresponding to the entity for chain `ci` |
| _extra_models (list): a list of hierarchicList object, representing locations |
| for alternative models |
| _labels (dict): a dictionary of residue labels. A label is a string value, |
| under some category (also a string), associated with a residue. E.g., |
| the category could be "SSE" and the value could be "H" or "S". If entry |
| `labels[category][gti]` exists and is equal to `value`, this means that |
| residue with global template index `gti` has the label `category:value`. |
| _selections (dict): a dictionary of selections. Keys are selection names and |
| values are lists of corresponding gti indices. |
| _assembly_info (SystemAssemblyInfo): information on symmetric assemblies that can |
| be constructed from components of the molecular system. See ``SystemAssemblyInfo``. |
| """ |
|
|
| name: str |
| _chains: HierarchicList |
| _residues: HierarchicList |
| _atoms: HierarchicList |
| _locations: HierarchicList |
| _entities: Dict[int, SystemEntity] |
| _chain_entities: List[int] |
| _extra_models: List[HierarchicList] |
| _labels: Dict[str, Dict[int, str]] |
| _selections: Dict[str, List[int]] |
| _assembly_info: SystemAssemblyInfo |
|
|
| def __init__(self, name: str = "system"): |
| self.name = name |
| self._chains = HierarchicList( |
| properties={ |
| "cid": StringList(), |
| "segid": StringList(), |
| "authid": StringList(), |
| } |
| ) |
| self._residues = HierarchicList( |
| properties={ |
| "name": NameList(), |
| "resnum": ArrayList(1, dtype=int), |
| "authresid": StringList(), |
| "icode": ArrayList(1, dtype="U1"), |
| }, |
| parent_list=self._chains, |
| ) |
| self._atoms = HierarchicList( |
| properties={"name": NameList(), "het": ArrayList(1, dtype=bool)}, |
| parent_list=self._residues, |
| ) |
| self._locations = HierarchicList( |
| properties={ |
| "coor": ArrayList(5, dtype=float), |
| "alt": ArrayList(1, dtype="U1"), |
| }, |
| parent_list=self._atoms, |
| num_children=None, |
| ) |
| self._entities = dict() |
| self._chain_entities = [] |
| self._extra_models = [] |
| self._labels = dict() |
| self._selections = dict() |
| self._assembly_info = SystemAssemblyInfo() |
|
|
| def _reindex(self): |
| self._chains.reindex() |
| self._residues.reindex() |
| self._atoms.reindex() |
| self._locations.reindex() |
|
|
| def _print_indexing(self): |
| for chain in self.chains(): |
| off = self._chains.child_index(chain._ix, 0) |
| num = self._chains.num_children(chain._ix) |
| print(f"chain {chain._ix}, {chain}: [{off} - {off + num})") |
| for residue in chain.residues(): |
| off = self._residues.child_index(residue._ix, 0) |
| num = self._residues.num_children(residue._ix) |
| print(f"\tresidue {residue._ix}, {residue}: [{off} - {off + num})") |
| for atom in residue.atoms(): |
| off = self._atoms.child_index(atom._ix, 0) |
| num = self._atoms.num_children(atom._ix) |
| print(f"\t\tatom {atom._ix}, {atom}: [{off} - {off + num})") |
| for loc in atom.locations(): |
| has_children = self._locations.has_children(loc._ix) |
| print( |
| f"\t\t\tlocation {loc._ix}, {loc}: has children? {has_children}" |
| ) |
|
|
| @classmethod |
| def from_XCS( |
| cls, |
| X: torch.Tensor, |
| C: torch.Tensor, |
| S: torch.Tensor, |
| alternate_alphabet: str = None, |
| ) -> System: |
| """Convert an XCS set of pytorch tensors to a new System object. |
| |
| B is batch size (Function only handles batch size of one now) |
| N is the number of residues |
| |
| Args: |
| X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. |
| `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. |
| C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes |
| positions as 0 when masked, positive integers for chain indices, |
| and negative integers to represent missing residues of the |
| corresponding positive integers. |
| S (torch.LongTensor): Sequence with shape `(1, num_residues)`. |
| alternate_alphabet (str, optional): Optional alternative alphabet for |
| sequence encoding. Otherwise the default alphabet is set in |
| `constants.AA20`.Amino acid alphabet for embedding. |
| Returns: |
| System: A System object with the new XCS data. |
| |
| """ |
| alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet |
| all_atom = X.shape[2] == 14 |
|
|
| assert X.shape[0] == 1 |
| assert C.shape[0] == 1 |
| assert S.shape[0] == 1 |
| assert X.shape[1] == S.shape[1] |
| assert C.shape[1] == C.shape[1] |
|
|
| X, C, S = [T.squeeze(0).cpu().data.numpy() for T in [X, C, S]] |
|
|
| chain_ids = np.abs(C) |
|
|
| atom_count = 0 |
| new_system = cls("system") |
|
|
| for i, chain_id in enumerate(np.unique(chain_ids)): |
| if chain_id == 0: |
| continue |
|
|
| chain_bool = chain_ids == chain_id |
| X_chain = X[chain_bool, :, :].tolist() |
| C_chain = C[chain_bool].tolist() |
| S_chain = S[chain_bool].tolist() |
|
|
| |
| chain = new_system.add_chain("A") |
| for chain_ix, (X_i, C_i, S_i) in enumerate(zip(X_chain, C_chain, S_chain)): |
| resname = polyseq.to_triple(alphabet[int(S_i)]) |
|
|
| |
| residue = chain.add_residue( |
| resname, chain_ix + 1, str(chain_ix + 1), " " |
| ) |
|
|
| if C_i > 0: |
| atom_names = constants.ATOMS_BB |
|
|
| if all_atom and resname in constants.AA_GEOMETRY: |
| atom_names = ( |
| atom_names + constants.AA_GEOMETRY[resname]["atoms"] |
| ) |
|
|
| for atom_ix, atom_name in enumerate(atom_names): |
| x, y, z = X_i[atom_ix] |
| atom_count += 1 |
| residue.add_atom(atom_name, False, x, y, z, 1.0, 0.0, " ") |
|
|
| |
| for ci, chain in enumerate(new_system.chains()): |
| seq = [None] * chain.num_residues() |
| het = [None] * chain.num_residues() |
| for ri, res in enumerate(chain.residues()): |
| seq[ri] = res.name |
| het[ri] = all(a.het for a in res.atoms()) |
| entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) |
| entity = SystemEntity( |
| entity_type, f"chain {chain.cid}", polymer_type, seq, het |
| ) |
| new_system.add_new_entity(entity, [ci]) |
|
|
| return new_system |
|
|
| def to_XCS( |
| self, |
| all_atom: bool = False, |
| batch_dimension: bool = True, |
| mask_unknown: bool = True, |
| unknown_token: int = 0, |
| reorder_chain: bool = True, |
| alternate_alphabet=None, |
| alternate_atoms=None, |
| get_indices=False, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Convert System object to XCS format. |
| |
| `C` tensor has shape [num_residues], where it codes positions as 0 |
| when masked, positive integers for chain indices, and negative integers |
| to represent missing residues of the corresponding positive integers. |
| |
| `S` tensor has shape [num_residues], it will map residue amino acid to alphabet integers. |
| If it is not found in `alphabet`, it will default to `unknown_token`. Set `mask_unknown` to true if |
| also want to mask `unk residue` in `chain_map` |
| |
| This function takes into account missing residues and updates chain_map |
| accordingly. |
| |
| Args: |
| system (type): generate System object to convert. |
| all_atom (bool): Include side chain atoms. Default is `False`. |
| batch_dimension (bool): Include a batch dimension. Default is `True`. |
| mask_unknown (bool): Mask residues not found in the alphabet. Default is |
| `True`. |
| unknown_token (int): Default token index if a residue is not found in |
| the alphabet. Default is `0`. |
| reorder_chain (bool): If set to true will start indexing chain at 1, |
| else will use the alphabet index (Default: True) |
| altenate_alphabet (str): Alternative alphabet if not `None`. |
| alternate_atoms (list): Alternate atom name subset for `X` if not `None`. |
| get_indices (bool): Also return the location indices corresponding to the |
| returned `X` tensor. |
| |
| Returns: |
| X (torch.Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. |
| `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. |
| C (torch.LongTensor): Chain map with shape `(1, num_residues)`. It codes |
| positions as 0 when masked, positive integers for chain indices, |
| and negative integers to represent missing residues of the |
| corresponding positive integers. |
| S (torch.LongTensor): Sequence with shape `(1, num_residues)`. |
| location_indices (np.ndaray, optional): location indices corresponding to |
| the coordinates in `X`. |
| |
| """ |
| alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet |
|
|
| |
| C = [] |
| for ch_id, chain in enumerate(self.chains()): |
| ch_str = chain.cid |
| if ch_str in list(constants.CHAIN_ALPHABET): |
| map_ch_id = list(constants.CHAIN_ALPHABET).index(ch_str) |
| else: |
| |
| map_ch_id = np.setdiff1d(np.arange(1, len(constants.CHAIN_ALPHABET)), np.unique(C))[0] |
| |
| if reorder_chain: |
| map_ch_id = ch_id + 1 |
| C += [map_ch_id] * chain.num_residues() |
|
|
| |
| oneLetterSeq = self.sequence(format="one-letter-string") |
|
|
| if len(oneLetterSeq) != len(C): |
| logging.warning("Warning, System and chain_map length don't agree") |
|
|
| |
| atom_names = None |
| if all_atom: |
| num_atoms = 14 if all_atom else 4 |
| else: |
| if alternate_atoms is not None: |
| atom_names = alternate_atoms |
| else: |
| atom_names = constants.ATOMS_BB |
| num_atoms = len(atom_names) |
| atom_names = {a: i for (i, a) in enumerate(atom_names)} |
| num_residues = self.num_residues() |
| X = np.zeros([num_residues, num_atoms, 3]) |
| location_indices = ( |
| np.zeros([num_residues * num_atoms], dtype=int) if get_indices else None |
| ) |
|
|
| S = [] |
| for i in range(num_residues): |
| |
| is_mask = False |
|
|
| |
| if oneLetterSeq[i] in list(alphabet): |
| S.append(alphabet.index(oneLetterSeq[i])) |
| else: |
| S.append(unknown_token) |
| if mask_unknown: |
| is_mask = True |
|
|
| |
| res = self.get_residue(i) |
| if res is None or not res.has_structure(): |
| is_mask = True |
|
|
| |
| if is_mask: |
| |
| C[i] = -abs(C[i]) |
| else: |
| |
| if all_atom: |
| code3 = constants.AA20_1_TO_3[oneLetterSeq[i]] |
| atom_names = ( |
| constants.ATOMS_BB + constants.AA_GEOMETRY[code3]["atoms"] |
| ) |
| atom_names = {a: i for (i, a) in enumerate(atom_names)} |
|
|
| X[ |
| i, : |
| ] = np.nan |
| num_rem = len(atom_names) |
| for atom in res.atoms(): |
| name = System.protein_backbone_atom_type(atom.name, False, True) |
| if name is None: |
| name = atom.name |
| ix = atom_names.get(name, None) |
| if ix is None or not np.isnan(X[i, ix, 0]): |
| continue |
| for loc in atom.locations(): |
| X[i, ix] = loc.coors |
| if location_indices is not None: |
| location_indices[i * num_atoms + ix] = loc.get_index() |
| num_rem -= 1 |
| break |
| if num_rem == 0: |
| break |
| if num_rem != 0: |
| C[i] = -abs(C[i]) |
| X[i, :] = 0 |
| np.nan_to_num(X[i, :], copy=False, nan=0) |
|
|
| |
| X = torch.tensor(X).float() |
| C = torch.tensor(C).type(torch.long) |
| S = torch.tensor(S).type(torch.long) |
|
|
| |
| if batch_dimension: |
| X = X.unsqueeze(0) |
| C = C.unsqueeze(0) |
| S = S.unsqueeze(0) |
|
|
| if location_indices is not None: |
| return X, C, S, location_indices |
|
|
| return X, C, S |
|
|
| def update_with_XCS(self, X, C=None, S=None, alternate_alphabet=None): |
| """Update the System with XCS coordinates. NOTE: if the System has |
| more than one model, and if the shape of the System changes (i.e., |
| atoms are added or deleted), the additional models will be wiped. |
| |
| Args: |
| X (Tensor): Coordinates with shape `(1, num_residues, num_atoms, 3)`. |
| `num_atoms` will be 14 if `all_atom=True` or 4 otherwise. |
| C (LongTensor): Chain map with shape `(1, num_residues)`. It codes |
| positions as 0 when masked, positive integers for chain indices, |
| and negative integers to represent missing residues of the |
| corresponding positive integers. Defaults to the current System's |
| chain map. |
| S (LongTensor): Sequence with shape `(1, num_residues)`. Defaults to |
| the current System's sequence. |
| """ |
| if C is None or S is None: |
| _, _C, _S = self.to_XCS() |
| if C is None: |
| C = _C |
| if S is None: |
| S = _S |
|
|
| |
| if not ( |
| (X.shape[1] == self.num_residues()) |
| and (X.shape[1] == C.shape[1]) |
| and (X.shape[1] == S.shape[1]) |
| ): |
| raise Exception( |
| f"input tensor sizes {X.shape}, {C.shape}, and {S.shape}, disagree with System size {self.num_residues()}" |
| ) |
|
|
| def _process_inputs(T): |
| if T is not None: |
| if len(T.shape) == 2 or len(T.shape) == 4: |
| T = T.squeeze(0) |
| T = T.to("cpu").detach().numpy() |
| return T |
|
|
| X, C, S = map(_process_inputs, [X, C, S]) |
|
|
| shape_changed = False |
| alphabet = constants.AA20 if alternate_alphabet is None else alternate_alphabet |
| for i, res in enumerate(self.residues()): |
| |
| if not res.has_structure() or C[i] <= 0: |
| continue |
|
|
| |
| resname = "UNK" |
| if S is not None and S[i] < len(alphabet): |
| resname = polyseq.to_triple(alphabet[S[i]]) |
| |
| if res.name != resname: |
| res.rename(resname) |
|
|
| |
| atoms_sys = [atom.name for atom in res.atoms()] |
| atoms_XCS = constants.ATOMS_BB |
| if resname in constants.AA_GEOMETRY: |
| atoms_XCS = atoms_XCS + constants.AA_GEOMETRY[resname]["atoms"] |
| atoms_XCS = atoms_XCS[: X.shape[1]] |
| to_delete = [] |
| for ix_a, atom in enumerate(res.atoms()): |
| name = atom.name |
| if name not in atoms_XCS or name in atoms_sys[:ix_a]: |
| to_delete.append(atom) |
| if len(to_delete) > 0: |
| shape_changed = True |
| res.delete_atoms(to_delete) |
|
|
| |
| for x_id, atom_name in enumerate(atoms_XCS): |
| atom = res.find_atom(atom_name) |
| x, y, z = [X[i][x_id][k].item() for k in range(3)] |
| if atom is not None and atom.num_locations() > 0: |
| atom.x = x |
| atom.y = y |
| atom.z = z |
| else: |
| shape_changed = True |
| if atom is not None: |
| atom.add_location(x, y, z) |
| else: |
| res.add_atom(atom_name, False, x, y, z, 1.0, 0.0) |
|
|
| |
| if shape_changed: |
| self._extra_models = [] |
|
|
| def __str__(self): |
| return "system " + self.name |
|
|
| def chains(self): |
| """Chain iterator (generator function).""" |
| for ci in range(len(self._chains)): |
| yield ChainView(ci, self) |
|
|
| def get_chain(self, ci: int): |
| """Returns the chain by index. |
| |
| Args: |
| ci (int): Chain index (from 0) |
| |
| Returns: |
| ChainView object corresponding to the chain in question. |
| """ |
| return ChainView(ci, self) |
|
|
| def get_chain_by_id(self, cid: str, segid=False): |
| """Returns the chain by its string ID. |
| |
| Args: |
| cid (str): Chain ID. |
| segid (bool, optional): If set to True (default is False) will |
| return the chain with the matching segment ID and not chain ID. |
| |
| Returns: |
| ChainView object corresponding to the chain in question. |
| """ |
| for ci, chain in enumerate(self.chains()): |
| if (not segid and cid == chain.cid) or (segid and cid == chain.segid): |
| return ChainView(ci, self) |
| return None |
|
|
| def get_chains(self): |
| """Returns the list of all chains.""" |
| return [ChainView(ci, self) for ci in range(len(self._chains))] |
|
|
| def get_chains_of_entity(self, entity_id: int, by=None): |
| """Returns the list of chains that correspond to the given entity ID. |
| |
| Args: |
| entity_id (int): Entity ID. |
| by (str, optional): If specified as "index", will return a |
| list of chain indices instead of ChainView objects. |
| |
| Returns: |
| List of ChainView objects or chain indices. |
| """ |
| cixs = [ci for (ci, eid) in enumerate(self._chain_entities) if entity_id == eid] |
| if by == "index": |
| return cixs |
| return [ChainView(ci, self) for ci in cixs] |
|
|
| def residues(self): |
| """Residue iterator (generator function).""" |
| for chain in self.chains(): |
| for residue in chain.residues(): |
| yield residue |
|
|
| def get_residue(self, gti: int): |
| """Returns the residue at the given global index. |
| |
| Args: |
| gti (int): Global residue index. |
| |
| Returns: |
| ResidueView object corresponding to the index. |
| """ |
| if gti < 0: |
| raise Exception(f"negative residue index: {gti}") |
| off = 0 |
| for chain in self.chains(): |
| nr = chain.num_residues() |
| if gti < off + nr: |
| return chain.get_residue(gti - off) |
| off = off + nr |
| raise Exception( |
| f"residue index {gti} out of range for System, which has {self.num_residues()} residues" |
| ) |
|
|
| def atoms(self): |
| """Iterator of atoms in this System (generator function).""" |
| for chain in self.chains(): |
| for residue in chain.residues(): |
| for atom in residue.atoms(): |
| yield atom |
|
|
| def get_atom(self, aidx: int): |
| """Returns the atom at the given global atom index. |
| |
| Args: |
| gti (int): Global atom index. |
| |
| Returns: |
| AtomView object corresponding to the index. |
| """ |
| if aidx < 0: |
| raise Exception(f"negative atom index: {aidx}") |
| off = 0 |
| for chain in self.chains(): |
| na = chain.num_atoms() |
| if aidx < off + na: |
| return chain.get_atom(aidx - off) |
| off = off + na |
| raise Exception( |
| f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" |
| ) |
|
|
| def locations(self): |
| """Iterator of atoms in this System (generator function).""" |
| for chain in self.chains(): |
| for residue in chain.residues(): |
| for atom in residue.atoms(): |
| for loc in atom.locations(): |
| yield loc |
|
|
| def _new_locations(self): |
| new_locs = self._locations.copy() |
| for li in range(len(new_locs)): |
| new_locs["coor"][li] = [np.nan] * 5 |
| return new_locs |
|
|
| def select(self, expression: str, left_associativity: bool = True): |
| """Evalates the given selection expression and returns all atoms |
| involved in the result as a list of AtomView's. |
| |
| Args: |
| expression (str): selection expression. |
| left_associativity (bool, optional): determines whether operators |
| in the expression are left-associative. |
| |
| Returns: |
| List of AtomView's. |
| """ |
| val, selex_info = self._select( |
| expression, left_associativity=left_associativity |
| ) |
|
|
| |
| result = [selex_info["all_atoms"][i].atom for i in sorted(val)] |
|
|
| return result |
|
|
| def select_residues( |
| self, |
| expression: str, |
| gti: bool = False, |
| allow_unstructured=False, |
| left_associativity: bool = True, |
| ): |
| """Evalates the given selection expression and returns all residues with any |
| atoms involved in the result as a list of ResidueView's or list of gti's. |
| |
| Args: |
| expression (str): selection expression. |
| gti (bool): if True (default is False), will return a list of gti |
| instead of a list of ResidueView's. |
| allow_unstructured (bool): If True (default is False), will allow |
| unstructured residues to be selected. |
| left_associativity (bool, optional): determines whether operators |
| in the expression are left-associative. |
| |
| Returns: |
| List of ResidueView's or gti's (ints). |
| """ |
| val, selex_info = self._select( |
| expression, |
| unstructured=allow_unstructured, |
| left_associativity=left_associativity, |
| ) |
|
|
| |
| if gti: |
| result = sorted(set([selex_info["all_atoms"][i].rix for i in val])) |
| else: |
| residues = dict() |
| for i in val: |
| a = selex_info["all_atoms"][i] |
| residues[a.rix] = a.atom.residue |
| result = [residues[rix] for rix in sorted(residues.keys())] |
|
|
| return result |
|
|
| def select_chains( |
| self, expression: str, allow_unstructured=False, left_associativity: bool = True |
| ): |
| """Evalates the given selection expression and returns all chains with any |
| atoms involved in the result as a list of ChainView's. |
| |
| Args: |
| expression (str): selection expression. |
| allow_unstructured (bool): If True (default is False), will allow |
| unstructured chains to be selected. |
| left_associativity (bool, optional): determines whether operators |
| in the expression are left-associative. |
| |
| Returns: |
| List of ResidueView's or gti's (ints). |
| """ |
| val, selex_info = self._select( |
| expression, |
| unstructured=allow_unstructured, |
| left_associativity=left_associativity, |
| ) |
|
|
| |
| chains = dict() |
| for i in val: |
| a = selex_info["all_atoms"][i] |
| chains[a.cix] = a.atom.chain |
| result = [chains[rix] for rix in sorted(chains.keys())] |
|
|
| return result |
|
|
| def _select( |
| self, |
| expression: str, |
| unstructured: bool = False, |
| left_associativity: bool = True, |
| ): |
| |
| @dataclass(frozen=True) |
| class MappableAtom: |
| atom: AtomView |
| aix: int |
| rix: int |
| cix: int |
|
|
| def __hash__(self) -> int: |
| return self.aix |
|
|
| |
| all_atoms = [None] * self.num_atoms() |
| cix, rix, aix = 0, 0, 0 |
| for chain in self.chains(): |
| for residue in chain.residues(): |
| for atom in residue.atoms(): |
| all_atoms[aix] = MappableAtom(atom, aix, rix, cix) |
| aix = aix + 1 |
|
|
| |
| |
| |
| if residue.num_atoms() == 0: |
| view = DummyAtomView(residue) |
| view.dummy = True |
| |
| all_atoms.append(None) |
| all_atoms[aix] = MappableAtom(view, aix, rix, cix) |
| aix = aix + 1 |
| rix = rix + 1 |
| cix = cix + 1 |
|
|
| _selex_info = {"all_atoms": all_atoms} |
| _selex_info["all_indices_set"] = set([a.aix for a in all_atoms]) |
|
|
| |
| |
| tree = ExpressionTreeEvaluator( |
| ["hyd", "all", "none"], |
| ["not", "byres", "bychain", "first", "last", |
| "chain", "authchain", "segid", "namesel", "gti", "resix", "resid", |
| "authresid", "resname", "re", "x", "y", "z", "b", "icode", "name"], |
| ["and", "or", "around", "saround"], |
| eval_function=partial(self._selex_eval, _selex_info), |
| left_associativity=left_associativity, |
| debug=False, |
| ) |
| |
|
|
| |
| val = tree.evaluate(expression) |
|
|
| |
| |
| |
| if not unstructured: |
| val = { |
| i for i in val if not hasattr(_selex_info["all_atoms"][i].atom, "dummy") |
| } |
|
|
| return val, _selex_info |
|
|
| def save_selection( |
| self, |
| expression: Optional[str] = None, |
| gti: Optional[List[int]] = None, |
| selname: str = "_default", |
| allow_unstructured=False, |
| left_associativity: bool = True, |
| ): |
| """Performs a selection on the System according to the given |
| selection string and saves the indices of residues involved in |
| the result (global template indices) under the given name. |
| |
| Args: |
| expression (str): (optional) selection expression. |
| gti (list of int): (optional) list of gti to define selection expression |
| selname (str): selection name. |
| allow_unstructured (bool): If True (default is False), will allow |
| unstructured residues to be selected. |
| left_associativity (bool, optional): determines whether operators |
| in the expression are left-associative. |
| """ |
| if gti is not None: |
| if expression is not None: |
| warnings.warn( |
| f"Expression and gti are both not null, expression will be ignored" |
| f" and gti will be used!" |
| ) |
| result = sorted(gti) |
| else: |
| result = self.select_residues( |
| expression, |
| allow_unstructured=allow_unstructured, |
| left_associativity=left_associativity, |
| gti=True, |
| ) |
|
|
| |
| self._selections[selname] = result |
|
|
| def get_selected(self, selname: str = "_default"): |
| """Returns the list of gti saved under the specified name. |
| |
| Args: |
| selname (str): selection name. |
| |
| Returns: |
| List of global template indices. |
| """ |
| if selname not in self._selections: |
| raise Exception( |
| f"selection by name '{selname}' does not exist in the System" |
| ) |
| return self._selections[selname] |
|
|
| def has_selection(self, selname: str = "_default"): |
| """Returns whether the given named selection exists. |
| |
| Args: |
| selname (str): selection name. |
| |
| Returns: |
| Whether the selection exists in the System. |
| """ |
| return selname in self._selections |
|
|
| def get_selection_names(self): |
| """Returns the list of all currently stored named selections.""" |
| return list(self._selections.keys()) |
|
|
| def remove_selection(self, selname: str = "_default"): |
| """Deletes the selection under the specified name. |
| |
| Args: |
| selname (str): selection name. |
| """ |
| if selname not in self._selections: |
| raise Exception( |
| f"selection by name '{selname}' does not exist in the System" |
| ) |
| del self._selections[selname] |
|
|
| def _selex_eval(self, _selex_info, op: str, left, right): |
| def _is_numeric(string: str) -> bool: |
| try: |
| float(string) |
| return True |
| except ValueError: |
| return False |
|
|
| def _is_int(string: str) -> bool: |
| try: |
| int(string) |
| return True |
| except ValueError: |
| return False |
|
|
| def _unpack_operands(operands, dests): |
| assert len(operands) == len(dests) |
| unpacked = [None] * len(operands) |
| succ = True |
| for i, (operand, dest) in enumerate(zip(operands, dests)): |
| if dest is None: |
| if operand is not None: |
| succ = False |
| break |
| elif dest == "result": |
| if not (isinstance(operand, dict) and "result" in operand): |
| succ = False |
| break |
| unpacked[i] = operand["result"] |
| elif dest == "string": |
| if not (len(operand) == 1 and isinstance(operand[0], str)): |
| succ = False |
| break |
| unpacked[i] = operand[0] |
| elif dest == "strings": |
| if not ( |
| isinstance(operand, list) |
| and all([isinstance(val, str) for val in operands]) |
| ): |
| succ = False |
| break |
| unpacked[i] = operands |
| elif dest == "float": |
| if not (len(operand) == 1 and _is_numeric(operand[0])): |
| succ = False |
| break |
| unpacked[i] = float(operand[0]) |
| elif dest == "floats": |
| if not ( |
| isinstance(operand, list) |
| and all([_is_numeric(val) for val in operands]) |
| ): |
| succ = False |
| break |
| unpacked[i] = [float(val) for val in operands] |
| elif dest == "range": |
| test = _parse_range(operand) |
| if test is None: |
| succ = False |
| break |
| unpacked[i] = test |
| elif dest == "int": |
| if not (len(operand) == 1 and _is_int(operand[0])): |
| succ = False |
| break |
| unpacked[i] = int(operand[0]) |
| elif dest == "ints": |
| if not ( |
| isinstance(operand, list) |
| and all([_is_int(val) for val in operands]) |
| ): |
| succ = False |
| break |
| unpacked[i] = [int(val) for val in operands] |
| elif dest == "int_range": |
| test = _parse_int_range(operand) |
| if test is None: |
| succ = False |
| break |
| unpacked[i] = test |
| elif dest == "int_range_string": |
| test = _parse_int_range(operand, string=True) |
| if test is None: |
| succ = False |
| break |
| unpacked[i] = test |
| return unpacked, succ |
|
|
| def _parse_range(operands: list): |
| """Parses range information given a list of operands that were originally separated |
| by spaces. Allowed range expressiosn are of the form: `< n`, `> n`, `n:m` with |
| optional spaces allowed between operands.""" |
| if not ( |
| isinstance(operands, list) |
| and all([isinstance(opr, str) for opr in operands]) |
| ): |
| return None |
| operand = "".join(operands) |
| if operand.startswith(">") or operand.startswith("<"): |
| if not _is_numeric(operand[1:]): |
| return None |
| num = float(operand[1:]) |
| if operand.startswith(">"): |
| test = lambda x, cut=num: x > cut |
| else: |
| test = lambda x, cut=num: x < cut |
| elif ":" in operand: |
| parts = operand.split(":") |
| if (len(parts) != 2) or not all([_is_numeric(p) for p in parts]): |
| return None |
| parts = [float(p) for p in parts] |
| test = lambda x, lims=parts: lims[0] < x < lims[1] |
| elif _is_numeric(operand): |
| target = float(operand) |
| test = lambda x, t=target: x == t |
| else: |
| return None |
| return test |
|
|
| def _parse_int_range(operands: list, string: bool = False): |
| """Parses range of integers information given a list of operands that were |
| originally separated by spaces. Allowed range expressiosn are of the form: |
| `n`, `n-m`, `n+m`, with optional spaces allowed anywhere and combinations |
| also allowed (e.g., "n+m+s+r-p+a").""" |
| if not ( |
| isinstance(operands, list) |
| and all([isinstance(opr, str) for opr in operands]) |
| ): |
| return None |
| operand = "".join(operands) |
| operands = operand.split("+") |
| ranges = [] |
| for operand in operands: |
| m = re.fullmatch("(.*\d)-(.+)", operand) |
| if m: |
| if not all([_is_int(g) for g in m.groups()]): |
| return None |
| r = range(int(m.group(1)), int(m.group(2)) + 1) |
| ranges.append(r) |
| else: |
| if not _is_int(operand): |
| return None |
| if string: |
| ranges.append(set([operand])) |
| else: |
| ranges.append(set([int(operand)])) |
| if string: |
| ranges = [[str(x) for x in r] for r in ranges] |
| test = lambda x, ranges=ranges: any([x in r for r in ranges]) |
| return test |
|
|
| |
| result = set() |
| if op in ("and", "or"): |
| (Si, Sj), succ = _unpack_operands([left, right], ["result", "result"]) |
| if not succ: |
| return None |
| if op == "and": |
| result = set(Si).intersection(set(Sj)) |
| else: |
| result = set(Si).union(set(Sj)) |
| elif op == "not": |
| (_, S), succ = _unpack_operands([left, right], [None, "result"]) |
| if not succ: |
| return None |
| result = _selex_info["all_indices_set"].difference(S) |
| elif op == "all": |
| (_, _), succ = _unpack_operands([left, right], [None, None]) |
| if not succ: |
| return None |
| result = _selex_info["all_indices_set"] |
| elif op == "none": |
| (_, _), succ = _unpack_operands([left, right], [None, None]) |
| if not succ: |
| return None |
| elif op == "around": |
| (S, rad), succ = _unpack_operands([left, right], ["result", "float"]) |
| if not succ: |
| return None |
|
|
| |
| atom_indices = np.asarray( |
| [ |
| ai.aix |
| for ai in _selex_info["all_atoms"] |
| for xi in ai.atom.locations() |
| ] |
| ) |
| X_i = np.asarray( |
| [ |
| [xi.x, xi.y, xi.z] |
| for ai in _selex_info["all_atoms"] |
| for xi in ai.atom.locations() |
| ] |
| ) |
| X_j = np.asarray( |
| [ |
| [xi.x, xi.y, xi.z] |
| for j in S |
| for xi in _selex_info["all_atoms"][j].atom.locations() |
| ] |
| ) |
| D = np.sqrt(((X_j[np.newaxis, :, :] - X_i[:, np.newaxis, :]) ** 2).sum(-1)) |
| ix_match = (D <= rad).sum(1) > 0 |
| match_hits = atom_indices[ix_match] |
| result = set(match_hits.tolist()) |
| elif op == "saround": |
| (S, srad), succ = _unpack_operands([left, right], ["result", "int"]) |
| if not succ: |
| return None |
| for j in S: |
| aj = _selex_info["all_atoms"][j] |
| rj = aj.rix |
| for ai in _selex_info["all_atoms"]: |
| if aj.atom.residue.chain != ai.atom.residue.chain: |
| continue |
| ri = ai.rix |
| if abs(ri - rj) <= srad: |
| result.add(ai.aix) |
| elif op == "byres": |
| (_, S), succ = _unpack_operands([left, right], [None, "result"]) |
| if not succ: |
| return None |
| gtis = set() |
| for j in S: |
| gtis.add(_selex_info["all_atoms"][j].rix) |
| for a in _selex_info["all_atoms"]: |
| if a.rix in gtis: |
| result.add(a.aix) |
| elif op == "bychain": |
| (_, S), succ = _unpack_operands([left, right], [None, "result"]) |
| if not succ: |
| return None |
| cixs = set() |
| for j in S: |
| cixs.add(_selex_info["all_atoms"][j].cix) |
| for a in _selex_info["all_atoms"]: |
| if a.cix in cixs: |
| result.add(a.aix) |
| elif op in ("first", "last"): |
| (_, S), succ = _unpack_operands([left, right], [None, "result"]) |
| if not succ: |
| return None |
| if op == "first": |
| mi = min([_selex_info["all_atoms"][i].aix for i in S]) |
| else: |
| mi = max([_selex_info["all_atoms"][i].aix for i in S]) |
| result.add(mi) |
| elif op == "name": |
| (_, name), succ = _unpack_operands([left, right], [None, "string"]) |
| if not succ: |
| return None |
| for a in _selex_info["all_atoms"]: |
| if a.atom.name == name: |
| result.add(a.aix) |
| elif op in ("re", "hyd"): |
| if op == "re": |
| (_, regex), succ = _unpack_operands([left, right], [None, "string"]) |
| else: |
| (_, _), succ = _unpack_operands([left, right], [None, None]) |
| regex = "[0123456789]?H.*" |
| if not succ: |
| return None |
| ex = re.compile(regex) |
| for a in _selex_info["all_atoms"]: |
| if a.atom.name is not None and ex.fullmatch(a.atom.name): |
| result.add(a.aix) |
| elif op in ("chain", "authchain", "segid"): |
| (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) |
| if not succ: |
| return None |
| if op == "chain": |
| prop = "cid" |
| elif op == "authchain": |
| prop = "authid" |
| elif op == "segid": |
| prop = "segid" |
| for a in _selex_info["all_atoms"]: |
| if getattr(a.atom.residue.chain, prop) == match_id: |
| result.add(a.aix) |
| elif op == "resid": |
| (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) |
| if not succ: |
| return None |
| for a in _selex_info["all_atoms"]: |
| if test(a.atom.residue.num): |
| result.add(a.aix) |
| elif op in ("resname", "icode"): |
| (_, match_id), succ = _unpack_operands([left, right], [None, "string"]) |
| if not succ: |
| return None |
| if op == "resname": |
| prop = "name" |
| elif op == "icode": |
| prop = "icode" |
| for a in _selex_info["all_atoms"]: |
| if getattr(a.atom.residue, prop) == match_id: |
| result.add(a.aix) |
| elif op == "authresid": |
| (_, test), succ = _unpack_operands( |
| [left, right], [None, "int_range_string"] |
| ) |
| if not succ: |
| return None |
| for a in _selex_info["all_atoms"]: |
| if test(a.atom.residue.authid): |
| result.add(a.aix) |
| elif op == "gti": |
| (_, test), succ = _unpack_operands([left, right], [None, "int_range"]) |
| if not succ: |
| return None |
| for a in _selex_info["all_atoms"]: |
| if test(a.rix): |
| result.add(a.aix) |
| elif op in ("x", "y", "z", "b", "occ"): |
| (_, test), succ = _unpack_operands([left, right], [None, "range"]) |
| if not succ: |
| return None |
| prop = op |
| if op == "b": |
| prop = "B" |
| for a in _selex_info["all_atoms"]: |
| for loc in a.atom.locations(): |
| if test(getattr(loc, prop)): |
| result.add(a.aix) |
| break |
| elif op == "namesel": |
| (_, selname), succ = _unpack_operands([left, right], [None, "string"]) |
| if not succ: |
| return None |
| if selname not in self._selections: |
| return None |
| gtis = set(self._selections[selname]) |
| for a in _selex_info["all_atoms"]: |
| if a.rix in gtis: |
| result.add(a.aix) |
| else: |
| return None |
|
|
| return {"result": result} |
|
|
| def __getitem__(self, chain_idx: int): |
| """Returns the chain at the given index.""" |
| return self.get_chain(chain_idx) |
|
|
| def add_chain( |
| self, |
| cid: str, |
| segid: str = None, |
| authid: str = None, |
| entity_id: int = None, |
| auto_rename: bool = True, |
| at: int = None, |
| ): |
| """Adds a new chain to the System and returns a reference to it. |
| |
| Args: |
| cid (str): Chain ID. |
| segid (str): Segment ID. |
| authid (str): Author chain ID. |
| entity_id (int, optional): Entity ID of the entity corresponding to this chain. |
| auto_rename (bool, optional): If True, will pick a unique chain ID if the specified |
| one clashes with an already existing chain. |
| |
| Returns: |
| AtomView object corresponding to the index. |
| """ |
| if auto_rename: |
| cid = self._pick_unique_chain_name(cid) |
| if segid is None: |
| segid = cid |
| if authid is None: |
| authid = cid |
| if at is None: |
| at = self.num_chains() |
| self._chains.append({"cid": cid, "segid": segid, "authid": authid}) |
| self._chain_entities.append(entity_id) |
| else: |
| self._chains.insert(at, {"cid": cid, "segid": segid, "authid": authid}) |
| self._chain_entities.insert(at, entity_id) |
| return ChainView(at, self) |
|
|
| def _append_residue(self, name: str, num: int, authid: str, icode: str): |
| """Add a new residue to the end this System. Internal method, do not use. |
| |
| Args: |
| name (str): Residue name. |
| num (int): Residue number (i.e., residue ID). |
| authid (str): Author residue ID. |
| icode (str): Insertion code. |
| |
| Returns: |
| Global index to the newly added residue. |
| """ |
| self._chains.append_child( |
| {"name": name, "resnum": num, "authresid": authid, "icode": icode} |
| ) |
| return len(self._residues) - 1 |
|
|
| def _append_atom( |
| self, |
| name: str, |
| het: bool, |
| x: float = None, |
| y: float = None, |
| z: float = None, |
| occ: float = None, |
| B: float = None, |
| alt: str = None, |
| ): |
| """Adds a new atom to the end of this System. Internal method, do not use. |
| |
| Args: |
| name (str): Atom name. |
| het (bool): Whether it is a hetero-atom. |
| x, y, z (float): Atom location coordinates. |
| occ (float): Occupancy. |
| B (float): B-factor. |
| alt (str): Alternative position character. |
| |
| Returns: |
| Global index to the newly added atom. |
| """ |
| self._residues.append_child({"name": name, "het": het}) |
| return len(self._atoms) - 1 |
|
|
| def _append_location(self, x, y, z, occ, B, alt): |
| """Adds a location to the end of this System. Internal method, do not use. |
| |
| Args: |
| x, y, z (float): coordinates of the location. |
| occ (float): occupancy for the location. |
| B (float): B-factor for the location. |
| alt (str): alternative location character. |
| |
| Returns: |
| Global index to the newly added location. |
| """ |
| self._atoms.append_child({"coor": [x, y, z, occ, B], "alt": alt}) |
| return len(self._locations) - 1 |
|
|
| def add_new_entity(self, entity: SystemEntity, chain_indices: list): |
| """Adds a new entity to the list contained within the System and |
| assigns chains with provided indices to this entity. |
| |
| Args: |
| entity (SystemEntity): The new entity to add to the System. |
| chain_indices (list): a list of Chain indices for chains to |
| assign to this entity. |
| |
| Returns: |
| The entity ID of the newly added entity. |
| """ |
| new_entity_id = len(self._entities) |
| while new_entity_id in self._entities: |
| new_entity_id = new_entity_id + 1 |
| self._entities[new_entity_id] = entity |
| for ci in chain_indices: |
| self._chain_entities[ci] = new_entity_id |
| return new_entity_id |
|
|
| def delete_entity(self, entity_id: int): |
| """Deletes the entity with the specified ID. Takes care to unlink |
| any chains belonging to this entity from it. |
| |
| Args: |
| entity_id (int): Entity ID. |
| """ |
| chain_indices = self.get_chains_of_entity(entity_id) |
| for ci in chain_indices: |
| self._chain_entities[ci] = None |
| del self._entities[entity_id] |
|
|
| def _pick_unique_chain_name(self, hint: str, verbose=False): |
| goodNames = list( |
| "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" |
| ) |
| taken = set([chain.cid for chain in self.chains()]) |
|
|
| |
| for cid in [hint] + goodNames: |
| if cid not in taken: |
| return cid |
| if verbose: |
| warnings.warn( |
| "ran out of reasonable single-letter chain names, will use more than one character (PDB sctructure may be repeating chain IDs upon writing, but should still have unique segment IDs)!" |
| ) |
|
|
| |
| for i in range(-1, len(goodNames)): |
| |
| base = hint if i < 0 else goodNames[i : i + 1] |
| if base == "": |
| continue |
| for k in range(1000): |
| longName = f"{base}{k}" |
| if longName not in taken: |
| return longName |
| raise Exception( |
| "ran out of even multi-character chain names; PDB structure appears to have an enormous number of chains" |
| ) |
|
|
| def _ensure_unique_entity(self, ci: int): |
| """Any time we need to update some piece of information about a Chain that |
| relates to its entity (e.g., sequence info or hetero info), we cannot just |
| update it directly because other Chains may be pointing to the same entity. |
| This function checks for any other chains pointing to the same entity as the |
| specified chain, and (if so) assigns the given chain to a new (duplicate) |
| entity and returns its new ID. This clears the way updates of this Chain's entity. |
| |
| Args: |
| ci (int): Index of the Chain for which we are trying to update |
| entity information. |
| |
| Returns: |
| entity ID for either a newly created entity mapped to the Chain or its |
| initial entity ID if no other chains point to the same entity. |
| """ |
| chain = self.get_chain(ci) |
| entity_id = chain.get_entity_id() |
| if entity_id is None: |
| return entity_id |
|
|
| |
| unique = True |
| for other in self.chains(): |
| if (other != chain) and (entity_id == other.get_entity_id()): |
| unique = False |
| break |
| if unique: |
| return entity_id |
|
|
| |
| new_entity = copy.deepcopy(self._entities[entity_id]) |
| new_entity_id = self.add_new_entity(new_entity, [ci]) |
| return new_entity_id |
|
|
| def num_chains(self): |
| """Returns the number of chains in the System.""" |
| return len(self._chains) |
|
|
| def num_chains_of_entity(self, entity_id: int): |
| """Returns the number of chains of a given entity. |
| |
| Args: |
| entity_id (int): Entity ID. |
| |
| Returns: |
| number of chains mapping to the entity. |
| """ |
|
|
| return sum([entity_id == eid for eid in self._chain_entities]) |
|
|
| def num_molecules_of_entity(self, entity_id: int): |
| if self._entities[entity_id].is_polymer(): |
| return self.num_chains_of_entity(entity_id) |
| cixs = [ci for (ci, id) in enumerate(self._chain_entities) if id == entity_id] |
| return sum([self[ci].num_residues() for ci in cixs]) |
|
|
| def num_entities(self): |
| """Returns the number of entities in the System.""" |
| return len(self._entities) |
|
|
| def num_residues(self): |
| """Returns the number of residues in the System.""" |
| return len(self._residues) |
|
|
| def num_structured_residues(self): |
| """Returns the number of residues with any structure information.""" |
| return sum([chain.num_structured_residues() for chain in self.chains()]) |
|
|
| def num_atoms(self): |
| """Returns the number of atoms in the System.""" |
| return len(self._atoms) |
|
|
| def num_structured_atoms(self): |
| """Returns the number of atoms with any location information.""" |
| num = 0 |
| for chain in self.chains(): |
| for residue in chain.residues(): |
| for atom in residue.atoms(): |
| num = num + (atom.num_locations() > 0) |
| return num |
|
|
| def num_atom_locations(self): |
| """Returns the number of atom locations. Note that an atom can have |
| multiple (alternative) locations and this functions counts all. |
| """ |
| return len(self._locations) |
|
|
| def num_models(self): |
| """Returns the number of models in the System. A model is effectively |
| a conformation of the molecular system and each System object can have |
| an arbitrary number of different conformations. |
| """ |
| return len(self._extra_models) + 1 |
|
|
| def swap_model(self, i: int): |
| """Swaps the model at index `i` with the current model (i.e., the |
| model at index 0). |
| |
| Args: |
| i (int): Model index |
| """ |
| if i == 0: |
| return |
| if i < 0 or i >= self.num_models(): |
| raise Exception(f"model index {i} out of range") |
| tmp = self._locations |
| self._locations = self._extra_models[i - 1] |
| self._extra_models[i - 1] = tmp |
|
|
| def add_model(self, other: System): |
| """Adds a new model to the System by taking the current model from the |
| specified System `other`. Note that `other` and the present System |
| must have the same number of atom locations of matching atom and |
| residue names. |
| |
| Args: |
| other (System): The System to take the model from. |
| """ |
| if len(self._locations) != len(other._locations): |
| raise Exception( |
| f"System has {len(self._locations)} atom locations while {len(other._locations)} were provided" |
| ) |
| self._extra_models.append(other._locations.copy()) |
| self._extra_models[-1].set_parent(self._atoms) |
|
|
| def add_model_from_X(self, X: torch.Tensor): |
| """Adds a new model to the System with given coordinates. |
| |
| Args: |
| X (torch.Tensor): Coordinate tensor of shape |
| `(residues, atoms (4 or 14), coordinates (3))` |
| """ |
| if len(self._locations) != X.numel() / 3: |
| raise Exception( |
| f"System has {len(self._locations)} atom locations while provided tensor shape is {X.shape}" |
| ) |
| X = X.detach().cpu() |
| self._extra_models.append(self._locations.copy()) |
| self._extra_models[-1]["coor"][:, 0:3] = X.flatten(0, 1) |
| return None |
|
|
| def num_assemblies(self): |
| """Returns the number of biological assemblies defined in this System.""" |
| return len(self._assembly_info.assemblies) |
|
|
| @staticmethod |
| def from_CIF_string(cif_string: str): |
| """Initializes and returns a System object from a CIF string.""" |
| import io |
|
|
| f = io.StringIO(cif_string) |
| return System._read_cif(f)[0] |
|
|
| @staticmethod |
| def from_CIF(input_file: str): |
| """Initializes and returns a System object from a CIF file.""" |
| if input_file.endswith('.cif.gz'): |
| with gzip.open(input_file, 'rb') as f_in: |
| file_content = f_in.read() |
| file_stream = io.BytesIO(file_content) |
| f = io.TextIOWrapper(file_stream, encoding='utf-8') |
| if input_file.endswith('.cif'): |
| f = open(input_file, "r") |
| return System._read_cif(f)[0] |
|
|
| @staticmethod |
| def _read_cif(f, strict=False): |
| def _warn_or_error(strict: bool, msg: str): |
| if strict: |
| raise Exception(msg) |
| else: |
| warnings.warn(msg) |
|
|
| is_read = { |
| part: False for part in ["coors", "entities", "sequence", "entity_poly"] |
| } |
| category = "" |
| (in_loop, success) = (False, True) |
| peeked = sp.PeekedLine("", 0) |
| |
| num_of_mols = dict() |
|
|
| system = System("system") |
| while sp.peek_line(f, peeked): |
| if peeked.line.startswith("#"): |
| |
| sp.advance(f, peeked) |
| elif peeked.line.startswith("data_"): |
| |
| sp.advance(f, peeked) |
| elif peeked.line.startswith("loop_"): |
| in_loop = True |
| category = "" |
| sp.advance(f, peeked) |
| else: |
| (cat, name, val) = ("", "", "") |
| if peeked.line.startswith("_"): |
| (cat, name, val) = sp.star_item_parse(peeked.line) |
| if cat != category: |
| if category != "": |
| in_loop = False |
| category = cat |
|
|
| if (cat == "_entry") and (name == "id"): |
| if val != "": |
| system.name = val |
| sp.advance(f, peeked) |
| elif cat == "_entity_poly": |
| if is_read["entity_poly"]: |
| raise Exception("entity_poly block encountered multiple times") |
| tab = sp.star_read_data(f, ["entity_id", "type"], in_loop) |
| for row in tab: |
| ent_id = int(row[0]) - 1 |
| if ent_id not in system._entities: |
| system._entities[ent_id] = SystemEntity( |
| None, None, row[1], None, None |
| ) |
| else: |
| system._entities[ent_id]._polymer_type = row[1] |
| is_read["entity_poly"] = True |
| elif cat == "_entity": |
| if is_read["entities"]: |
| raise Exception( |
| f"entities block encountered multiple times: {peeked.line}" |
| ) |
| tab = sp.star_read_data( |
| f, |
| ["id", "type", "pdbx_description", "pdbx_number_of_molecules"], |
| in_loop, |
| ) |
| for row in tab: |
| ent_id = int(row[0]) - 1 |
| if ent_id not in system._entities: |
| system._entities[ent_id] = SystemEntity( |
| row[1], row[2], None, None, None |
| ) |
| else: |
| system._entities[ent_id]._type = row[1] |
| system._entities[ent_id]._desc = row[2] |
| if row[3].isnumeric(): |
| num_of_mols[ent_id] = int(row[3]) |
| is_read["entities"] = True |
| elif cat == "_entity_poly_seq": |
| if is_read["sequence"]: |
| raise Exception(f"sequence block encountered multiple times") |
| tab = sp.star_read_data( |
| f, ["entity_id", "num", "mon_id", "hetero"], in_loop |
| ) |
| (seq, het) = ([], []) |
| for i in range(len(tab)): |
| |
| seq.append(tab[i][2]) |
| het.append(tab[i][3].startswith("y")) |
| if (i == len(tab) - 1) or (tab[i][0] != tab[i + 1][0]): |
| ent_id = int(tab[i][0]) - 1 |
| system._entities[ent_id]._seq = seq |
| system._entities[ent_id]._het = het |
| (seq, het) = ([], []) |
| is_read["sequence"] = True |
| elif cat == "_pdbx_struct_assembly": |
| tab = sp.star_read_data(f, ["id", "details"], in_loop) |
| for row in tab: |
| system._assembly_info.assemblies[row[0]] = {"details": row[1]} |
| elif cat == "_pdbx_struct_assembly_gen": |
| tab = sp.star_read_data( |
| f, ["assembly_id", "oper_expression", "asym_id_list"], in_loop |
| ) |
| for row in tab: |
| assembly = system._assembly_info.assemblies[row[0]] |
| if "instructions" not in assembly: |
| assembly["instructions"] = [] |
| chain_ids = [cid.strip() for cid in row[2].strip().split(",")] |
| assembly["instructions"].append( |
| {"oper_expression": row[1], "chains": chain_ids} |
| ) |
| elif cat == "_pdbx_struct_oper_list": |
| tab = sp.star_read_data( |
| f, |
| [ |
| "id", |
| "type", |
| "name", |
| "matrix[1][1]", |
| "matrix[1][2]", |
| "matrix[1][3]", |
| "matrix[2][1]", |
| "matrix[2][2]", |
| "matrix[2][3]", |
| "matrix[3][1]", |
| "matrix[3][2]", |
| "matrix[3][3]", |
| "vector[1]", |
| "vector[2]", |
| "vector[3]", |
| ], |
| in_loop, |
| ) |
| for row in tab: |
| system._assembly_info.operations[ |
| row[0] |
| ] = SystemAssemblyInfo.make_operation( |
| row[1], row[2], row[3:12], row[12:15] |
| ) |
| elif cat == "_generate_selections": |
| tab = sp.star_read_data(f, ["name", "indices"], in_loop) |
| for row in tab: |
| system._selections[row[0]] = [ |
| int(gti.strip()) for gti in row[1].strip().split() |
| ] |
| elif cat == "_generate_labels": |
| tab = sp.star_read_data(f, ["name", "index", "value"], in_loop) |
| for row in tab: |
| if row[0] not in system._labels: |
| system._labels[row[0]] = dict() |
| idx = int(row[1]) |
| system._labels[row[0]][int(row[1])] = row[2] |
| elif cat == "_atom_site": |
| if is_read["coors"]: |
| raise Exception(f"ATOM_SITE block encountered multiple times") |
| |
| tab = sp.star_read_data( |
| f, |
| [ |
| "group_PDB", |
| "id", |
| "label_atom_id", |
| "label_alt_id", |
| "label_comp_id", |
| "label_asym_id", |
| "label_entity_id", |
| "label_seq_id", |
| "pdbx_PDB_ins_code", |
| "Cartn_x", |
| "Cartn_y", |
| "Cartn_z", |
| "occupancy", |
| "B_iso_or_equiv", |
| "pdbx_PDB_model_num", |
| "auth_seq_id", |
| "auth_asym_id", |
| ], |
| in_loop, |
| cols=False, |
| has_blocks=False, |
| ) |
|
|
| groupCol = 0 |
| idxCol = 1 |
| atomNameCol = 2 |
| altIdCol = 3 |
| resNameCol = 4 |
| chainNameCol = 5 |
| entityIdCol = 6 |
| seqIdCol = 7 |
| insCodeCol = 8 |
| xCol = 9 |
| yCol = 10 |
| zCol = 11 |
| occCol = 12 |
| bCol = 13 |
| modelCol = 14 |
| authSeqIdCol = 15 |
| authChainNameCol = 16 |
|
|
| ( |
| atom, |
| residue, |
| chain, |
| prev_chain, |
| prev_residue, |
| prev_atom, |
| prev_entity_id, |
| prev_seq_id, |
| prev_auth_seq_id, |
| ) = (None, None, None, None, None, None, None, None, None) |
| loc = None |
| aIdx = 0 |
| for i in range(len(tab)): |
| if i == 0: |
| first_model = tab[i][modelCol] |
| prev_model = first_model |
| elif (tab[i][modelCol] != prev_model) or ( |
| tab[i][modelCol] != first_model |
| ): |
| if tab[i][modelCol] != prev_model: |
| aIdx = 0 |
| num_loc = system.num_atom_locations() |
| |
| |
| |
| system._extra_models.append(system._new_locations()) |
| prev_model = tab[i][modelCol] |
| locations_generator = (l for l in system.locations()) |
|
|
| loc = next(locations_generator, None) |
| if aIdx >= num_loc: |
| _warn_or_error( |
| strict, |
| f"at atom id: {tab[i][idxCol]} -- too many atoms in model {tab[i][modelCol]} relative to first model {first_model}", |
| ) |
| success = False |
| system._extra_models.clear() |
| break |
|
|
| |
| same = ( |
| (loc is not None) |
| and (tab[i][chainNameCol] == loc.atom.residue.chain.cid) |
| and (tab[i][resNameCol] == loc.atom.residue.name) |
| and ( |
| int( |
| sp.star_value( |
| tab[i][seqIdCol], loc.atom.residue.num |
| ) |
| ) |
| == loc.atom.residue.num |
| ) |
| and (tab[i][atomNameCol] == loc.atom.name) |
| ) |
| if not same: |
| _warn_or_error( |
| strict, |
| f"at atom id: {tab[i][idxCol]} -- atoms in model {tab[i][modelCol]} do not correspond exactly to atoms in first model", |
| ) |
| success = False |
| system._extra_models.clear() |
| break |
|
|
| coor = [ |
| float(tab[i][c]) |
| for c in [xCol, yCol, zCol, occCol, bCol] |
| ] |
| system._extra_models[-1]["coor"][aIdx] = coor |
| system._extra_models[-1]["alt"][aIdx] = sp.star_value( |
| tab[i][altIdCol], " " |
| )[0] |
| aIdx = aIdx + 1 |
| continue |
|
|
| |
| if ( |
| (chain is None) |
| or (prev_entity_id != tab[i][entityIdCol]) |
| or (tab[i][chainNameCol] != chain.cid) |
| ): |
| authid = ( |
| tab[i][authChainNameCol] |
| if (tab[i][authChainNameCol] != "") |
| else tab[i][chainNameCol] |
| ) |
| chain = system.add_chain( |
| tab[i][chainNameCol], |
| tab[i][chainNameCol], |
| authid, |
| int(tab[i][entityIdCol]) - 1, |
| ) |
|
|
| |
| if ( |
| (residue is None) |
| or (chain != prev_chain) |
| or (prev_seq_id != tab[i][seqIdCol]) |
| or (prev_auth_seq_id != tab[i][authSeqIdCol]) |
| ): |
| resnum = ( |
| int(tab[i][seqIdCol]) |
| if sp.star_value_defined(tab[i][seqIdCol]) |
| else chain.num_residues() + 1 |
| ) |
| ri = system._append_residue( |
| tab[i][resNameCol], |
| resnum, |
| tab[i][authSeqIdCol], |
| sp.star_value(tab[i][insCodeCol], " ")[0], |
| ) |
| residue = ResidueView(ri, chain) |
|
|
| |
| |
| |
| x, y, z, occ, B = [ |
| float(tab[i][col]) |
| for col in [xCol, yCol, zCol, occCol, bCol] |
| ] |
| alt = sp.star_value(tab[i][altIdCol], " ")[0] |
| if ( |
| (atom is None) |
| or (residue != prev_residue) |
| or (tab[i][atomNameCol] != atom.name) |
| ): |
| ai = system._append_atom( |
| tab[i][atomNameCol], (tab[i][groupCol] == "HETATM") |
| ) |
| atom = AtomView(ai, residue) |
| system._append_location(x, y, z, occ, B, alt) |
|
|
| prev_chain = chain |
| prev_residue = residue |
| prev_entity_id = tab[i][entityIdCol] |
| prev_seq_id = tab[i][seqIdCol] |
| prev_auth_seq_id = tab[i][authSeqIdCol] |
| is_read["coors"] = True |
| else: |
| sp.advance(f, peeked) |
|
|
| |
| |
| for entity_id in num_of_mols: |
| if system._entities[entity_id].is_polymer(): |
| rem = num_of_mols[entity_id] - system.num_chains_of_entity(entity_id) |
| for _ in range(rem): |
| |
| system.add_chain("A", None, None, entity_id, auto_rename=True) |
|
|
| |
| |
| for chain in system.chains(): |
| entity = chain.get_entity() |
| if not entity.is_polymer() or entity._seq is None: |
| continue |
| k = 0 |
| for ri in range(len(entity._seq)): |
| cur_res = chain.get_residue(k) if k < chain.num_residues() else None |
| if (cur_res is None) or (cur_res.num > ri + 1): |
| |
| chain.add_residue(entity._seq[ri], ri + 1, str(ri + 1), " ", at=k) |
| elif cur_res.num < ri + 1: |
| _warn_or_error( |
| strict, f"inconsistent numbering in chain {chain.cid}" |
| ) |
| break |
| k = k + 1 |
|
|
| |
| for chain in system.chains(): |
| if not chain.check_sequence(): |
| _warn_or_error( |
| strict, |
| f"chain {chain.cid} did not pass sequence check against corresponding entity", |
| ) |
|
|
| system._reindex() |
| return system, success |
|
|
| @staticmethod |
| def from_PDB_string(cif_string: str, options=""): |
| """Initializes and returns a System object from a PDB string.""" |
| import io |
|
|
| f = io.StringIO(cif_string) |
| sys = System._read_pdb(f, options) |
| sys.name = "from_string" |
| return sys |
|
|
| @staticmethod |
| def from_PDB(input_file: str, options=""): |
| """Initializes and returns a System object from a PDB file.""" |
| f = open(input_file, "r") |
| sys = System._read_pdb(f, options) |
| sys.name = input_file |
| return sys |
|
|
| @staticmethod |
| def _read_pdb(f, strict=False, options=""): |
| def _to_float(strval, default): |
| v = default |
| try: |
| v = float(strval) |
| except: |
| pass |
| return v |
|
|
| last_resnum = None |
| last_resname = None |
| last_icode = None |
| last_chain_id = None |
| last_alt = None |
| chain = None |
| residue = None |
|
|
| |
| ter = True |
|
|
| |
| |
| options = options.upper() |
| |
| |
| usese_gid = True if ("USESEGID" in options) else False |
|
|
| |
| charmm_format = True if ("CHARMM" in options) else False |
|
|
| |
| |
| charmm19_format = True if ("CHARMM19" in options) else False |
|
|
| |
| uniq_chain_ids = False if ("ALLOW DUPLICATE CIDS" in options) else True |
|
|
| |
| fix_Ile_CD = False if ("ALLOW ILE CD" in options) else True |
|
|
| |
| |
| icodes_as_sep_res = True |
|
|
| |
| ignore_ter = True if ("IGNORE-TER" in options) else False |
|
|
| |
| verbose = False if ("QUIET" in options) else True |
|
|
| chains_to_rename = [] |
|
|
| |
| system = System("system") |
| all_system = system |
| model_index = 0 |
| for line in f: |
| line = line.strip() |
| if line.startswith("ENDMDL"): |
| |
| if model_index: |
| try: |
| all_system.add_model(system) |
| except Exception as e: |
| warnings.warn( |
| f"error when adding model {model_index + 1}: {str(e)}, skipping model..." |
| ) |
| system = System("system") |
| model_index = model_index + 1 |
| last_resnum = None |
| last_resname = None |
| last_icode = None |
| last_chain_id = None |
| last_alt = None |
| chain = None |
| residue = None |
| continue |
| if line.startswith("END"): |
| break |
| if line.startswith("MODEL"): |
| |
| continue |
| if line.startswith("TER") and not ignore_ter: |
| ter = True |
| continue |
| if not (line.startswith("ATOM") or line.startswith("HETATM")): |
| continue |
|
|
| """ Now read atom record. Sometimes PDB lines are too short (if they do not contain some |
| of the last optional columns). We don't want to read past the end of the string!""" |
| line += " " * 100 |
| atominx = int(line[6:11]) |
| atomname = line[12:16].strip() |
| alt = line[16:17] |
| resname = line[17:21].strip() |
| chain_id = line[21:22].strip() |
| resnum = int(line[23:27]) if charmm_format else int(line[22:26]) |
| icode = " " if charmm_format else line[26:27] |
| x = float(line[30:38]) |
| y = float(line[38:46]) |
| z = float(line[46:54]) |
| seg_id = line[72:76].strip() |
| B = _to_float(line[60:66], 0.0) |
| occ = _to_float(line[54:60], 0.0) |
| het = line.startswith("HETATM") |
|
|
| |
| if usese_gid: |
| chain_id = seg_id |
| elif (chain_id == "") and (len(seg_id) > 0) and seg_id[0].isalnum(): |
| |
| |
| chain_id = seg_id[0:1] |
|
|
| |
| if (chain_id != last_chain_id) or ter: |
| cid_used = system.get_chain_by_id(chain_id) is not None |
| chain = system.add_chain(chain_id, seg_id, chain_id, auto_rename=False) |
| |
| |
| if uniq_chain_ids and cid_used: |
| chain.cid = chain.cid + f"|to rename {len(chains_to_rename)}" |
| if model_index == 0: |
| chains_to_rename.append(chain) |
| if verbose: |
| warnings.warn( |
| "chain name '" |
| + chain_id |
| + "' was repeated while reading, will rename at the end..." |
| ) |
|
|
| |
| last_resnum = None |
| last_resname = None |
| ter = False |
|
|
| if charmm19_format: |
| if resname == "HSE": |
| resname = "HSD" |
| if resname == "HSD": |
| resname = "HIS" |
| if resname == "HSC": |
| resname = "HSP" |
|
|
| |
| |
| if fix_Ile_CD and (resname == "ILE") and (atomname == "CD"): |
| atomname = "CD1" |
|
|
| |
| really_new_atom = True |
| if ( |
| (resnum != last_resnum) |
| or (resname != last_resname) |
| or (icodes_as_sep_res and (icode != last_icode)) |
| ): |
| |
| |
| |
| |
| |
| |
| if ( |
| (resnum == last_resnum) |
| and (resname != last_resname) |
| and (alt != last_alt) |
| and (not icodes_as_sep_res or (icode == last_icode)) |
| ): |
| continue |
|
|
| residue = chain.add_residue( |
| resname, chain.num_residues() + 1, str(resnum), icode[0] |
| ) |
| elif alt != " ": |
| |
| |
| |
| |
| a = residue.find_atom(atomname) |
| if a is not None: |
| really_new_atom = False |
| a.add_location(x, y, z, occ, B, alt[0]) |
|
|
| |
| if really_new_atom: |
| a = residue.add_atom(atomname, het, x, y, z, occ, B, alt[0]) |
|
|
| |
| last_resnum = resnum |
| last_icode = icode |
| last_resname = resname |
| last_chain_id = chain_id |
| last_alt = alt |
|
|
| |
| for chain in chains_to_rename: |
| parts = chain.cid.split("|") |
| assert ( |
| len(parts) > 1 |
| ), "something went wrong when renaming a chain at the end of reading" |
| name = all_system._pick_unique_chain_name(parts[0], verbose) |
| chain.cid = name |
| if len(name): |
| chain.segid = name |
|
|
| |
| for ci, chain in enumerate(all_system.chains()): |
| seq = [None] * chain.num_residues() |
| het = [None] * chain.num_residues() |
| for ri, res in enumerate(chain.residues()): |
| seq[ri] = res.name |
| het[ri] = all(a.het for a in res.atoms()) |
| entity_type, polymer_type = SystemEntity.guess_entity_and_polymer_type(seq) |
| entity = SystemEntity( |
| entity_type, f"chain {chain.cid}", polymer_type, seq, het |
| ) |
| all_system.add_new_entity(entity, [ci]) |
|
|
| return all_system |
|
|
| def to_CIF(self, output_file: str): |
| """Writes the System to a CIF file.""" |
| f = open(output_file, "w") |
| self._write_cif(f) |
|
|
| def to_CIF_string(self): |
| """Returns a CIF string representing the System.""" |
| import io |
|
|
| f = io.StringIO("") |
| self._write_cif(f) |
| cif_str = f.getvalue() |
| f.close() |
| return cif_str |
|
|
| def _write_cif(self, f): |
| |
| _specials_atom_names = [ |
| "MG", "CL", "FE", "ZN", "MN", "NI", "SE", "CU", "BR", "CO", "AS", |
| "BE", "RU", "RB", "ZR", "OS", "SR", "GD", "MO", "AU", "AG", "PT", |
| "AL", "XE", "BE", "CS", "EU", "IR", "AM", "TE", "BA", "SB" |
| ] |
| |
| _ambiguous_atom_names = ["CA", "CD", "NA", "HG", "PB"] |
|
|
| def _guess_type(atom_name, res_name): |
| if len(atom_name) > 0 and atom_name[0] == '"': |
| atom_name = atom_name.replace('"', "") |
| if atom_name[:2] in _specials_atom_names: |
| return atom_name[:2] |
| else: |
| if atom_name in _ambiguous_atom_names and res_name == atom_name: |
| return atom_name |
| elif atom_name == "UNK": |
| return "X" |
| return atom_name[:1] |
|
|
| entry_id = self.name.strip() |
| if entry_id == "": |
| entry_id = "system" |
| f.write( |
| "data_GNR8\n#\n" |
| + "_entry.id " |
| + sp.star_string_escape(entry_id) |
| + "\n#\n" |
| ) |
|
|
| |
| sp.star_loop_header_write( |
| f, "_entity", ["id", "type", "pdbx_description", "pdbx_number_of_molecules"] |
| ) |
| for id, entity in self._entities.items(): |
| num_mol = self.num_molecules_of_entity(id) |
| f.write( |
| f"{id + 1} {sp.star_string_escape(entity._type)} {sp.star_string_escape(entity._desc)} {num_mol}\n" |
| ) |
| f.write("#\n") |
|
|
| |
| sp.star_loop_header_write( |
| f, "_entity_poly_seq", ["entity_id", "num", "mon_id", "hetero"] |
| ) |
| for id, entity in self._entities.items(): |
| if entity._seq is not None: |
| for i, (res, het) in enumerate(zip(entity._seq, entity._het)): |
| f.write(f"{id + 1} {i + 1} {res} {'y' if het else 'n'}\n") |
| f.write("#\n") |
|
|
| |
| sp.star_loop_header_write(f, "_entity_poly", ["entity_id", "type"]) |
| for id, entity in self._entities.items(): |
| if entity.is_polymer(): |
| f.write(f"{id + 1} {sp.star_string_escape(entity._polymer_type)}\n") |
| f.write("#\n") |
|
|
| if self.num_assemblies(): |
| assemblies = self._assembly_info.assemblies |
| ops = self._assembly_info.operations |
| |
| sp.star_loop_header_write(f, "_pdbx_struct_assembly", ["id", "details"]) |
| for assembly_id, assembly in assemblies.items(): |
| f.write(f"{assembly_id} {sp.star_string_escape(assembly['details'])}\n") |
| f.write("#\n") |
|
|
| |
| sp.star_loop_header_write( |
| f, |
| "_pdbx_struct_assembly_gen", |
| ["assembly_id", "oper_expression", "asym_id_list"], |
| ) |
| for assembly_id, assembly in assemblies.items(): |
| for instruction in assembly["instructions"]: |
| chain_list = ",".join([str(ci) for ci in instruction["chains"]]) |
| f.write( |
| f"{assembly_id} {sp.star_string_escape(instruction['oper_expression'])} {chain_list}\n" |
| ) |
| f.write("#\n") |
|
|
| |
| sp.star_loop_header_write( |
| f, |
| "_pdbx_struct_oper_list", |
| [ |
| "id", |
| "type", |
| "name", |
| "matrix[1][1]", |
| "matrix[1][2]", |
| "matrix[1][3]", |
| "matrix[2][1]", |
| "matrix[2][2]", |
| "matrix[2][3]", |
| "matrix[3][1]", |
| "matrix[3][2]", |
| "matrix[3][3]", |
| "vector[1]", |
| "vector[2]", |
| "vector[3]", |
| ], |
| ) |
| for op_id, op in ops.items(): |
| f.write( |
| f"{op_id} {sp.star_string_escape(op['type'])} {sp.star_string_escape(op['name'])} " |
| ) |
| f.write( |
| f"{float(op['matrix'][0][0]):g} {float(op['matrix'][0][1]):g} {float(op['matrix'][0][2]):g} " |
| ) |
| f.write( |
| f"{float(op['matrix'][1][0]):g} {float(op['matrix'][1][1]):g} {float(op['matrix'][1][2]):g} " |
| ) |
| f.write( |
| f"{float(op['matrix'][2][0]):g} {float(op['matrix'][2][1]):g} {float(op['matrix'][2][2]):g} " |
| ) |
| f.write( |
| f"{float(op['vector'][0]):g} {float(op['vector'][1]):g} {float(op['vector'][2]):g}\n" |
| ) |
| f.write("#\n") |
|
|
| sp.star_loop_header_write( |
| f, |
| "_atom_site", |
| [ |
| "group_PDB", |
| "id", |
| "label_atom_id", |
| "label_alt_id", |
| "label_comp_id", |
| "label_asym_id", |
| "label_entity_id", |
| "label_seq_id", |
| "pdbx_PDB_ins_code", |
| "Cartn_x", |
| "Cartn_y", |
| "Cartn_z", |
| "occupancy", |
| "B_iso_or_equiv", |
| "pdbx_PDB_model_num", |
| "auth_seq_id", |
| "auth_asym_id", |
| "type_symbol", |
| ], |
| ) |
| idx = -1 |
| for model_index in range(self.num_models()): |
| self.swap_model(model_index) |
| for chain, entity_id in zip(self.chains(), self._chain_entities): |
| authchainid = ( |
| chain.authid if sp.star_value_defined(chain.authid) else chain.cid |
| ) |
| for residue in chain.residues(): |
| authresid = ( |
| residue.authid |
| if sp.star_value_defined(residue.authid) |
| else residue.num |
| ) |
| for atom in residue.atoms(): |
| idx = idx + 1 |
| for location in atom.locations(): |
| |
| if not location.defined(): |
| continue |
|
|
| coor = location.coor_info |
| f.write("HETATM " if atom.het else "ATOM ") |
| f.write( |
| f"{idx + 1} {atom.name} {sp.atom_site_token(location.alt)} " |
| ) |
| entity_id_str = ( |
| f"{entity_id + 1}" if entity_id is not None else "?" |
| ) |
| f.write( |
| f"{residue.name} {chain.cid} {entity_id_str} {residue.num} " |
| ) |
| f.write( |
| f"{sp.atom_site_token(residue.icode)} {coor[0]:g} {coor[1]:g} {coor[2]:g} " |
| ) |
| f.write(f"{coor[3]:g} {coor[4]:g} {model_index} ") |
| f.write( |
| f"{authresid} {authchainid} {_guess_type(atom.name, residue.name)}\n" |
| ) |
| self.swap_model(model_index) |
| f.write("#\n") |
|
|
| |
| if len(self._selections): |
| sp.star_loop_header_write(f, "_generate_selections", ["name", "indices"]) |
| for name, indices in self._selections.items(): |
| f.write( |
| f"{sp.star_string_escape(name)} \"{' '.join([str(i) for i in indices])}\"\n" |
| ) |
| f.write("#\n") |
|
|
| |
| if len(self._labels): |
| sp.star_loop_header_write(f, "_generate_labels", ["name", "index", "value"]) |
| for category, label_dict in self._labels.items(): |
| for gti, label in label_dict.items(): |
| f.write( |
| f"{sp.star_string_escape(category)} {gti} {sp.star_string_escape(label)}\n" |
| ) |
| f.write("#\n") |
|
|
| def to_PDB(self, output_file: str, options: str = "", mask_indices=None, seq=None): |
| """Writes the System to a PDB file. |
| |
| Args: |
| output_file (str): output PDB file name. |
| options (str, optional): a string specifying various options for |
| the writing process. The presence of certain sub-strings will |
| trigger specific behaviors. Currently recognized sub-strings |
| include "CHARMM", "CHARMM19", "CHARMM22", "RENUMBER", "NOEND", |
| "NOTER", and "NOALT". This option is case-insensitive. |
| """ |
| f = open(output_file, "w") |
| self._write_pdb(f, options, mask_indices=mask_indices, seq=seq) |
|
|
| def to_PDB_string(self, options=""): |
| """Writes the System to a PDB string. The options string has the same |
| interpretation as with System::toPDB. |
| """ |
| import io |
|
|
| f = io.StringIO("") |
| self._write_pdb(f, options) |
| cif_str = f.getvalue() |
| f.close() |
| return cif_str |
|
|
| def _write_pdb(self, f, options="", mask_indices=None, seq=None): |
| def _pdb_line(loc: AtomLocationView, ai: int, ri=None, rn=None, an=None): |
| if rn is None: |
| rn = loc.atom.residue.name |
| if ri is None: |
| ri = loc.atom.residue.num |
| if an is None: |
| an = loc.atom.name |
| icode = loc.atom.residue.icode |
| cid = loc.atom.residue.chain.cid |
| if len(cid) > 1: |
| cid = cid[0] |
| segid = loc.atom.residue.chain.segid |
| if len(segid) > 4: |
| segid = segid[0:4] |
|
|
| |
| if len(an) < 4: |
| an_str = " %-.3s" % an |
| else: |
| an_str = "%.4s" % an |
|
|
| |
| |
| line = ( |
| "%6s%5d %-4s%c%-4s%.1s%4d%c %8.3f%8.3f%8.3f%6.2f%6.2f %.4s" |
| % ( |
| "HETATM" if loc.atom.het else "ATOM ", |
| ai % 100000, |
| an_str, |
| loc.alt, |
| rn, |
| cid, |
| ri % 10000, |
| icode, |
| loc.x, |
| loc.y, |
| loc.z, |
| loc.occ, |
| loc.B, |
| segid, |
| ) |
| ) |
|
|
| return line |
|
|
| |
| |
| options = options.upper() |
| |
| charmmFormat = True if "CHARMM" in options else False |
|
|
| |
| |
| charmm19Format = True if "CHARMM19" in options else False |
|
|
| |
| |
| |
| charmm22Format = True if "CHARMM22" in options else False |
|
|
| |
| renumber = True if "RENUMBER" in options else False |
|
|
| |
| |
| noend = True if "NOEND" in options else False |
|
|
| |
| |
| |
| noter = True if "NOTER" in options else False |
|
|
| |
| writeAlt = True if "NOALT" in options else False |
|
|
| |
| |
| genericFormat = False |
|
|
| if charmm19Format and charmm22Format: |
| raise Exception( |
| "CHARMM 19 and 22 formatting options cannot be specified together" |
| ) |
|
|
| atomIndex = 1 |
| for ci, chain in enumerate(self.chains()): |
| for ri, residue in enumerate(chain.residues()): |
| for ai, atom in enumerate(residue.atoms()): |
| |
| atomname = atom.name |
| resname = residue.name |
| if seq is not None: |
| resname = str(seq[ri]) |
| if charmmFormat: |
| if (residue.name == "ILE") and (atom.name == "CD1"): |
| atomname = "CD" |
| if (atom.name == "O") and (ri == chain.num_residues() - 1): |
| atomname = "OT1" |
| if (atom.name == "OXT") and (ri == chain.num_residues() - 1): |
| atomname = "OT2" |
| if residue.name == "HOH": |
| resname = "TIP3" |
|
|
| if charmm19Format: |
| if residue.name == "HSD": |
| resname = "HIS" |
| if residue.name == "HSE": |
| resname = "HSD" |
| if residue.name == "HSC": |
| resname = "HSP" |
| elif charmm22Format: |
| """This will convert from CHARMM19 to CHARMM22 as well as from a generic downlodaded |
| * PDB file to one ready for use in CHARMM22. The latter is because in the all-hydrogen |
| * topology, HIS protonation state must be explicitly specified, so there is no HIS per se. |
| * Whereas in typical downloaded PDB files HIS is used for all histidines (usually, one |
| * does not even really know the protonation state). Whether sometimes people do specify it |
| * nevertheless, and what naming format they use to do so, I am not sure (welcome to the |
| * PDB file format). But certainly almost always it is just HIS. Below HIS is renamed to |
| * HSD, the neutral form with proton on ND1. This is an assumption; not a perfect one, but |
| * something needs to be assumed. Doing this renaming will make the PDB file work in MM |
| * packages with the all-hydrogen model.""" |
| if residue.name == "HSD": |
| resname = "HSE" |
| if residue.name == "HIS": |
| resname = "HSD" |
| if residue.name == "HSP": |
| resname = "HSC" |
| elif genericFormat: |
| if residue.name in ["HSD", "HSP", "HSE", "HSC"]: |
| resname = "HIS" |
| if (residue.name == "ILE") and (atom.name == "CD"): |
| atomname = "CD1" |
|
|
| |
| for li in range(atom.num_locations()): |
| if renumber: |
| f.write( |
| _pdb_line( |
| atom.get_location(li), |
| atomIndex, |
| ri=ri + 1, |
| rn=resname, |
| an=atomname, |
| ) |
| + "\n" |
| ) |
| else: |
| a = atom.get_location(li) |
| if mask_indices is not None: |
| if ri in mask_indices: |
| a.atom.B = 0 |
| else: |
| a.atom.B = 1 |
| f.write( |
| _pdb_line( |
| a, |
| atomIndex, |
| rn=resname, |
| an=atomname, |
| ) |
| + "\n" |
| ) |
| atomIndex = atomIndex + 1 |
|
|
| if not noter and (ri == chain.num_residues() - 1): |
| f.write("TER\n") |
| if not noend and (ci == self.num_chains() - 1): |
| f.write("END\n") |
|
|
| def canonicalize_protein( |
| self, |
| level=2, |
| drop_coors_unknowns=False, |
| drop_coors_missing_backbone=False, |
| filter_by_entity=False, |
| ): |
| """Canonicalize the calling System object (in place) by assuming that it represents |
| a protein molecular system. Different canonicalization rigor and options |
| can be specified but are all optional. |
| |
| Args: |
| level (int): Canonicalization level that determines which nonstandard-to-standard |
| residue mappings are performed. Possible values are 1, 2 or 3, with 2 being |
| the default and higher values meaning more rigorous (and less conservative) |
| canonicalization. With level 1, only truly equivalent mappings are performed |
| (e.g., different His protonation states are mapped to the canonical residue |
| name HIS that does not specify protonation). Level 2 adds to this some less |
| exact but still quite close mappings--i.e., seleno-methionine (MSE) and seleno- |
| cystine (SEC) to methionine (MET) and cystine (CYS). Level 3 further adds |
| even less equivalent but still reasonable mappings--i.e., phosphorylated SER, |
| THR, TYR, and HIS to their unphosphorylated counterparts as well as S-oxy Cys |
| to Cys. |
| drop_coors_unknowns (bool, optional): if True, will discard structural information |
| for all residues that are not natural or mappable under the current level. |
| NOTE: any sequence record for these residues (i.e., if they are part of a |
| polymer entity) will be preserved. |
| drop_coors_missing_backbone (bool, optional): if True, will discard structural |
| information for residues that do not have at least the N, CA, C, and O |
| backbone atoms. Same note applies regarding the full sequence record as in |
| the above. |
| filter_by_entity (bool, optional): if True, will remove any chains that do not |
| represent polymer/polypeptide entities. This is convenient for cases where a |
| System object has both protein and non-protein components. However, depending |
| on how the System object was generated, entity metadata may not have been filled, |
| so applying this canonicalization approach will remove the entire structure. |
| For this reason, the option is False by default. |
| """ |
|
|
| def _mod_to_standard_aa_mappings( |
| less_standard: bool, almost_standard: bool, standard: bool |
| ): |
| |
| standard_map = {"HSD": "HIS", "HSE": "HIS", "HSC": "HIS", "HSP": "HIS"} |
|
|
| |
| |
| almost_standard_map = {"MSE": "MET", "SEC": "CYS"} |
|
|
| |
| |
| |
| less_standard_map = { |
| "HIP": "HIS", |
| "CSX": "CYS", |
| "SEP": "SER", |
| "TPO": "THR", |
| "PTR": "TYR", |
| } |
|
|
| ret = dict() |
| if standard: |
| ret.update(standard_map) |
| if almost_standard: |
| ret.update(almost_standard_map) |
| if less_standard: |
| ret.update(less_standard_map) |
| return ret |
|
|
| def _to_standard_aa_mappings( |
| less_standard: bool, almost_standard: bool, standard: bool |
| ): |
| |
| mapping = _mod_to_standard_aa_mappings( |
| less_standard, almost_standard, standard |
| ) |
|
|
| |
| import src.chroma.utility.polyseq as polyseq |
|
|
| for aa in polyseq.canonical_amino_acids(): |
| mapping[aa] = aa |
|
|
| return mapping |
|
|
| less_standard, almost_standard, standard = False, False, False |
| if level == 3: |
| less_standard, almost_standard, standard = True, True, True |
| elif level == 2: |
| less_standard, almost_standard, standard = False, True, True |
| elif level == 1: |
| less_standard, almost_standard, standard = False, False, True |
| else: |
| raise Exception(f"unknown canonicalization level {level}") |
|
|
| to_standard = _to_standard_aa_mappings(less_standard, almost_standard, standard) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| chains_to_delete = [] |
| residues_to_rename = dict() |
| for ci, chain in enumerate(self.chains()): |
| entity = chain.get_entity() |
| if filter_by_entity: |
| if ( |
| (entity is None) |
| or (entity._type != "polymer") |
| or ("polypeptide" not in entity.polymer_time) |
| ): |
| chains_to_delete.append(chain) |
| continue |
|
|
| |
| cleared_residues = 0 |
| for residue in reversed(list(chain.residues())): |
| aa = residue.name |
| delete_atoms = False |
| |
| if aa in to_standard: |
| aa_new = to_standard[aa] |
| if aa != aa_new: |
| |
| if ( |
| (aa == "HSD") |
| or (aa == "HSE") |
| or (aa == "HSC") |
| or (aa == "HSP") |
| ) and (aa_new == "HIS"): |
| pass |
| elif ((aa == "MSE") and (aa_new == "MET")) or ( |
| (aa == "SEC") and (aa_new == "CYS") |
| ): |
| SE = residue.find_atom("SE") |
| if SE is not None: |
| if aa == "MSE": |
| SE.residue.rename("SD") |
| else: |
| SE.residue.rename("SG") |
| elif ( |
| ((aa == "HIP") and (aa_new == "HIS")) |
| or ((aa == "SEP") and (aa_new == "SER")) |
| or ((aa == "TPO") and (aa_new == "THR")) |
| or ((aa == "PTR") and (aa_new == "TYR")) |
| ): |
| |
| for atomname in ["P", "O1P", "O2P", "O3P", "HOP2", "HOP3"]: |
| a = residue.find_atom(atomname) |
| if a is not None: |
| a.delete() |
| elif (aa == "CSX") and (aa_new == "CYS"): |
| a = residue.find_atom("OD") |
| if a is not None: |
| a.delete() |
|
|
| |
| entity_id = chain.get_entity_id() |
| if entity_id is None: |
| residue.rename(aa_new) |
| else: |
| if entity_id not in residues_to_rename: |
| residues_to_rename[entity_id] = dict() |
| if ci not in residues_to_rename[entity_id]: |
| residues_to_rename[entity_id][ci] = list() |
| residues_to_rename[entity_id][ci].append( |
| (residue.get_index_in_chain(), aa_new) |
| ) |
| else: |
| if aa == "ARG": |
| A = {an: None for an in ["CD", "NE", "CZ", "NH1", "NH2"]} |
| for an in A: |
| atom = residue.find_atom(an) |
| if atom is not None and atom.num_locations(): |
| A[an] = atom.get_location(0) |
| if all([a is not None for n, a in A.items()]): |
| dihe1 = System.dihedral( |
| A["CD"], A["NE"], A["CZ"], A["NH1"] |
| ) |
| dihe2 = System.dihedral( |
| A["CD"], A["NE"], A["CZ"], A["NH2"] |
| ) |
| if abs(dihe1) > abs(dihe2): |
| A["NH1"].name = "NH2" |
| A["NH2"].name = "NH1" |
| elif drop_coors_unknowns: |
| delete_atoms = True |
|
|
| if not drop_coors_missing_backbone: |
| if not delete_atoms and not residue.has_full_backbone(): |
| delete_atoms = True |
|
|
| if delete_atoms: |
| residue.delete_atoms() |
| cleared_residues += 1 |
|
|
| |
| |
| |
| |
| if ( |
| not filter_by_entity |
| and (cleared_residues != 0) |
| and (cleared_residues == chain.num_residues()) |
| ): |
| chains_to_delete.append(chain) |
|
|
| |
| |
| for entity_id, ops in residues_to_rename.items(): |
| chain_indices = set(ops.keys()) |
| entity_chains = set(self.get_chains_of_entity(entity_id, by="index")) |
| unique_renames = set([tuple(v) for v in ops.values()]) |
| fork = True |
| if (chain_indices == entity_chains) and (len(unique_renames) == 1): |
| |
| fork = False |
| for ci, renames in ops.items(): |
| chain = self.get_chain(ci) |
| for ri, new_name in renames: |
| chain.get_residue(ri).rename(new_name, fork_entity=fork) |
|
|
| |
| for chain in reversed(chains_to_delete): |
| chain.delete() |
|
|
| self._reindex() |
|
|
| def sequence(self, format="three-letter-list"): |
| """Returns the full sequence of this System, concatenated over all |
| chains in their order within the System. |
| |
| Args: |
| format (str): sequence format. Possible options are either |
| "three-letter-list" (default) or "one-letter-string". |
| |
| Returns: |
| List (default) or string. |
| """ |
| if format == "three-letter-list": |
| seq = [] |
| else: |
| seq = "" |
|
|
| for chain in self.chains(): |
| seq = seq + chain.sequence(format) |
| return seq |
|
|
| @staticmethod |
| def distance(a1: AtomLocationView, a2: AtomLocationView): |
| """Computes the distance between atom locations `a1` and `a2`.""" |
| v21 = a1.coors - a2.coors |
| return np.linalg.norm(v21) |
|
|
| @staticmethod |
| def angle( |
| a1: AtomLocationView, a2: AtomLocationView, a3: AtomLocationView, radians=False |
| ): |
| """Computes the angle formed by three 3D points represented by AtomLocationView objects. |
| |
| Args: |
| a1, a2, a3 (AtomLocationView): three 3D points. |
| radian (bool, optional): if True (default False), will return the angle in radians. |
| Otherwise, in degrees. |
| |
| Returns: |
| Angle `a1`-`a2`-`a3`. |
| """ |
| v21 = a1.coors - a2.coors |
| v23 = a3.coors - a2.coors |
| v21 = v21 / np.linalg.norm(v21) |
| v23 = v23 / np.linalg.norm(v23) |
| c = np.dot(v21, v23) |
| return np.arctan2(np.sqrt(1 - c * c), c) * (1 if radians else 180.0 / np.pi) |
|
|
| @staticmethod |
| def dihedral( |
| a1: AtomLocationView, |
| a2: AtomLocationView, |
| a3: AtomLocationView, |
| a4: AtomLocationView, |
| radians=False, |
| ): |
| """Computes the dihedral angle formed by four 3D points represented by AtomLocationView objects. |
| |
| Args: |
| a1, a2, a3, a4 (AtomLocationView): four 3D points. |
| radian (bool, optional): if True (default False), will return the angle in radians. |
| Otherwise, in degrees. |
| |
| Returns: |
| Dihedral angle `a1`-`a2`-`a3`-`a4`. |
| """ |
| AB = a1.coors - a2.coors |
| CB = a3.coors - a2.coors |
| DC = a4.coors - a3.coors |
|
|
| if min([np.linalg.norm(p) for p in [AB, CB, DC]]) == 0.0: |
| raise Exception("some points coincide in dihedral calculation") |
|
|
| ABxCB = np.cross(AB, CB) |
| ABxCB = ABxCB / np.linalg.norm(ABxCB) |
| DCxCB = np.cross(DC, CB) |
| DCxCB = DCxCB / np.linalg.norm(DCxCB) |
|
|
| |
| dotp = np.dot(ABxCB, DCxCB) |
| if dotp > 1.0: |
| dotp = 1.0 |
| elif dotp < -1.0: |
| dotp = -1.0 |
|
|
| angle = np.arccos(dotp) |
| if np.dot(ABxCB, DC) > 0: |
| angle *= -1 |
| if not radians: |
| angle *= 180.0 / np.pi |
|
|
| return angle |
|
|
| @staticmethod |
| def protein_backbone_atom_type(atom_name: str, no_hyd=True, by_name=True): |
| """Backbone atoms can be either nitrogens, carbons, oxigens, or hydrogens. |
| Specifically, possible known names in each category are: |
| 'N', 'NT' |
| 'CA', 'C', 'CY', 'CAY' |
| 'OY', 'O', 'OCT*', 'OXT', 'OT1', 'OT2' |
| 'H', 'HY*', 'HA*', 'HN', 'HT*', '1H', '2H', '3H' |
| """ |
| array = ["N", "CA", "C", "O", "H"] if by_name else [0, 1, 2, 3, 4] |
| if atom_name in ["N", "NT"]: |
| return array[0] |
| if atom_name == "CA": |
| return array[1] |
| if (atom_name == "C") or (atom_name == "CY"): |
| return array[2] |
| if atom_name in ["O", "OY", "OXT", "OT1", "OT2"] or atom_name.startswith("OCT"): |
| return array[3] |
| if not no_hyd: |
| if atom_name in ["H", "HA", "HN"]: |
| return array[4] |
| if atom_name.startswith("HT") or atom_name.startswith("HY"): |
| return array[4] |
| |
| if ( |
| atom_name.startswith("1H") |
| or atom_name.startswith("2H") |
| or atom_name.startswith("3H") |
| ): |
| return array[4] |
| return None |
|
|
|
|
| @dataclass |
| class SystemEntity: |
| """A molecular entity represented in a molecular system.""" |
|
|
| _type: str |
| _desc: str |
| _polymer_type: str |
| _seq: list |
| _het: list |
|
|
| def is_polymer(self): |
| """Returns whether the entity represents a polymer.""" |
| return self._type == "polymer" |
|
|
| @classmethod |
| def guess_entity_and_polymer_type(cls, seq: List): |
| is_poly = np.mean([polyseq.is_polymer_residue(res, None) for res in seq]) > 0.8 |
| polymer_type = None |
| if is_poly: |
| entity_type = "polymer" |
| for ptype in polyseq.polymerType: |
| if ( |
| np.mean([polyseq.is_polymer_residue(res, ptype) for res in seq]) |
| > 0.8 |
| ): |
| polymer_type = polyseq.polymer_type_name(ptype) |
| break |
| else: |
| entity_type = "unknown" |
|
|
| return entity_type, polymer_type |
|
|
| @property |
| def type(self): |
| return self._type |
|
|
| @property |
| def description(self): |
| return self._desc |
|
|
| @property |
| def polymer_type(self): |
| return self._polymer_type |
|
|
| @property |
| def sequence(self): |
| return self._seq |
|
|
| @property |
| def hetero(self): |
| return self._het |
|
|
|
|
| @dataclass |
| class BaseView: |
| """An abstract base "view" class for accessing different parts of System.""" |
|
|
| _ix: int |
| _parent: object |
|
|
| def get_index(self): |
| """Return the index of this atom location in its System.""" |
| return self._ix |
|
|
| def is_valid(self): |
| return self._ix >= 0 and self._parent is not None |
|
|
| def _delete(self): |
| at = self._ix - self.parent._siblings.child_index(self.parent._ix, 0) |
| self.parent._siblings.delete_child(self.parent._ix, at) |
|
|
| @property |
| def parent(self): |
| return self._parent |
|
|
|
|
| @dataclass |
| class ChainView(BaseView): |
| """A Chain view, allowing hierarchical exploration and editing.""" |
|
|
| def __init__(self, ix: int, system: System): |
| self._ix = ix |
| self._parent = system |
| self._siblings = system._chains |
|
|
| def __str__(self): |
| return f"{self.cid} ({self.segid}/{self.authid}) -> {str(self.system)}" |
|
|
| def residues(self): |
| for rn in range(self.num_residues()): |
| ri = self._siblings.child_index(self._ix, rn) |
| yield ResidueView(ri, self) |
|
|
| def num_residues(self): |
| """Returns the number of residues in the Chain.""" |
| return self._siblings.num_children(self._ix) |
|
|
| def num_structured_residues(self): |
| return sum([res.has_structure() for res in self.residues()]) |
|
|
| def num_atoms(self): |
| return sum([res.num_atoms() for res in self.residues()]) |
|
|
| def num_atom_locations(self): |
| return sum([res.num_atom_locations() for res in self.residues()]) |
|
|
| def sequence(self, format="three-letter-list"): |
| """Returns the sequence of this chain. See `System::sequence()` for |
| possible formats. |
| """ |
| if format == "three-letter-list": |
| seq = [None] * self.num_residues() |
| for ri, residue in enumerate(self.residues()): |
| seq[ri] = residue.name |
| return seq |
| elif format == "one-letter-string": |
| import src.data.protein.polyseq as polyseq |
|
|
| seq = [None] * self.num_residues() |
| for ri, residue in enumerate(self.residues()): |
| seq[ri] = polyseq.to_single(residue.name) |
| return "".join(seq) |
| else: |
| raise Exception(f"unknown sequence format {format}") |
|
|
| def get_residue(self, ri: int): |
| """Get the residue at the specified index within the Chain. |
| |
| Args: |
| ri (int): Residue index within the Chain. |
| |
| Returns: |
| ResidueView object corresponding to the residue in question. |
| """ |
| if ri < 0 or ri >= self.num_residues(): |
| raise Exception( |
| f"residue index {ri} out of range for Chain, which has {self.num_residues()} residues" |
| ) |
| ri = self._siblings.child_index(self._ix, ri) |
| return ResidueView(ri, self) |
|
|
| def get_residue_index(self, residue: ResidueView): |
| """Get the index of the given residue in this Chain.""" |
| return residue._ix - self._siblings.child_index(self._ix, 0) |
|
|
| def get_atom(self, aidx: int): |
| """Get the atom at index `aidx` within this chain.""" |
| if aidx < 0: |
| raise Exception(f"negative atom index: {aidx}") |
| off = 0 |
| for residue in self.residues(): |
| na = residue.num_atoms() |
| if aidx < off + na: |
| return residue.get_atom(aidx - off) |
| off = off + na |
| raise Exception( |
| f"atom index {aidx} out of range for System, which has {self.num_atoms()} atoms" |
| ) |
|
|
| def get_atoms(self): |
| """Return a list of all atoms in this chain.""" |
| atoms_views = [] |
| for residue in self.residues(): |
| atoms_views.extend(residue.get_atoms()) |
| return atoms_views |
|
|
| def __getitem__(self, res_idx: int): |
| return self.get_residue(res_idx) |
|
|
| def get_entity_id(self): |
| """Return the entity ID corresponding to this chain.""" |
| return self.system._chain_entities[self._ix] |
|
|
| def get_entity(self): |
| """Return the entity this chain belongs to.""" |
| entity_id = self.get_entity_id() |
| if entity_id is None: |
| return None |
| return self.system._entities[entity_id] |
|
|
| def check_sequence(self): |
| """Compare the list of residue names of this chain to the corresponding entity sequence record.""" |
| entity = self.get_entity() |
| if entity is not None and entity.is_polymer(): |
| if self.num_residues() != len(entity._seq): |
| return False |
| for res, ent_aan in zip(self.residues(), entity._seq): |
| if res.name != ent_aan: |
| return False |
| return True |
|
|
| def add_residue(self, name: str, num: int, authid: str, icode: str = " ", at=None): |
| """Add a new residue to this chain. |
| |
| Args: |
| name (str): Residue name. |
| num (int): Residue number (i.e., residue ID). |
| authid (str): Author residue ID. |
| icode (str): Insertion code. |
| at (int, optional): Index at which to insert the residue. Default |
| is to append to the end of the chain (i.e., equivalent of ``at` |
| being equal to the present length of the chain). |
| """ |
| if at is None: |
| at = self.num_residues() |
| ri = self._siblings.insert_child( |
| self._ix, |
| at, |
| {"name": name, "resnum": num, "authresid": authid, "icode": icode}, |
| ) |
| return ResidueView(ri, self) |
|
|
| def delete(self, keep_entity=False): |
| """Deletes this Chain from its System. |
| |
| Args: |
| keep_entity (bool, optional): If False (default) and if the chain |
| being deleted happens to be the last representative of the |
| entity it belongs to, the entity will be deleted. If True, the |
| entity will always be kept. |
| """ |
| |
| self.system._assembly_info.delete_chain(self.cid) |
|
|
| |
| if not keep_entity: |
| eid = self.get_entity_id() |
| if self.system.num_chains_of_entity(eid) == 0: |
| self.system.delete_entity(eid) |
|
|
| self.system._chain_entities.pop(self._ix) |
| self._siblings.delete(self._ix) |
| self._ix = -1 |
|
|
| @property |
| def system(self): |
| return self._parent |
|
|
| @property |
| def cid(self): |
| return self._siblings["cid"][self._ix] |
|
|
| @property |
| def segid(self): |
| return self._siblings["segid"][self._ix] |
|
|
| @property |
| def authid(self): |
| return self._siblings["authid"][self._ix] |
|
|
| @cid.setter |
| def cid(self, val): |
| self._siblings["cid"][self._ix] = val |
|
|
| @segid.setter |
| def segid(self, val): |
| self._siblings["segid"][self._ix] = val |
|
|
| @authid.setter |
| def authid(self, val): |
| self._siblings["authid"][self._ix] = val |
|
|
|
|
| @dataclass |
| class ResidueView(BaseView): |
| """A Residue view, allowing hierarchical exploration and editing.""" |
|
|
| def __init__(self, ix: int, chain: ChainView): |
| self._ix = ix |
| self._parent = chain |
| self._siblings = chain.system._residues |
|
|
| def __str__(self): |
| return f"{self.name} {self.num} ({self.authid}) -> {str(self.chain)}" |
|
|
| def atoms(self): |
| off = self._siblings.child_index(self._ix, 0) |
| for an in range(self.num_atoms()): |
| yield AtomView(off + an, self) |
|
|
| def num_atoms(self): |
| return self._siblings.num_children(self._ix) |
|
|
| def num_atom_locations(self): |
| return sum([a.num_locations() for a in self.atoms()]) |
|
|
| def has_structure(self): |
| """Returns whether the atom has any structural information (i.e., one or more locations).""" |
| for a in self.atoms(): |
| if a.num_locations(): |
| return True |
| return False |
|
|
| def get_atom(self, ai: int): |
| """Get the atom at the specified index within the Residue. |
| |
| Args: |
| atom_idx (int): Atom index within the Residue. |
| |
| Returns: |
| AtomView object corresponding to the atom in question. |
| """ |
|
|
| if ai < 0 or ai >= self.num_atoms(): |
| raise Exception( |
| f"atom index {ai} out of range for Residue, which has {self.num_atoms()} atoms" |
| ) |
| ai = self._siblings.child_index(self._ix, ai) |
| return AtomView(ai, self) |
|
|
| def get_atom_index(self, atom: AtomView): |
| """Get the index of the given atom in this Residue.""" |
| return atom._ix - self._siblings.child_index(self._ix, 0) |
|
|
| def find_atom(self, name): |
| """Find and return the first atom (as AtomView object) with the given name |
| within the Residue or None.""" |
| for atom in self.atoms(): |
| if atom.name == name: |
| return atom |
| return None |
|
|
| def __getitem__(self, atom_idx: int): |
| return self.get_atom(atom_idx) |
|
|
| def get_index_in_chain(self): |
| """Return the index of the Residue in its parent Chain.""" |
| return self.chain.get_residue_index(self) |
|
|
| def rename(self, new_name: str, fork_entity=True): |
| """Assigns the residue a new name with all proper updates. |
| |
| Args: |
| new_name (str): New residue name. |
| fork_entity (bool, optional): If True (default) and if parent |
| chain corresponds to an entity that has other chains |
| associated with it and there is a real renaming (i.e., |
| the old name is not the same as the new name), will |
| make a new (duplicate) entity for to this chain and |
| will edit the new one, leaving the old one unchanged. |
| If False, will not perform this regardless. NOTE: |
| setting this to False can create an inconsistent state |
| between chain and entity sequence information. |
| """ |
| entity_id = self.chain.get_entity_id() |
| if entity_id is not None: |
| entity = self.system._entities[entity_id] |
| ri = self.get_index_in_chain() |
| if fork_entity and (entity._seq[ri] != new_name): |
| ci = self.chain.get_index() |
| entity_id = self.system._ensure_unique_entity(ci) |
| entity = self.system._entities[entity_id] |
| entity._seq[ri] = new_name |
| self._siblings["name"][self._ix] = new_name |
|
|
| def add_atom( |
| self, |
| name: str, |
| het: bool, |
| x: float = None, |
| y: float = None, |
| z: float = None, |
| occ: float = 1.0, |
| B: float = 0.0, |
| alt: str = " ", |
| at=None, |
| ): |
| """Adds a new atom to the residue (appending it at the end) and |
| returns an AtomView to it. If atom location information is |
| specified, will also add a location to the atom. |
| |
| Args: |
| name (str): Atom name. |
| het (bool): Whether it is a hetero-atom. |
| x, y, z (float): Atom location coordinates. |
| occ (float): Occupancy. |
| B (float): B-factor. |
| alt (str): Alternative position character. |
| at (int, optional): Index at which to insert the atom. Default |
| is to append to the end of the residue (i.e., equivalent of |
| ``at` being equal to the number of atoms in the residue). |
| |
| Returns: |
| AtomView object corresponding to the newly added atom. |
| """ |
| if at is None: |
| at = self.num_atoms() |
| ai = self._siblings.insert_child(self._ix, at, {"name": name, "het": het}) |
| atom = AtomView(ai, self) |
|
|
| |
| if x is not None: |
| atom.add_location(x, y, z, occ, B, alt) |
|
|
| return atom |
|
|
| def delete(self, fork_entity=True): |
| """Deletes this residue from its Chain/System. |
| |
| Args: |
| fork_entity (bool, optional): If True (default) and if parent |
| chain corresponds to an entity that has other chains |
| associated with it, will make a new (duplicate) entity |
| for to this chain and will edit the new one, leaving the |
| old one unchanged. If False, will not perform this. |
| NOTE: setting this to False can create an inconsistent state |
| between chain and entity sequence information. |
| """ |
| |
| entity_id = self.chain.get_entity_id() |
| if entity_id is not None: |
| entity = self.system._entities[entity_id] |
| ri = self.get_index_in_chain() |
| if fork_entity: |
| ci = self.chain.get_index() |
| entity_id = self.system._ensure_unique_entity(ci) |
| entity = self.system._entities[entity_id] |
| entity._seq.pop(ri) |
|
|
| |
| self._delete() |
| self._ix = -1 |
|
|
| def delete_atoms(self, atoms=None): |
| """Delete either the specified list of atoms or all atoms from the residue. |
| |
| Args: |
| atoms (list, optional): List of AtomView objects corresponding to the |
| atoms to delete. If not specified, will delete all atoms in the residue. |
| """ |
| if atoms is None: |
| atoms = list(self.atoms()) |
| for atom in reversed(atoms): |
| if atom.residue != self: |
| raise Exception(f"Atom {atom} does not belong to Residue {self}") |
| atom.delete() |
|
|
| @property |
| def chain(self): |
| return self._parent |
|
|
| @property |
| def system(self): |
| return self.chain.system |
|
|
| @property |
| def name(self): |
| return self._siblings["name"][self._ix] |
|
|
| @property |
| def num(self): |
| return self._siblings["resnum"][self._ix] |
|
|
| @property |
| def authid(self): |
| return self._siblings["authresid"][self._ix] |
|
|
| @property |
| def icode(self): |
| return self._siblings["icode"][self._ix] |
|
|
| def get_backbone(self, no_hyd=True): |
| """Assuming that this is a protein residue (i.e., an amino acid), returns the |
| list of atoms corresponding to the residue's backbone, in the order: |
| backbone amide (N), alpha carbon (CA), carbonyl carbon (C), carbonyl oxygen (O), |
| and amide hydrogen (H, optional). |
| |
| Args: |
| no_hyd (bool, optional): If True (default), will exclude the amide hydrogen |
| and only return four atoms. If False, will include the amide hydrogen. |
| |
| Returns: |
| A list with each entry being an AtomView object corresponding to the backbone |
| atom in the order above or None if the atom does not exist in the residue. |
| """ |
| bb = [None] * (4 if no_hyd else 5) |
| left = len(bb) |
| for atom in self.atoms(): |
| i = System.protein_backbone_atom_type(atom.name, no_hyd) |
| if i is None or bb[i] is not None: |
| continue |
| bb[i] = atom |
| left = left - 1 |
| if left == 0: |
| break |
| return bb |
|
|
| def has_full_backbone(self, no_hyd=True): |
| """Assuming that this is a protein residue (i.e., an amino acid), returns |
| whether the residue harbors a structurally defined backbone (i.e., has |
| all backbone atoms each of which has location information). |
| |
| Args: |
| no_hyd (bool, optional): If True (default), will ignore whether the amide |
| hydrogen exists or not (if False will consider it). |
| |
| Returns: |
| Boolean indicating whether there is a full backbone in the residue. |
| """ |
| bb = self.get_backbone(no_hyd) |
| return all([(a is not None) and a.num_locations() for a in bb]) |
|
|
| def delete_non_backbone(self, no_hyd=True): |
| """Assuming that this is a protein residue (i.e., an amino acid), deletes |
| all atoms except backbone atoms. |
| |
| Args: |
| no_hyd (bool, optional): If True (default), will not consider the amide |
| hydrogen as a backbone atom (if False will consider it). |
| """ |
| to_delete = [] |
| for atom in self.atoms(): |
| if System.protein_backbone_atom_type(atom.name, no_hyd) is None: |
| to_delete.append(atom) |
| self.delete_atoms(to_delete) |
|
|
|
|
| @dataclass |
| class AtomView(BaseView): |
| """An Atom view, allowing hierarchical exploration and editing.""" |
|
|
| def __init__(self, ix: int, residue: ResidueView): |
| self._ix = ix |
| self._parent = residue |
| self._siblings = residue.system._atoms |
|
|
| def __str__(self): |
| string = self.name + (" (HET) " if self.het else " ") |
| if self.num_locations() > 0: |
| string = string + str(self.get_location(0)) |
| string = string + f" ({self.num_locations()})" |
| return string + " -> " + str(self.residue) |
|
|
| def locations(self): |
| off = self._siblings.child_index(self._ix, 0) |
| for ln in range(self.num_locations()): |
| yield AtomLocationView(off + ln, self) |
|
|
| def num_locations(self): |
| return self._siblings.num_children(self._ix) |
|
|
| def __getitem__(self, loc_idx: int): |
| return self.get_location(loc_idx) |
|
|
| def get_location(self, li: int = 0): |
| """Returns the (li+1)-th location of the atom.""" |
| if li < 0 or li >= self.num_locations(): |
| raise Exception( |
| f"location index {li} out of range for Atom with {self.num_locations()} locations" |
| ) |
| li = self._siblings.child_index(self._ix, li) |
| return AtomLocationView(li, self) |
|
|
| def add_location(self, x, y, z, occ=1.0, B=0.0, alt=" ", at=None): |
| """Adds a location to this atom, append it to the end. |
| |
| Args: |
| x, y, z (float): coordinates of the location. |
| occ (float): occupancy for the location. |
| B (float): B-factor for the location. |
| alt (str): alternative location character. |
| at (int, optional): Index at which to insert the location. Default |
| is to append at the end (i.e., equivalent of ``at` being equal |
| to the current number of locations). |
| """ |
| if at is None: |
| at = self.num_locations() |
| li = self._siblings.insert_child( |
| self._ix, at, {"coor": [x, y, z, occ, B], "alt": alt} |
| ) |
| return AtomLocationView(li, self) |
|
|
| def delete(self): |
| """Deletes this atom from its Residue/Chain/System.""" |
| self._delete() |
| self._ix = -1 |
|
|
| @property |
| def residue(self): |
| return self._parent |
|
|
| @property |
| def chain(self): |
| return self.residue.chain |
|
|
| @property |
| def system(self): |
| return self.chain.system |
|
|
| @property |
| def name(self): |
| return self._siblings["name"][self._ix] |
|
|
| @property |
| def het(self): |
| return self._siblings["het"][self._ix] |
|
|
| """Location information getters and setters operate on the default (first) |
| location for this atom and throw an index error if there are no locations.""" |
|
|
| @property |
| def x(self): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| return self.system._locations["coor"][ix, 0] |
|
|
| @property |
| def y(self): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| return self.system._locations["coor"][ix, 1] |
|
|
| @property |
| def z(self): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| return self.system._locations["coor"][ix, 2] |
|
|
| @property |
| def coors(self): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| return self.system._locations["coor"][ix, 0:3] |
|
|
| @property |
| def occ(self): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| return self.system._locations["coor"][ix, 3] |
|
|
| @property |
| def B(self): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| return self.system._locations["coor"][ix, 4] |
|
|
| @property |
| def alt(self): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| return self.system._locations["alt"][ix] |
|
|
| @x.setter |
| def x(self, val): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| self.system._locations["coor"][ix, 0] = val |
|
|
| @y.setter |
| def y(self, val): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| self.system._locations["coor"][ix, 1] = val |
|
|
| @z.setter |
| def z(self, val): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| self.system._locations["coor"][ix, 2] = val |
|
|
| @occ.setter |
| def occ(self, val): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| self.system._locations["coor"][ix, 3] = val |
|
|
| @B.setter |
| def B(self, val): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| self.system._locations["coor"][ix, 4] = val |
|
|
| @alt.setter |
| def alt(self, val): |
| if self._siblings.num_children(self._ix) == 0: |
| raise Exception("atom has no locations") |
| ix = self._siblings.child_index(self._ix, 0) |
| self.system._locations["alt"][ix] = val |
|
|
|
|
| class DummyAtomView(AtomView): |
| """An dummy Atom view that can be attached to a residue but that does not |
| have any locations and with no other information.""" |
|
|
| def __init__(self, residue: ResidueView): |
| self._ix = -1 |
| self._parent = residue |
|
|
| def __str__(self): |
| return "DUMMY -> " + str(self.residue) |
|
|
| def locations(self): |
| return |
| yield |
|
|
| def num_locations(self): |
| return 0 |
|
|
| def __getitem__(self, loc_idx: int): |
| return None |
|
|
| def get_location(self, li: int = 0): |
| raise Exception(f"no locations in DUMMY atom") |
|
|
| def add_location(self, x, y, z, occ, B, alt, at=None): |
| raise Exception(f"can't add no locations to DUMMY atom") |
|
|
| @property |
| def residue(self): |
| return self._parent |
|
|
| @property |
| def chain(self): |
| return self.residue.chain |
|
|
| @property |
| def system(self): |
| return self.chain.system |
|
|
| @property |
| def name(self): |
| return None |
|
|
| @property |
| def het(self): |
| return None |
|
|
| @property |
| def x(self): |
| raise Exception(f"no coordinates in DUMMY atom") |
|
|
| @property |
| def y(self): |
| raise Exception(f"no coordinates in DUMMY atom") |
|
|
| @property |
| def z(self): |
| raise Exception(f"no coordinates in DUMMY atom") |
|
|
| @property |
| def occ(self): |
| raise Exception(f"no occupancy in DUMMY atom") |
|
|
| @property |
| def B(self): |
| raise Exception(f"no B-factor in DUMMY atom") |
|
|
| @property |
| def alt(self): |
| raise Exception(f"no alt flag in DUMMY atom") |
|
|
| @x.setter |
| def x(self, val): |
| raise Exception(f"can't set coordinate for DUMMY atom") |
|
|
| @y.setter |
| def y(self, val): |
| raise Exception(f"can't set coordinate for DUMMY atom") |
|
|
| @z.setter |
| def z(self, val): |
| raise Exception(f"can't set coordinate for DUMMY atom") |
|
|
| @occ.setter |
| def occ(self, val): |
| raise Exception(f"can't set occupancy for DUMMY atom") |
|
|
| @B.setter |
| def B(self, val): |
| raise Exception(f"can't set B-factor for DUMMY atom") |
|
|
| @alt.setter |
| def alt(self, val): |
| raise Exception(f"can't set alt flag for DUMMY atom") |
|
|
|
|
| @dataclass |
| class AtomLocationView(BaseView): |
| """An AtomLocation view, allowing hierarchical exploration and editing.""" |
|
|
| def __init__(self, ix: int, atom: AtomView): |
| self._ix = ix |
| self._parent = atom |
| self._siblings = atom.system._locations |
|
|
| def __str__(self): |
| return f"{self.x} {self.y} {self.z}" |
|
|
| def swap(self, other: AtomLocationView): |
| """Swaps information between itself and the provided atom location. |
| |
| Args: |
| other (AtomLocationView): the other atom location to swap with. |
| """ |
| self.x, other.x = other.x, self.x |
| self.y, other.y = other.y, self.y |
| self.z, other.z = other.z, self.z |
| self.occ, other.occ = other.occ, self.occ |
| self.B, other.B = other.B, self.B |
| self.alt, other.alt = other.alt, self.alt |
|
|
| def defined(self): |
| """Return whether this is a valid location.""" |
| return (self.x is not None) and (self.y is not None) and (self.z is not None) |
|
|
| @property |
| def atom(self): |
| return self._parent |
|
|
| @property |
| def residue(self): |
| return self.atom.residue |
|
|
| @property |
| def chain(self): |
| return self.residue.chain |
|
|
| @property |
| def system(self): |
| return self.chain.system |
|
|
| @property |
| def x(self): |
| return self.system._locations["coor"][self._ix, 0] |
|
|
| @property |
| def y(self): |
| return self.system._locations["coor"][self._ix, 1] |
|
|
| @property |
| def z(self): |
| return self.system._locations["coor"][self._ix, 2] |
|
|
| @property |
| def occ(self): |
| return self.system._locations["coor"][self._ix, 3] |
|
|
| @property |
| def B(self): |
| return self.system._locations["coor"][self._ix, 4] |
|
|
| @property |
| def alt(self): |
| return self.system._locations["alt"][self._ix] |
|
|
| @property |
| def coors(self): |
| return np.array(self.system._locations["coor"][self._ix, 0:3]) |
|
|
| @property |
| def coor_info(self): |
| return np.array(self.system._locations["coor"][self._ix]) |
|
|
| @x.setter |
| def x(self, val): |
| self.system._locations["coor"][self._ix, 0] = val |
|
|
| @y.setter |
| def y(self, val): |
| self.system._locations["coor"][self._ix, 1] = val |
|
|
| @z.setter |
| def z(self, val): |
| self.system._locations["coor"][self._ix, 2] = val |
|
|
| @coors.setter |
| def coors(self, val): |
| self.system._locations["coor"][self._ix, 0:3] = val |
|
|
| @coor_info.setter |
| def coor_info(self, val): |
| self.system._locations["coor"][self._ix] = val |
|
|
| @occ.setter |
| def occ(self, val): |
| self.system._locations["coor"][self._ix, 3] = val |
|
|
| @B.setter |
| def B(self, val): |
| self.system._locations["coor"][self._ix, 4] = val |
|
|
| @alt.setter |
| def alt(self, val): |
| self.system._locations["alt"][self._ix] = val |
|
|
|
|
| class ExpressionTreeEvaluator: |
| """A class for evaluating custom logical parenthetical expressions. The |
| implementation is very generic, supports nullary, unary, and binary |
| operators, and does not know anything about what the expressions actually |
| mean. Instead the class interprets the expression as a tree of sub- |
| expressions, governed by parentheses and operators, and traverses the |
| calling upon a user-specified evaluation function to evaluate leaf |
| nodes as the tree is gradually collapsed into a single node. This |
| can be used for evaluating set expressions, algebraic expressions, and |
| others. |
| |
| Args: |
| operators_nullary (list): A list of strings designating nullary operators |
| (i.e., operators that do not have any operands). E.g., if the language |
| describes selection algebra, these could be "hyd", "all", or "none"]. |
| operators_unary (list): A list of strings designating unary operators |
| (i.e., operators that have one operand, which must comes to the right |
| of the operator). E.g., if the language describes selection algebra, |
| these could be "name", "resid", or "chain". |
| operators_binary (list): A list of strings designating binary operators |
| (i.e., operators that have two operands, one on each side of the |
| operator). E.g., if the language describes selection algebra, thse |
| could be "and", "or", or "around". |
| eval_function (str): A function that is able to evaluate a leaf node of |
| the expression tree. It shall accept three parameters: |
| |
| operator (str): name of the operator |
| left: the left operand. Will be None if the left operand is missing or |
| not relevant. Otherwise, can be either a list of strings, which |
| should represent an evaluatable sub-expression corresponding to the |
| left operand, or the result of a prior evaluation of this function. |
| right: Same as `left` but for the right operand. |
| |
| The function should attempt to evaluate the resulting expression and |
| return None in the case of failing or a dictionary with the result of |
| the evaluation stored under key "result". |
| left_associativity (bool): If True (the default), operators are taken to be |
| left-associative. Meaning something like "A and B or C" is "(A and B) or C". |
| If False, the operators are taken to be right-associative, such that |
| the same expression becomes "A and (B or C)". NOTE: MST is right-associative |
| but often human intiution tends to be left-associative. |
| debug (bool): If True (default is false), will print a great deal of debugging |
| messages to help diagnose any evaluation problems. |
| """ |
|
|
| def __init__( |
| self, |
| operators_nullary: list, |
| operators_unary: list, |
| operators_binary: list, |
| eval_function: function, |
| left_associativity: bool = True, |
| debug: bool = False, |
| ): |
| self.operators_nullary = operators_nullary |
| self.operators_unary = operators_unary |
| self.operators_binary = operators_binary |
| self.operators = operators_nullary + operators_unary + operators_binary |
| self.eval_function = eval_function |
| self.debug = debug |
| self.left_associativity = left_associativity |
|
|
| def _traverse_expression_tree(self, E, i=0, eval_all=True, debug=False): |
| def _collect_operands(E, j): |
| |
| operands = [] |
| for k in range(len(E[j:])): |
| if E[j + k] in self.operators: |
| k = k - 1 |
| break |
| operands.append(E[j + k]) |
| return operands, j + k + 1 |
|
|
| def _find_matching_close_paren(E, beg: int): |
| c = 0 |
| for i in range(beg, len(E)): |
| if E[i] == "(": |
| c = c + 1 |
| elif E[i] == ")": |
| c = c - 1 |
| if c == 0: |
| return i |
| return None |
|
|
| def _my_eval(op, left, right, debug=False): |
| if debug: |
| print( |
| f"\t-> evaluating {operand_str(left)} | {op} | {operand_str(right)}" |
| ) |
| result = self.eval_function(op, left, right) |
| if debug: |
| print(f"\t-> got result {operand_str(result)}") |
| return result |
|
|
| def operand_str(operand): |
| if isinstance(operand, dict): |
| if "result" in operand and len(operand["result"]) > 15: |
| vec = list(operand["result"]) |
| beg = ", ".join([str(i) for i in vec[:5]]) |
| end = ", ".join([str(i) for i in vec[-5:]]) |
| return "{'result': " + f"{beg} ... {end} ({len(vec)} long)" + "}" |
| return str(operand) |
| return str(operand) |
|
|
| left, right, op = None, None, None |
| if debug: |
| print(f"-> received {E[i:]}") |
|
|
| while i < len(E): |
| if all([x is None for x in (left, right, op)]): |
| |
| if E[i] == "(": |
| end = _find_matching_close_paren(E, i) |
| if end is None: |
| return None, f"parenthesis imbalance starting with {E[i:]}" |
| |
| left, rem = self._traverse_expression_tree( |
| E[i + 1 : end], 0, eval_all=True, debug=debug |
| ) |
| if left is None: |
| return None, rem |
| i = end + 1 |
| if not eval_all: |
| return left, i |
| elif E[i] in self.operators_nullary: |
| |
| left = _my_eval(E[i], None, None, debug) |
| if left is None: |
| return None, f"failed to evaluate nullary operator '{E[i]}'" |
| i = i + 1 |
| elif E[i] in self.operators_unary: |
| op = E[i] |
| i = i + 1 |
| elif E[i] in self.operators: |
| |
| return None, f"unexpected binary operator in the context {E[i:]}" |
| else: |
| |
| left, i = _collect_operands(E, i) |
| elif (left is not None) and (op is None) and (right is None): |
| |
| if E[i] not in self.operators_binary: |
| return ( |
| None, |
| f"expected end or a binary operator when got '{E[i]}' in expression: {E}", |
| ) |
| op = E[i] |
| i = i + 1 |
| elif ( |
| (left is None) and (op in self.operators_unary) and (right is None) |
| ) or ( |
| (left is not None) and (op in self.operators_binary) and (right is None) |
| ): |
| |
| |
| |
| if ( |
| E[i] in (self.operators_nullary + self.operators_unary) |
| or E[i] == "(" |
| ): |
| right, i = self._traverse_expression_tree( |
| E, i, eval_all=not self.left_associativity, debug=debug |
| ) |
| if right is None: |
| return None, i |
| else: |
| right, i = _collect_operands(E, i) |
|
|
| |
| |
| |
| |
| result = _my_eval(op, left, right, debug) |
| if result is None: |
| return ( |
| None, |
| f"failed to evaluate operator '{op}' (in expression {E}) with operands {operand_str(left)} and {operand_str(right)}", |
| ) |
| if not eval_all: |
| return result, i |
| left = result |
| op, right = None, None |
|
|
| else: |
| return ( |
| None, |
| f"encountered an unexpected condition when evaluating {E}: left is {operand_str(left)}, op is {op}, or right {operand_str(right)}", |
| ) |
|
|
| if (op is not None) or (right is not None): |
| return None, f"expression ended unexpectedly" |
| if left is None: |
| return None, f"failed to evaluate expression: {E}" |
|
|
| return left, i |
|
|
| def evaluate(self, expression: str): |
| """Evaluates the expression and returns the result.""" |
|
|
| def _split_tokens(expr): |
| |
| parts = list(re.split("([()])", expr)) |
| |
| return [ |
| t.strip() |
| for p in parts |
| for t in re.split("\s+", p.strip()) |
| if t.strip() != "" |
| ] |
|
|
| |
| E = _split_tokens(expression) |
| val, rem = self._traverse_expression_tree(E, debug=self.debug) |
| if val is None: |
| raise Exception( |
| f"failed to evaluate expression: '{expression}', reason: {rem}" |
| ) |
|
|
| return val["result"] |
|
|