Spaces:
Runtime error
Runtime error
import os | |
import random | |
import logging | |
import datetime | |
import pandas as pd | |
import joblib | |
import pickle | |
import lmdb | |
import subprocess | |
import torch | |
from Bio import PDB, SeqRecord, SeqIO, Seq | |
from Bio.PDB import PDBExceptions | |
from Bio.PDB import Polypeptide | |
from torch.utils.data import Dataset | |
from tqdm.auto import tqdm | |
from ..utils.protein import parsers, constants | |
from ._base import register_dataset | |
ALLOWED_AG_TYPES = { | |
'protein', | |
'protein | protein', | |
'protein | protein | protein', | |
'protein | protein | protein | protein | protein', | |
'protein | protein | protein | protein', | |
} | |
RESOLUTION_THRESHOLD = 4.0 | |
TEST_ANTIGENS = [ | |
'sars-cov-2 receptor binding domain', | |
'hiv-1 envelope glycoprotein gp160', | |
'mers s', | |
'influenza a virus', | |
'cd27 antigen', | |
] | |
def nan_to_empty_string(val): | |
if val != val or not val: | |
return '' | |
else: | |
return val | |
def nan_to_none(val): | |
if val != val or not val: | |
return None | |
else: | |
return val | |
def split_sabdab_delimited_str(val): | |
if not val: | |
return [] | |
else: | |
return [s.strip() for s in val.split('|')] | |
def parse_sabdab_resolution(val): | |
if val == 'NOT' or not val or val != val: | |
return None | |
elif isinstance(val, str) and ',' in val: | |
return float(val.split(',')[0].strip()) | |
else: | |
return float(val) | |
def _aa_tensor_to_sequence(aa): | |
return ''.join([Polypeptide.index_to_one(a.item()) for a in aa.flatten()]) | |
def _label_heavy_chain_cdr(data, seq_map, max_cdr3_length=30): | |
if data is None or seq_map is None: | |
return data, seq_map | |
# Add CDR labels | |
cdr_flag = torch.zeros_like(data['aa']) | |
for position, idx in seq_map.items(): | |
resseq = position[1] | |
cdr_type = constants.ChothiaCDRRange.to_cdr('H', resseq) | |
if cdr_type is not None: | |
cdr_flag[idx] = cdr_type | |
data['cdr_flag'] = cdr_flag | |
# Add CDR sequence annotations | |
data['H1_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H1] ) | |
data['H2_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H2] ) | |
data['H3_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H3] ) | |
cdr3_length = (cdr_flag == constants.CDR.H3).sum().item() | |
# Remove too long CDR3 | |
if cdr3_length > max_cdr3_length: | |
cdr_flag[cdr_flag == constants.CDR.H3] = 0 | |
logging.warning(f'CDR-H3 too long {cdr3_length}. Removed.') | |
return None, None | |
# Filter: ensure CDR3 exists | |
if cdr3_length == 0: | |
logging.warning('No CDR-H3 found in the heavy chain.') | |
return None, None | |
return data, seq_map | |
def _label_light_chain_cdr(data, seq_map, max_cdr3_length=30): | |
if data is None or seq_map is None: | |
return data, seq_map | |
cdr_flag = torch.zeros_like(data['aa']) | |
for position, idx in seq_map.items(): | |
resseq = position[1] | |
cdr_type = constants.ChothiaCDRRange.to_cdr('L', resseq) | |
if cdr_type is not None: | |
cdr_flag[idx] = cdr_type | |
data['cdr_flag'] = cdr_flag | |
data['L1_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L1] ) | |
data['L2_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L2] ) | |
data['L3_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L3] ) | |
cdr3_length = (cdr_flag == constants.CDR.L3).sum().item() | |
# Remove too long CDR3 | |
if cdr3_length > max_cdr3_length: | |
cdr_flag[cdr_flag == constants.CDR.L3] = 0 | |
logging.warning(f'CDR-L3 too long {cdr3_length}. Removed.') | |
return None, None | |
# Ensure CDR3 exists | |
if cdr3_length == 0: | |
logging.warning('No CDRs found in the light chain.') | |
return None, None | |
return data, seq_map | |
def preprocess_sabdab_structure(task): | |
entry = task['entry'] | |
pdb_path = task['pdb_path'] | |
parser = PDB.PDBParser(QUIET=True) | |
model = parser.get_structure(id, pdb_path)[0] | |
parsed = { | |
'id': entry['id'], | |
'heavy': None, | |
'heavy_seqmap': None, | |
'light': None, | |
'light_seqmap': None, | |
'antigen': None, | |
'antigen_seqmap': None, | |
} | |
try: | |
if entry['H_chain'] is not None: | |
( | |
parsed['heavy'], | |
parsed['heavy_seqmap'] | |
) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure( | |
model[entry['H_chain']], | |
max_resseq = 113 # Chothia, end of Heavy chain Fv | |
)) | |
if entry['L_chain'] is not None: | |
( | |
parsed['light'], | |
parsed['light_seqmap'] | |
) = _label_light_chain_cdr(*parsers.parse_biopython_structure( | |
model[entry['L_chain']], | |
max_resseq = 106 # Chothia, end of Light chain Fv | |
)) | |
if parsed['heavy'] is None and parsed['light'] is None: | |
raise ValueError('Neither valid H-chain or L-chain is found.') | |
if len(entry['ag_chains']) > 0: | |
chains = [model[c] for c in entry['ag_chains']] | |
( | |
parsed['antigen'], | |
parsed['antigen_seqmap'] | |
) = parsers.parse_biopython_structure(chains) | |
except ( | |
PDBExceptions.PDBConstructionException, | |
parsers.ParsingException, | |
KeyError, | |
ValueError, | |
) as e: | |
logging.warning('[{}] {}: {}'.format( | |
task['id'], | |
e.__class__.__name__, | |
str(e) | |
)) | |
return None | |
return parsed | |
class SAbDabDataset(Dataset): | |
MAP_SIZE = 32*(1024*1024*1024) # 32GB | |
def __init__( | |
self, | |
summary_path = './data/sabdab_summary_all.tsv', | |
chothia_dir = './data/all_structures/chothia', | |
processed_dir = './data/processed', | |
split = 'train', | |
split_seed = 2022, | |
transform = None, | |
reset = False, | |
): | |
super().__init__() | |
self.summary_path = summary_path | |
self.chothia_dir = chothia_dir | |
if not os.path.exists(chothia_dir): | |
raise FileNotFoundError( | |
f"SAbDab structures not found in {chothia_dir}. " | |
"Please download them from http://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/" | |
) | |
self.processed_dir = processed_dir | |
os.makedirs(processed_dir, exist_ok=True) | |
self.sabdab_entries = None | |
self._load_sabdab_entries() | |
self.db_conn = None | |
self.db_ids = None | |
self._load_structures(reset) | |
self.clusters = None | |
self.id_to_cluster = None | |
self._load_clusters(reset) | |
self.ids_in_split = None | |
self._load_split(split, split_seed) | |
self.transform = transform | |
def _load_sabdab_entries(self): | |
df = pd.read_csv(self.summary_path, sep='\t') | |
entries_all = [] | |
for i, row in tqdm( | |
df.iterrows(), | |
dynamic_ncols=True, | |
desc='Loading entries', | |
total=len(df), | |
): | |
entry_id = "{pdbcode}_{H}_{L}_{Ag}".format( | |
pdbcode = row['pdb'], | |
H = nan_to_empty_string(row['Hchain']), | |
L = nan_to_empty_string(row['Lchain']), | |
Ag = ''.join(split_sabdab_delimited_str( | |
nan_to_empty_string(row['antigen_chain']) | |
)) | |
) | |
ag_chains = split_sabdab_delimited_str( | |
nan_to_empty_string(row['antigen_chain']) | |
) | |
resolution = parse_sabdab_resolution(row['resolution']) | |
entry = { | |
'id': entry_id, | |
'pdbcode': row['pdb'], | |
'H_chain': nan_to_none(row['Hchain']), | |
'L_chain': nan_to_none(row['Lchain']), | |
'ag_chains': ag_chains, | |
'ag_type': nan_to_none(row['antigen_type']), | |
'ag_name': nan_to_none(row['antigen_name']), | |
'date': datetime.datetime.strptime(row['date'], '%m/%d/%y'), | |
'resolution': resolution, | |
'method': row['method'], | |
'scfv': row['scfv'], | |
} | |
# Filtering | |
if ( | |
(entry['ag_type'] in ALLOWED_AG_TYPES or entry['ag_type'] is None) | |
and (entry['resolution'] is not None and entry['resolution'] <= RESOLUTION_THRESHOLD) | |
): | |
entries_all.append(entry) | |
self.sabdab_entries = entries_all | |
def _load_structures(self, reset): | |
if not os.path.exists(self._structure_cache_path) or reset: | |
if os.path.exists(self._structure_cache_path): | |
os.unlink(self._structure_cache_path) | |
self._preprocess_structures() | |
with open(self._structure_cache_path + '-ids', 'rb') as f: | |
self.db_ids = pickle.load(f) | |
self.sabdab_entries = list( | |
filter( | |
lambda e: e['id'] in self.db_ids, | |
self.sabdab_entries | |
) | |
) | |
def _structure_cache_path(self): | |
return os.path.join(self.processed_dir, 'structures.lmdb') | |
def _preprocess_structures(self): | |
tasks = [] | |
for entry in self.sabdab_entries: | |
pdb_path = os.path.join(self.chothia_dir, '{}.pdb'.format(entry['pdbcode'])) | |
if not os.path.exists(pdb_path): | |
logging.warning(f"PDB not found: {pdb_path}") | |
continue | |
tasks.append({ | |
'id': entry['id'], | |
'entry': entry, | |
'pdb_path': pdb_path, | |
}) | |
data_list = joblib.Parallel( | |
n_jobs = max(joblib.cpu_count() // 2, 1), | |
)( | |
joblib.delayed(preprocess_sabdab_structure)(task) | |
for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess') | |
) | |
db_conn = lmdb.open( | |
self._structure_cache_path, | |
map_size = self.MAP_SIZE, | |
create=True, | |
subdir=False, | |
readonly=False, | |
) | |
ids = [] | |
with db_conn.begin(write=True, buffers=True) as txn: | |
for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'): | |
if data is None: | |
continue | |
ids.append(data['id']) | |
txn.put(data['id'].encode('utf-8'), pickle.dumps(data)) | |
with open(self._structure_cache_path + '-ids', 'wb') as f: | |
pickle.dump(ids, f) | |
def _cluster_path(self): | |
return os.path.join(self.processed_dir, 'cluster_result_cluster.tsv') | |
def _load_clusters(self, reset): | |
if not os.path.exists(self._cluster_path) or reset: | |
self._create_clusters() | |
clusters, id_to_cluster = {}, {} | |
with open(self._cluster_path, 'r') as f: | |
for line in f.readlines(): | |
cluster_name, data_id = line.split() | |
if cluster_name not in clusters: | |
clusters[cluster_name] = [] | |
clusters[cluster_name].append(data_id) | |
id_to_cluster[data_id] = cluster_name | |
self.clusters = clusters | |
self.id_to_cluster = id_to_cluster | |
def _create_clusters(self): | |
cdr_records = [] | |
for id in self.db_ids: | |
structure = self.get_structure(id) | |
if structure['heavy'] is not None: | |
cdr_records.append(SeqRecord.SeqRecord( | |
Seq.Seq(structure['heavy']['H3_seq']), | |
id = structure['id'], | |
name = '', | |
description = '', | |
)) | |
elif structure['light'] is not None: | |
cdr_records.append(SeqRecord.SeqRecord( | |
Seq.Seq(structure['light']['L3_seq']), | |
id = structure['id'], | |
name = '', | |
description = '', | |
)) | |
fasta_path = os.path.join(self.processed_dir, 'cdr_sequences.fasta') | |
SeqIO.write(cdr_records, fasta_path, 'fasta') | |
cmd = ' '.join([ | |
'mmseqs', 'easy-cluster', | |
os.path.realpath(fasta_path), | |
'cluster_result', 'cluster_tmp', | |
'--min-seq-id', '0.5', | |
'-c', '0.8', | |
'--cov-mode', '1', | |
]) | |
subprocess.run(cmd, cwd=self.processed_dir, shell=True, check=True) | |
def _load_split(self, split, split_seed): | |
assert split in ('train', 'val', 'test') | |
ids_test = [ | |
entry['id'] | |
for entry in self.sabdab_entries | |
if entry['ag_name'] in TEST_ANTIGENS | |
] | |
test_relevant_clusters = set([self.id_to_cluster[id] for id in ids_test]) | |
ids_train_val = [ | |
entry['id'] | |
for entry in self.sabdab_entries | |
if self.id_to_cluster[entry['id']] not in test_relevant_clusters | |
] | |
random.Random(split_seed).shuffle(ids_train_val) | |
if split == 'test': | |
self.ids_in_split = ids_test | |
elif split == 'val': | |
self.ids_in_split = ids_train_val[:20] | |
else: | |
self.ids_in_split = ids_train_val[20:] | |
def _connect_db(self): | |
if self.db_conn is not None: | |
return | |
self.db_conn = lmdb.open( | |
self._structure_cache_path, | |
map_size=self.MAP_SIZE, | |
create=False, | |
subdir=False, | |
readonly=True, | |
lock=False, | |
readahead=False, | |
meminit=False, | |
) | |
def get_structure(self, id): | |
self._connect_db() | |
with self.db_conn.begin() as txn: | |
return pickle.loads(txn.get(id.encode())) | |
def __len__(self): | |
return len(self.ids_in_split) | |
def __getitem__(self, index): | |
id = self.ids_in_split[index] | |
data = self.get_structure(id) | |
if self.transform is not None: | |
data = self.transform(data) | |
return data | |
def get_sabdab_dataset(cfg, transform): | |
return SAbDabDataset( | |
summary_path = cfg.summary_path, | |
chothia_dir = cfg.chothia_dir, | |
processed_dir = cfg.processed_dir, | |
split = cfg.split, | |
split_seed = cfg.get('split_seed', 2022), | |
transform = transform, | |
) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--split', type=str, default='train') | |
parser.add_argument('--processed_dir', type=str, default='./data/processed') | |
parser.add_argument('--reset', action='store_true', default=False) | |
args = parser.parse_args() | |
if args.reset: | |
sure = input('Sure to reset? (y/n): ') | |
if sure != 'y': | |
exit() | |
dataset = SAbDabDataset( | |
processed_dir=args.processed_dir, | |
split=args.split, | |
reset=args.reset | |
) | |
print(dataset[0]) | |
print(len(dataset), len(dataset.clusters)) | |