from data.scripts.data_utils import parse_PDB from utils.utils import ClassConfig, DataCollatorForTokenRegression, process_in_batches_and_combine, get_dot_separated_name from models.T5_encoder_per_token import PT5_classification_model from data.scripts.get_enm_fluctuations_for_dataset import get_fluctuation_for_json_dict import argparse import os import yaml import torch from pathlib import Path from Bio import SeqIO import json import warnings from datetime import datetime from data.scripts.data_utils import modify_bfactor_biotite if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_file", type=str, required=True, help="Input file") parser.add_argument("--modality", type=str, required=True, help="Indicate 'Seq' or '3D' to use Flexpert-Seq or Flexpert-3D?") parser.add_argument("--splits_file", type=str, required=False, help="Path to the file defining the splits, in case that input_file is a dataset which should be subsampled.") parser.add_argument("--split", type=str, required=False, help="Specify test/train/val to subselect the respective split. If specified, the splits file needs to be provided as well.") parser.add_argument("--output_enm", action='store_true', help="If true, the ENM values will be outputted in separate file(s).") parser.add_argument("--output_fasta", action='store_true', help="If true, the sequences used for the prediction will be outputted in a fasta file (can be relevant when working with input list of PDB files).") parser.add_argument("--output_name", type=str, required=False, help="Name of the output file.") args = parser.parse_args() args.modality = args.modality.upper() filename, suffix = os.path.splitext(args.input_file) if args.modality not in ["SEQ", "3D"]: raise ValueError("Modality must be either Seq or 3D") if args.splits_file is not None and args.split is None: raise ValueError("If splits_file is provided, split must be specified.") if args.split is not None and args.splits_file is None: raise ValueError("If split is specified, splits_file must be provided.") if args.split is not None and args.split not in ["test", "train", "val", "validation"]: raise ValueError("Split must be either 'test', 'train', 'val' or 'validation'") if args.output_enm and (args.modality not in ["3D"]): raise ValueError("Output ENM is only supported for 3D modality") if not args.output_name: default_name = 'untitled_{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S')) args.output_name = default_name warnings.warn("Output name is not provided, using default name: {}".format(default_name)) if args.splits_file is not None: with open(args.splits_file, 'r') as f: splits = json.load(f) if 'val' in splits.keys() and args.split == 'validation': args.split = 'val' elif 'validation' in splits.keys() and args.split == 'val': args.split = 'validation' datapoint_for_eval = splits[args.split] else: datapoint_for_eval = 'all' sequences = [] names = [] backbones = [] pdb_files = [] flucts_list = [] def process_pdb_file(pdb_file, backbones, sequences, names): parsed_name = os.path.splitext(os.path.basename(pdb_file))[0].split('_') if len(parsed_name[0]) != 4 or len(parsed_name[1]) != 1 or not parsed_name[1].isalpha(): raise ValueError("PDB file name is expected to be in the format of 'name_chain.pdb', e.g.: 1BUI_C.pdb") _name = parsed_name[0] _chain = parsed_name[1] parsed_pdb = parse_PDB(pdb_file, name=_name, input_chain_list=[_chain])[0] backbone, sequence = parsed_pdb['coords_chain_{}'.format(_chain)], parsed_pdb['seq_chain_{}'.format(_chain)] if len(sequence) > 1023: print("Sequence length is greater than 1023, skipping {}".format(_name + "." + _chain)) else: backbones.append(backbone) sequences.append(sequence) names.append(_name + "." + _chain) return backbones, sequences, names if suffix == ".fasta": if args.modality == "3D": raise ValueError("Flexpert-3D needs the structure, fasta is not enough") # Load FASTA file using Biopython for record in SeqIO.parse(args.input_file, "fasta"): if '_' in record.name: dot_separated_name = '.'.join(record.name.split('_')) elif '.' in record.name: dot_separated_name = record.name else: raise ValueError("Sequence name must contain either an underscore or a dot to separate the PDB code and the chain code.") if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval: names.append(dot_separated_name) sequences.append(str(record.seq)) backbones.append(None) elif suffix == ".pdb": backbones, sequences, names = process_pdb_file(args.input_file, backbones, sequences, names) pdb_files.append(args.input_file) elif suffix == ".jsonl": for line in open(args.input_file, 'r'): _dict = json.loads(line) if 'fluctuations' in _dict.keys(): print("fluctuations are precomputed, using them") dot_separated_name = get_dot_separated_name(key='pdb_name', _dict=_dict) if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval: names.append(_dict['pdb_name']) backbones.append(None) sequences.append(_dict['sequence']) flucts_list.append(_dict['fluctuations']+[0.0]) #padding for end cls token continue dot_separated_name = get_dot_separated_name(key='name', _dict=_dict) if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval: backbones.append(_dict['coords']) sequences.append(_dict['seq']) names.append(dot_separated_name) elif suffix == ".pdb_list": with open(args.input_file, 'r') as f: pdb_files = f.read().splitlines() for pdb_file in pdb_files: backbones, sequences, names = process_pdb_file(pdb_file, backbones, sequences, names) else: raise ValueError("Input file must be a fasta, pdb, jsonl file or a pdb list file") ### Set environment variables env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader) # Set folder for huggingface cache os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME'] # Set gpu device os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device'] config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader) class_config=ClassConfig(config) class_config.adaptor_architecture = 'no-adaptor' if args.modality == 'SEQ' else 'conv' model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config) model.to(config['inference_args']['device']) if args.modality == 'SEQ': state_dict = torch.load(config['inference_args']['seq_model_path'], map_location=config['inference_args']['device']) model.load_state_dict(state_dict, strict=False) elif args.modality == '3D': print("Loading 3D model from {}".format(config['inference_args']['3d_model_path'])) state_dict = torch.load(config['inference_args']['3d_model_path'], map_location=config['inference_args']['device']) model.load_state_dict(state_dict, strict=False) model.eval() data_to_collate = [] for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)): if args.modality == '3D': if backbone is not None: _dict = {'coords': backbone, 'seq': sequence} flucts, _ = get_fluctuation_for_json_dict(_dict, enm_type = config['inference_args']['enm_type']) flucts = flucts.tolist() flucts.append(0.0) #To match the special token for the sequence flucts = torch.tensor(flucts).to(config['inference_args']['device']) else: flucts = flucts_list[idx] #Ensure that the missing residues in the sequence are not represented as '-' but as 'X' sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt') tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device']) if args.modality == '3D': data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:], 'enm_vals': flucts}) elif args.modality == 'SEQ': data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]}) # Use the data collator to process the input data_collator = DataCollatorForTokenRegression(tokenizer) batch = data_collator(data_to_collate) # Wrap in list since collator expects batch batch.to(model.device) for key in batch.keys(): print("___________-", key, "-___________") for b in batch[key]: if key == 'attention_mask': print(b.sum()) else: print(b.shape) # Predict with torch.no_grad(): output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size']) predictions = output_logits[:,:,0] #includes the prediction for the added token # subselect the predictions using the attention mask output_filename = Path(config['inference_args']['prediction_output_dir'].format(args.output_name, args.modality, 'all' if not args.split else args.split)) output_filename.parent.mkdir(parents=True, exist_ok=True) #Write the predictions to files with open(output_filename.with_suffix('.txt'), 'w') as f: print("Saving predictions to {}.".format(output_filename)) for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences): prediction = prediction[mask.bool()] if len(prediction) != len(sequence)+1: print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)) assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1) if '.' in name: name = name.replace('.', '_') f.write('>' + name + '\n') f.write(', '.join([str(p) for p in prediction.tolist()[:-1]]) + '\n') if suffix == ".pdb" or suffix == ".pdb_list": for name, pdb_file, prediction in zip(names, pdb_files, predictions): chain_id = name.split('.')[1] _prediction = prediction[:-1].reshape(1,-1) _outname = output_filename.with_name(output_filename.stem + '_{}.pdb'.format(name.replace('.', '_'))) print("Saving prediction to {}.".format(_outname)) modify_bfactor_biotite(pdb_file, chain_id, _outname, _prediction) #writing the prediction without the last token if args.output_enm: _outname = output_filename.with_name(output_filename.stem + '_enm.txt') with open(_outname, 'w') as f: print("Saving ENM predictions to {}.".format(_outname)) for enm_prediction, name in zip(batch['enm_vals'], names): f.write('>' + name + '\n') f.write(', '.join([str(p) for p in enm_prediction.tolist()[:-1]]) + '\n') if suffix == ".pdb" or suffix == ".pdb_list": for name, pdb_file, enm_vals_single in zip(names, pdb_files, batch['enm_vals']): _outname = output_filename.with_name(output_filename.stem + '_{}.pdb'.format(name.replace('.', '_'))) print("Saving ENM prediction to {}.".format(_outname)) chain_id = name.split('.')[1] _enm_vals = enm_vals_single[:-1].reshape(1,-1) modify_bfactor_biotite(pdb_file, chain_id, _outname, _enm_vals) #writing the prediction without the last token if args.output_fasta: _outname = output_filename.with_name(output_filename.stem + '_fasta.fasta') with open(_outname, 'w') as f: print("Saving fasta to {}.".format(_outname)) for name, sequence in zip(names, sequences): f.write('>' + name + '\n') f.write(sequence + '\n')