|
|
"""PDB data loader.""" |
|
|
import math |
|
|
import torch |
|
|
import tree |
|
|
import numpy as np |
|
|
import torch |
|
|
import pandas as pd |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import esm |
|
|
import copy |
|
|
|
|
|
from data import utils as du |
|
|
from data.repr import get_pre_repr |
|
|
from openfold.data import data_transforms |
|
|
from openfold.utils import rigid_utils |
|
|
from data.residue_constants import restype_atom37_mask, order2restype_with_mask |
|
|
|
|
|
from pytorch_lightning import LightningDataModule |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torch.utils.data.distributed import DistributedSampler, dist |
|
|
from scipy.spatial.transform import Rotation as scipy_R |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PdbDataModule(LightningDataModule): |
|
|
def __init__(self, data_cfg): |
|
|
super().__init__() |
|
|
self.data_cfg = data_cfg |
|
|
self.loader_cfg = data_cfg.loader |
|
|
self.dataset_cfg = data_cfg.dataset |
|
|
self.sampler_cfg = data_cfg.sampler |
|
|
|
|
|
def setup(self, stage: str): |
|
|
self._train_dataset = PdbDataset( |
|
|
dataset_cfg=self.dataset_cfg, |
|
|
is_training=True, |
|
|
) |
|
|
self._valid_dataset = PdbDataset( |
|
|
dataset_cfg=self.dataset_cfg, |
|
|
is_training=False, |
|
|
) |
|
|
|
|
|
def train_dataloader(self, rank=None, num_replicas=None): |
|
|
num_workers = self.loader_cfg.num_workers |
|
|
return DataLoader( |
|
|
self._train_dataset, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sampler=DistributedSampler(self._train_dataset, shuffle=True), |
|
|
|
|
|
num_workers=self.loader_cfg.num_workers, |
|
|
prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor, |
|
|
persistent_workers=True if num_workers > 0 else False, |
|
|
|
|
|
) |
|
|
|
|
|
def val_dataloader(self): |
|
|
num_workers = self.loader_cfg.num_workers |
|
|
return DataLoader( |
|
|
self._valid_dataset, |
|
|
sampler=DistributedSampler(self._valid_dataset, shuffle=False), |
|
|
num_workers=self.loader_cfg.num_workers, |
|
|
prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor, |
|
|
persistent_workers=True, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
class PdbDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
dataset_cfg, |
|
|
is_training, |
|
|
): |
|
|
self._log = logging.getLogger(__name__) |
|
|
self._is_training = is_training |
|
|
self._dataset_cfg = dataset_cfg |
|
|
self.split_frac = self._dataset_cfg.split_frac |
|
|
self.random_seed = self._dataset_cfg.seed |
|
|
|
|
|
|
|
|
self._init_metadata() |
|
|
|
|
|
@property |
|
|
def is_training(self): |
|
|
return self._is_training |
|
|
|
|
|
@property |
|
|
def dataset_cfg(self): |
|
|
return self._dataset_cfg |
|
|
|
|
|
def _init_metadata(self): |
|
|
"""Initialize metadata.""" |
|
|
|
|
|
|
|
|
pdb_csv = pd.read_csv(self.dataset_cfg.csv_path) |
|
|
self.raw_csv = pdb_csv |
|
|
pdb_csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_num_res] |
|
|
pdb_csv = pdb_csv[pdb_csv.modeled_seq_len >= self.dataset_cfg.min_num_res] |
|
|
|
|
|
if self.dataset_cfg.subset is not None: |
|
|
pdb_csv = pdb_csv.iloc[:self.dataset_cfg.subset] |
|
|
pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.is_training: |
|
|
self.csv = pdb_csv[pdb_csv['is_trainset']] |
|
|
self.csv = pdb_csv.sample(frac=self.split_frac, random_state=self.random_seed).reset_index() |
|
|
self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"train.csv"), index=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._log.info( |
|
|
f"Training: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") |
|
|
else: |
|
|
self.csv = pdb_csv[~pdb_csv['is_trainset']] |
|
|
|
|
|
|
|
|
|
|
|
self.csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_eval_length] |
|
|
self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"valid.csv"), index=False) |
|
|
|
|
|
self.csv = self.csv.sample(n=min(self.dataset_cfg.max_valid_num, len(self.csv)), random_state=self.random_seed).reset_index() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._log.info( |
|
|
f"Valid: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
return len(self.csv) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
|
|
|
processed_path = self.csv.iloc[idx]['processed_path'] |
|
|
chain_feats = du.read_pkl(processed_path) |
|
|
chain_feats['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32) |
|
|
|
|
|
energy = chain_feats['energy'] |
|
|
|
|
|
|
|
|
if self.is_training and self._dataset_cfg.use_split: |
|
|
|
|
|
|
|
|
split_len = random.randint(self.dataset_cfg.min_num_res, min(self._dataset_cfg.split_len, chain_feats['aatype'].shape[0])) |
|
|
|
|
|
idx = random.randint(0,chain_feats['aatype'].shape[0]-split_len) |
|
|
output_total = copy.deepcopy(chain_feats) |
|
|
|
|
|
output_total['energy'] = torch.ones(chain_feats['aatype'].shape) |
|
|
|
|
|
output_temp = tree.map_structure(lambda x: x[idx:idx+split_len], output_total) |
|
|
|
|
|
bb_center = np.sum(output_temp['bb_positions'], axis=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) |
|
|
output_temp['trans_1']=(output_temp['trans_1'] - torch.from_numpy(bb_center[None, :])).float() |
|
|
output_temp['bb_positions']=output_temp['bb_positions']- bb_center[None, :] |
|
|
output_temp['all_atom_positions']=output_temp['all_atom_positions'] - torch.from_numpy(bb_center[None, None, :]) |
|
|
output_temp['pair_repr_pre'] = output_temp['pair_repr_pre'][:,idx:idx+split_len] |
|
|
|
|
|
bb_center_esmfold = torch.sum(output_temp['trans_esmfold'], dim=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) |
|
|
output_temp['trans_esmfold']=(output_temp['trans_esmfold'] - bb_center_esmfold[None, :]).float() |
|
|
|
|
|
chain_feats = output_temp |
|
|
chain_feats['energy'] = energy |
|
|
|
|
|
|
|
|
if self._dataset_cfg.use_rotate_enhance: |
|
|
rot_vet = [random.random() for _ in range(3)] |
|
|
rot_mat = torch.tensor(scipy_R.from_rotvec(rot_vet).as_matrix()) |
|
|
chain_feats['all_atom_positions']=torch.einsum('lij,kj->lik',chain_feats['all_atom_positions'], |
|
|
rot_mat.type(chain_feats['all_atom_positions'].dtype)) |
|
|
|
|
|
all_atom_mask = np.array([restype_atom37_mask[i] for i in chain_feats['aatype']]) |
|
|
|
|
|
chain_feats_temp = { |
|
|
'aatype': chain_feats['aatype'], |
|
|
'all_atom_positions': chain_feats['all_atom_positions'], |
|
|
'all_atom_mask': torch.tensor(all_atom_mask).double(), |
|
|
} |
|
|
chain_feats_temp = data_transforms.atom37_to_frames(chain_feats_temp) |
|
|
curr_rigid = rigid_utils.Rigid.from_tensor_4x4(chain_feats_temp['rigidgroups_gt_frames'])[:, 0] |
|
|
chain_feats['trans_1'] = curr_rigid.get_trans() |
|
|
chain_feats['rotmats_1'] = curr_rigid.get_rots().get_rot_mats() |
|
|
chain_feats['bb_positions']=(chain_feats['trans_1']).numpy().astype(chain_feats['bb_positions'].dtype) |
|
|
|
|
|
return chain_feats |
|
|
|