|
|
|
|
|
import os |
|
|
from typing import Optional, Any |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
from utils import register as R |
|
|
from utils.const import sidechain_atoms |
|
|
|
|
|
from data.converter.list_blocks_to_pdb import list_blocks_to_pdb |
|
|
|
|
|
from .format import VOCAB, Block, Atom |
|
|
from .mmap_dataset import MMAPDataset |
|
|
from .resample import ClusterResampler |
|
|
|
|
|
|
|
|
|
|
|
def calculate_covariance_matrix(point_cloud): |
|
|
|
|
|
covariance_matrix = np.cov(point_cloud, rowvar=False) |
|
|
return covariance_matrix |
|
|
|
|
|
|
|
|
@R.register('CoDesignDataset') |
|
|
class CoDesignDataset(MMAPDataset): |
|
|
|
|
|
MAX_N_ATOM = 14 |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mmap_dir: str, |
|
|
backbone_only: bool, |
|
|
specify_data: Optional[str] = None, |
|
|
specify_index: Optional[str] = None, |
|
|
padding_collate: bool = False, |
|
|
cluster: Optional[str] = None, |
|
|
use_covariance_matrix: bool = False |
|
|
) -> None: |
|
|
super().__init__(mmap_dir, specify_data, specify_index) |
|
|
self.mmap_dir = mmap_dir |
|
|
self.backbone_only = backbone_only |
|
|
self._lengths = [len(prop[-1].split(',')) + int(prop[1]) for prop in self._properties] |
|
|
self.padding_collate = padding_collate |
|
|
self.resampler = ClusterResampler(cluster) if cluster else None |
|
|
self.use_covariance_matrix = use_covariance_matrix |
|
|
|
|
|
self.dynamic_idxs = [i for i in range(len(self))] |
|
|
self.update_epoch() |
|
|
|
|
|
def update_epoch(self): |
|
|
if self.resampler is not None: |
|
|
self.dynamic_idxs = self.resampler(len(self)) |
|
|
|
|
|
def get_len(self, idx): |
|
|
return self._lengths[self.dynamic_idxs[idx]] |
|
|
|
|
|
def get_summary(self, idx: int): |
|
|
props = self._properties[idx] |
|
|
_id = self._indexes[idx][0].split('.')[0] |
|
|
ref_pdb = os.path.join(self.mmap_dir, '..', 'pdbs', _id + '.pdb') |
|
|
rec_chain, lig_chain = props[4], props[5] |
|
|
return _id, ref_pdb, rec_chain, lig_chain |
|
|
|
|
|
def __getitem__(self, idx: int): |
|
|
idx = self.dynamic_idxs[idx] |
|
|
rec_blocks, lig_blocks = super().__getitem__(idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pocket_idx = [int(i) for i in self._properties[idx][-1].split(',')] |
|
|
rec_position_ids = [i + 1 for i, _ in enumerate(rec_blocks)] |
|
|
rec_blocks = [rec_blocks[i] for i in pocket_idx] |
|
|
rec_position_ids = [rec_position_ids[i] for i in pocket_idx] |
|
|
rec_blocks = [Block.from_tuple(tup) for tup in rec_blocks] |
|
|
lig_blocks = [Block.from_tuple(tup) for tup in lig_blocks] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks] |
|
|
position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)] |
|
|
X, S, atom_mask = [], [], [] |
|
|
for block in rec_blocks + lig_blocks: |
|
|
symbol = VOCAB.abrv_to_symbol(block.abrv) |
|
|
atom2coord = { unit.name: unit.get_coord() for unit in block.units } |
|
|
bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist() |
|
|
coords, coord_mask = [], [] |
|
|
for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []): |
|
|
if atom_name in atom2coord: |
|
|
coords.append(atom2coord[atom_name]) |
|
|
coord_mask.append(1) |
|
|
else: |
|
|
coords.append(bb_pos) |
|
|
coord_mask.append(0) |
|
|
n_pad = self.MAX_N_ATOM - len(coords) |
|
|
for _ in range(n_pad): |
|
|
coords.append(bb_pos) |
|
|
coord_mask.append(0) |
|
|
|
|
|
X.append(coords) |
|
|
S.append(VOCAB.symbol_to_idx(symbol)) |
|
|
atom_mask.append(coord_mask) |
|
|
|
|
|
X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool) |
|
|
mask = torch.tensor(mask, dtype=torch.bool) |
|
|
if self.backbone_only: |
|
|
X, atom_mask = X[:, :4], atom_mask[:, :4] |
|
|
|
|
|
if self.use_covariance_matrix: |
|
|
cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) |
|
|
eps = 1e-4 |
|
|
cov = cov + eps * np.identity(cov.shape[0]) |
|
|
L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0) |
|
|
else: |
|
|
L = None |
|
|
|
|
|
item = { |
|
|
'X': X, |
|
|
'S': torch.tensor(S, dtype=torch.long), |
|
|
'position_ids': torch.tensor(position_ids, dtype=torch.long), |
|
|
'mask': mask, |
|
|
'atom_mask': atom_mask, |
|
|
'lengths': len(S), |
|
|
} |
|
|
if L is not None: |
|
|
item['L'] = L |
|
|
return item |
|
|
|
|
|
def collate_fn(self, batch): |
|
|
if self.padding_collate: |
|
|
results = {} |
|
|
pad_idx = VOCAB.symbol_to_idx(VOCAB.PAD) |
|
|
for key in batch[0]: |
|
|
values = [item[key] for item in batch] |
|
|
if values[0] is None: |
|
|
results[key] = None |
|
|
continue |
|
|
if key == 'lengths': |
|
|
results[key] = torch.tensor(values, dtype=torch.long) |
|
|
elif key == 'S': |
|
|
results[key] = pad_sequence(values, batch_first=True, padding_value=pad_idx) |
|
|
else: |
|
|
results[key] = pad_sequence(values, batch_first=True, padding_value=0) |
|
|
return results |
|
|
else: |
|
|
results = {} |
|
|
for key in batch[0]: |
|
|
values = [item[key] for item in batch] |
|
|
if values[0] is None: |
|
|
results[key] = None |
|
|
continue |
|
|
if key == 'lengths': |
|
|
results[key] = torch.tensor(values, dtype=torch.long) |
|
|
else: |
|
|
results[key] = torch.cat(values, dim=0) |
|
|
return results |
|
|
|
|
|
|
|
|
@R.register('ShapeDataset') |
|
|
class ShapeDataset(CoDesignDataset): |
|
|
def __init__( |
|
|
self, |
|
|
mmap_dir: str, |
|
|
specify_data: Optional[str] = None, |
|
|
specify_index: Optional[str] = None, |
|
|
padding_collate: bool = False, |
|
|
cluster: Optional[str] = None |
|
|
) -> None: |
|
|
super().__init__(mmap_dir, False, specify_data, specify_index, padding_collate, cluster) |
|
|
self.ca_idx = VOCAB.backbone_atoms.index('CA') |
|
|
|
|
|
def __getitem__(self, idx: int): |
|
|
item = super().__getitem__(idx) |
|
|
|
|
|
|
|
|
X = item['X'] |
|
|
atom_mask = item['atom_mask'] |
|
|
ca_x = X[:, self.ca_idx].unsqueeze(1) |
|
|
sc_x = X[:, 4:] |
|
|
dist = torch.norm(sc_x - ca_x, dim=-1) |
|
|
dist = dist.masked_fill(~atom_mask[:, 4:], 1e10) |
|
|
furthest_atom_x = sc_x[torch.arange(sc_x.shape[0]), torch.argmax(dist, dim=-1)] |
|
|
X = torch.cat([ca_x, furthest_atom_x.unsqueeze(1)], dim=1) |
|
|
|
|
|
item['X'] = X |
|
|
return item |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import sys |
|
|
dataset = CoDesignDataset(sys.argv[1], backbone_only=True) |
|
|
print(dataset[0]) |
|
|
|