| | import argparse |
| | import json |
| | import multiprocessing |
| | import pickle |
| | import traceback |
| | from dataclasses import asdict, dataclass, replace |
| | from functools import partial |
| | from pathlib import Path |
| | from typing import Any, Optional |
| |
|
| | import numpy as np |
| | import rdkit |
| | from mmcif import parse_mmcif |
| | from p_tqdm import p_umap |
| | from redis import Redis |
| | from tqdm import tqdm |
| |
|
| | from boltz.data.filter.static.filter import StaticFilter |
| | from boltz.data.filter.static.ligand import ExcludedLigands |
| | from boltz.data.filter.static.polymer import ( |
| | ClashingChainsFilter, |
| | ConsecutiveCA, |
| | MinimumLengthFilter, |
| | UnknownFilter, |
| | ) |
| | from boltz.data.types import ChainInfo, InterfaceInfo, Record, Target |
| |
|
| |
|
| | @dataclass(frozen=True, slots=True) |
| | class PDB: |
| | """A raw MMCIF PDB file.""" |
| |
|
| | id: str |
| | path: str |
| |
|
| |
|
| | class Resource: |
| | """A shared resource for processing.""" |
| |
|
| | def __init__(self, host: str, port: int) -> None: |
| | """Initialize the redis database.""" |
| | self._redis = Redis(host=host, port=port) |
| |
|
| | def get(self, key: str) -> Any: |
| | """Get an item from the Redis database.""" |
| | value = self._redis.get(key) |
| | if value is not None: |
| | value = pickle.loads(value) |
| | return value |
| |
|
| | def __getitem__(self, key: str) -> Any: |
| | """Get an item from the resource.""" |
| | out = self.get(key) |
| | if out is None: |
| | raise KeyError(key) |
| | return out |
| |
|
| |
|
| | def fetch(datadir: Path, max_file_size: Optional[int] = None) -> list[PDB]: |
| | """Fetch the PDB files.""" |
| | data = [] |
| | excluded = 0 |
| | for file in datadir.rglob("*.cif*"): |
| | |
| | pdb_id = str(file.stem).lower() |
| |
|
| | |
| | if max_file_size is not None and (file.stat().st_size > max_file_size): |
| | excluded += 1 |
| | continue |
| |
|
| | |
| | target = PDB(id=pdb_id, path=str(file)) |
| | data.append(target) |
| |
|
| | print(f"Excluded {excluded} files due to size.") |
| | return data |
| |
|
| |
|
| | def finalize(outdir: Path) -> None: |
| | """Run post-processing in main thread. |
| | |
| | Parameters |
| | ---------- |
| | outdir : Path |
| | The output directory. |
| | |
| | """ |
| | |
| | records_dir = outdir / "records" |
| |
|
| | failed_count = 0 |
| | records = [] |
| | for record in records_dir.iterdir(): |
| | path = record |
| | try: |
| | with path.open("r") as f: |
| | records.append(json.load(f)) |
| | except: |
| | failed_count += 1 |
| | print(f"Failed to parse {record}") |
| | if failed_count > 0: |
| | print(f"Failed to parse {failed_count} entries.") |
| | else: |
| | print("All entries parsed successfully.") |
| |
|
| | |
| | outpath = outdir / "manifest.json" |
| | with outpath.open("w") as f: |
| | json.dump(records, f) |
| |
|
| |
|
| | def parse(data: PDB, resource: Resource, clusters: dict) -> Target: |
| | """Process a structure. |
| | |
| | Parameters |
| | ---------- |
| | data : PDB |
| | The raw input data. |
| | resource: Resource |
| | The shared resource. |
| | |
| | Returns |
| | ------- |
| | Target |
| | The processed data. |
| | |
| | """ |
| | |
| | pdb_id = data.id.lower() |
| |
|
| | |
| | parsed = parse_mmcif(data.path, resource) |
| | structure = parsed.data |
| | structure_info = parsed.info |
| |
|
| | |
| | chain_info = [] |
| | for i, chain in enumerate(structure.chains): |
| | key = f"{pdb_id}_{chain['entity_id']}" |
| | chain_info.append( |
| | ChainInfo( |
| | chain_id=i, |
| | chain_name=chain["name"], |
| | msa_id="", |
| | mol_type=int(chain["mol_type"]), |
| | cluster_id=clusters.get(key, -1), |
| | num_residues=int(chain["res_num"]), |
| | ) |
| | ) |
| |
|
| | |
| | interface_info = [] |
| | for interface in structure.interfaces: |
| | chain_1 = int(interface["chain_1"]) |
| | chain_2 = int(interface["chain_2"]) |
| | interface_info.append( |
| | InterfaceInfo( |
| | chain_1=chain_1, |
| | chain_2=chain_2, |
| | ) |
| | ) |
| |
|
| | |
| | record = Record( |
| | id=data.id, |
| | structure=structure_info, |
| | chains=chain_info, |
| | interfaces=interface_info, |
| | ) |
| |
|
| | return Target(structure=structure, record=record) |
| |
|
| |
|
| | def process_structure( |
| | data: PDB, |
| | resource: Resource, |
| | outdir: Path, |
| | filters: list[StaticFilter], |
| | clusters: dict, |
| | ) -> None: |
| | """Process a target. |
| | |
| | Parameters |
| | ---------- |
| | item : PDB |
| | The raw input data. |
| | resource: Resource |
| | The shared resource. |
| | outdir : Path |
| | The output directory. |
| | |
| | """ |
| | |
| | struct_path = outdir / "structures" / f"{data.id}.npz" |
| | record_path = outdir / "records" / f"{data.id}.json" |
| |
|
| | if struct_path.exists() and record_path.exists(): |
| | return |
| |
|
| | try: |
| | |
| | target: Target = parse(data, resource, clusters) |
| | structure = target.structure |
| |
|
| | |
| | mask = structure.mask |
| | if filters is not None: |
| | for f in filters: |
| | filter_mask = f.filter(structure) |
| | mask = mask & filter_mask |
| | except Exception: |
| | traceback.print_exc() |
| | print(f"Failed to parse {data.id}") |
| | return |
| |
|
| | |
| | chains = [] |
| | for i, chain in enumerate(target.record.chains): |
| | chains.append(replace(chain, valid=bool(mask[i]))) |
| |
|
| | interfaces = [] |
| | for interface in target.record.interfaces: |
| | chain_1 = bool(mask[interface.chain_1]) |
| | chain_2 = bool(mask[interface.chain_2]) |
| | interfaces.append(replace(interface, valid=(chain_1 and chain_2))) |
| |
|
| | |
| | structure = replace(structure, mask=mask) |
| | record = replace(target.record, chains=chains, interfaces=interfaces) |
| | target = replace(target, structure=structure, record=record) |
| |
|
| | |
| | np.savez_compressed(struct_path, **asdict(structure)) |
| |
|
| | |
| | with record_path.open("w") as f: |
| | json.dump(asdict(record), f) |
| |
|
| |
|
| | def process(args) -> None: |
| | """Run the data processing task.""" |
| | |
| | args.outdir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | records_dir = args.outdir / "records" |
| | records_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | structure_dir = args.outdir / "structures" |
| | structure_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | with Path(args.clusters).open("r") as f: |
| | clusters: dict[str, str] = json.load(f) |
| | clusters = {k.lower(): v.lower() for k, v in clusters.items()} |
| |
|
| | |
| | filters = [ |
| | ExcludedLigands(), |
| | MinimumLengthFilter(min_len=4, max_len=5000), |
| | UnknownFilter(), |
| | ConsecutiveCA(max_dist=10.0), |
| | ClashingChainsFilter(freq=0.3, dist=1.7), |
| | ] |
| |
|
| | |
| | pickle_option = rdkit.Chem.PropertyPickleOptions.AllProps |
| | rdkit.Chem.SetDefaultPickleProperties(pickle_option) |
| |
|
| | |
| | resource = Resource(host=args.redis_host, port=args.redis_port) |
| |
|
| | |
| | print("Fetching data...") |
| | data = fetch(args.datadir) |
| |
|
| | |
| | max_processes = multiprocessing.cpu_count() |
| | num_processes = max(1, min(args.num_processes, max_processes, len(data))) |
| | parallel = num_processes > 1 |
| |
|
| | |
| | print("Processing data...") |
| | if parallel: |
| | |
| | fn = partial( |
| | process_structure, |
| | resource=resource, |
| | outdir=args.outdir, |
| | clusters=clusters, |
| | filters=filters, |
| | ) |
| | |
| | p_umap(fn, data, num_cpus=num_processes) |
| | else: |
| | for item in tqdm(data): |
| | process_structure( |
| | item, |
| | resource=resource, |
| | outdir=args.outdir, |
| | clusters=clusters, |
| | filters=filters, |
| | ) |
| |
|
| | |
| | finalize(args.outdir) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Process MSA data.") |
| | parser.add_argument( |
| | "--datadir", |
| | type=Path, |
| | required=True, |
| | help="The data containing the MMCIF files.", |
| | ) |
| | parser.add_argument( |
| | "--clusters", |
| | type=Path, |
| | required=True, |
| | help="Path to the cluster file.", |
| | ) |
| | parser.add_argument( |
| | "--outdir", |
| | type=Path, |
| | default="data", |
| | help="The output directory.", |
| | ) |
| | parser.add_argument( |
| | "--num-processes", |
| | type=int, |
| | default=multiprocessing.cpu_count(), |
| | help="The number of processes.", |
| | ) |
| | parser.add_argument( |
| | "--redis-host", |
| | type=str, |
| | default="localhost", |
| | help="The Redis host.", |
| | ) |
| | parser.add_argument( |
| | "--redis-port", |
| | type=int, |
| | default=7777, |
| | help="The Redis port.", |
| | ) |
| | parser.add_argument( |
| | "--use-assembly", |
| | action="store_true", |
| | help="Whether to use assembly 1.", |
| | ) |
| | parser.add_argument( |
| | "--max-file-size", |
| | type=int, |
| | default=None, |
| | ) |
| | args = parser.parse_args() |
| | process(args) |
| |
|