|
|
|
|
|
|
|
|
from copy import copy |
|
|
from typing import List, Tuple, Iterator, Optional |
|
|
|
|
|
from utils import const |
|
|
|
|
|
|
|
|
class MoleculeVocab: |
|
|
|
|
|
MAX_ATOM_NUMBER = 14 |
|
|
|
|
|
def __init__(self): |
|
|
self.backbone_atoms = ['N', 'CA', 'C', 'O'] |
|
|
self.PAD, self.MASK, self.UNK, self.LAT = '#', '*', '?', '&' |
|
|
specials = [ |
|
|
(self.PAD, 'PAD'), (self.MASK, 'MASK'), (self.UNK, 'UNK'), |
|
|
(self.LAT, '<L>') |
|
|
] |
|
|
|
|
|
aas = const.aas |
|
|
|
|
|
|
|
|
sms = [] |
|
|
|
|
|
self.atom_pad, self.atom_mask, self.atom_latent = 'pad', 'msk', 'lat' |
|
|
self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent = 'pad', 'msk', 'lat' |
|
|
self.atom_pos_sm = 'sml' |
|
|
|
|
|
|
|
|
self.idx2block = specials + aas + sms |
|
|
self.symbol2idx, self.abrv2idx = {}, {} |
|
|
for i, (symbol, abrv) in enumerate(self.idx2block): |
|
|
self.symbol2idx[symbol] = i |
|
|
self.abrv2idx[abrv] = i |
|
|
self.special_mask = [1 for _ in specials] + [0 for _ in aas] + [0 for _ in sms] |
|
|
|
|
|
|
|
|
self.idx2atom = [self.atom_pad, self.atom_mask, self.atom_latent] + const.periodic_table |
|
|
self.idx2atom_pos = [self.atom_pos_pad, self.atom_pos_mask, self.atom_pos_latent, '', 'A', 'B', 'G', 'D', 'E', 'Z', 'H', 'XT', 'P', self.atom_pos_sm] |
|
|
self.atom2idx, self.atom_pos2idx = {}, {} |
|
|
self.atom2idx = {} |
|
|
for i, atom in enumerate(self.idx2atom): |
|
|
self.atom2idx[atom] = i |
|
|
for i, atom_pos in enumerate(self.idx2atom_pos): |
|
|
self.atom_pos2idx[atom_pos] = i |
|
|
|
|
|
|
|
|
|
|
|
def abrv_to_symbol(self, abrv): |
|
|
idx = self.abrv_to_idx(abrv) |
|
|
return None if idx is None else self.idx2block[idx][0] |
|
|
|
|
|
def symbol_to_abrv(self, symbol): |
|
|
idx = self.symbol_to_idx(symbol) |
|
|
return None if idx is None else self.idx2block[idx][1] |
|
|
|
|
|
def abrv_to_idx(self, abrv): |
|
|
abrv = abrv.upper() |
|
|
return self.abrv2idx.get(abrv, self.abrv2idx['UNK']) |
|
|
|
|
|
def symbol_to_idx(self, symbol): |
|
|
|
|
|
return self.symbol2idx.get(symbol, self.abrv2idx['UNK']) |
|
|
|
|
|
def idx_to_symbol(self, idx): |
|
|
return self.idx2block[idx][0] |
|
|
|
|
|
def idx_to_abrv(self, idx): |
|
|
return self.idx2block[idx][1] |
|
|
|
|
|
def get_pad_idx(self): |
|
|
return self.symbol_to_idx(self.PAD) |
|
|
|
|
|
def get_mask_idx(self): |
|
|
return self.symbol_to_idx(self.MASK) |
|
|
|
|
|
def get_special_mask(self): |
|
|
return copy(self.special_mask) |
|
|
|
|
|
|
|
|
|
|
|
def get_atom_pad_idx(self): |
|
|
return self.atom2idx[self.atom_pad] |
|
|
|
|
|
def get_atom_mask_idx(self): |
|
|
return self.atom2idx[self.atom_mask] |
|
|
|
|
|
def get_atom_latent_idx(self): |
|
|
return self.atom2idx[self.atom_latent] |
|
|
|
|
|
def get_atom_pos_pad_idx(self): |
|
|
return self.atom_pos2idx[self.atom_pos_pad] |
|
|
|
|
|
def get_atom_pos_mask_idx(self): |
|
|
return self.atom_pos2idx[self.atom_pos_mask] |
|
|
|
|
|
def get_atom_pos_latent_idx(self): |
|
|
return self.atom_pos2idx[self.atom_pos_latent] |
|
|
|
|
|
def idx_to_atom(self, idx): |
|
|
return self.idx2atom[idx] |
|
|
|
|
|
def atom_to_idx(self, atom): |
|
|
atom = atom.upper() |
|
|
return self.atom2idx.get(atom, self.atom2idx[self.atom_mask]) |
|
|
|
|
|
def idx_to_atom_pos(self, idx): |
|
|
return self.idx2atom_pos[idx] |
|
|
|
|
|
def atom_pos_to_idx(self, atom_pos): |
|
|
return self.atom_pos2idx.get(atom_pos, self.atom_pos2idx[self.atom_pos_mask]) |
|
|
|
|
|
|
|
|
|
|
|
def get_num_atom_type(self): |
|
|
return len(self.idx2atom) |
|
|
|
|
|
def get_num_atom_pos(self): |
|
|
return len(self.idx2atom_pos) |
|
|
|
|
|
def get_num_block_type(self): |
|
|
return len(self.special_mask) - sum(self.special_mask) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.symbol2idx) |
|
|
|
|
|
|
|
|
@property |
|
|
def ca_channel_idx(self): |
|
|
return self.backbone_atoms.index('CA') |
|
|
|
|
|
|
|
|
VOCAB = MoleculeVocab() |
|
|
|
|
|
|
|
|
class Atom: |
|
|
def __init__(self, atom_name: str, coordinate: List[float], element: str, pos_code: str=None): |
|
|
self.name = atom_name |
|
|
self.coordinate = coordinate |
|
|
self.element = element |
|
|
if pos_code is None: |
|
|
pos_code = atom_name.lstrip(element) |
|
|
self.pos_code = pos_code |
|
|
else: |
|
|
self.pos_code = pos_code |
|
|
|
|
|
def get_element(self): |
|
|
return self.element |
|
|
|
|
|
def get_coord(self): |
|
|
return copy(self.coordinate) |
|
|
|
|
|
def get_pos_code(self): |
|
|
return self.pos_code |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return self.name |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"Atom ({self.name}): {self.element}({self.pos_code}) [{','.join(['{:.4f}'.format(num) for num in self.coordinate])}]" |
|
|
|
|
|
def to_tuple(self): |
|
|
return ( |
|
|
self.name, |
|
|
self.coordinate, |
|
|
self.element, |
|
|
self.pos_code |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_tuple(self, data): |
|
|
return Atom( |
|
|
atom_name=data[0], |
|
|
coordinate=data[1], |
|
|
element=data[2], |
|
|
pos_code=data[3] |
|
|
) |
|
|
|
|
|
|
|
|
class Block: |
|
|
def __init__(self, abrv: str, units: List[Atom], id: Optional[any]=None) -> None: |
|
|
self.abrv: str = abrv |
|
|
self.units: List[Atom] = units |
|
|
self._uname2idx = { unit.name: i for i, unit in enumerate(self.units) } |
|
|
self.id = id |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.units) |
|
|
|
|
|
def __iter__(self) -> Iterator[Atom]: |
|
|
return iter(self.units) |
|
|
|
|
|
def get_unit_by_name(self, name: str) -> Atom: |
|
|
idx = self._uname2idx[name] |
|
|
return self.units[idx] |
|
|
|
|
|
def has_unit(self, name: str) -> bool: |
|
|
return name in self._uname2idx |
|
|
|
|
|
def to_tuple(self): |
|
|
return ( |
|
|
self.abrv, |
|
|
[unit.to_tuple() for unit in self.units], |
|
|
self.id |
|
|
) |
|
|
|
|
|
def is_residue(self): |
|
|
return self.has_unit('CA') and self.has_unit('N') and self.has_unit('C') and self.has_unit('O') |
|
|
|
|
|
@classmethod |
|
|
def from_tuple(self, data): |
|
|
return Block( |
|
|
abrv=data[0], |
|
|
units=[Atom.from_tuple(unit_data) for unit_data in data[1]], |
|
|
id=data[2] |
|
|
) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"Block ({self.abrv}):\n\t" + '\n\t'.join([repr(at) for at in self.units]) + '\n' |