Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import gc | |
| import argparse | |
| import biotite.structure.io as bsio | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from Bio import SeqIO | |
| from transformers import AutoTokenizer, EsmForProteinFolding | |
| from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein | |
| from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 | |
| def read_fasta(file_path, key): | |
| return str(getattr(SeqIO.read(file_path, 'fasta'), key)) | |
| def read_multi_fasta(file_path): | |
| """ | |
| params: | |
| file_path: path to a fasta file | |
| return: | |
| a dictionary of sequences | |
| """ | |
| sequences = {} | |
| current_sequence = '' | |
| with open(file_path, 'r') as file: | |
| for line in file: | |
| line = line.strip() | |
| if line.startswith('>'): | |
| if current_sequence: | |
| sequences[header] = current_sequence | |
| current_sequence = '' | |
| header = line | |
| else: | |
| current_sequence += line | |
| if current_sequence: | |
| sequences[header] = current_sequence | |
| return sequences | |
| def convert_outputs_to_pdb(outputs): | |
| final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) | |
| outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} | |
| final_atom_positions = final_atom_positions.cpu().numpy() | |
| final_atom_mask = outputs["atom37_atom_exists"] | |
| pdbs = [] | |
| for i in range(outputs["aatype"].shape[0]): | |
| aa = outputs["aatype"][i] | |
| pred_pos = final_atom_positions[i] | |
| mask = final_atom_mask[i] | |
| resid = outputs["residue_index"][i] + 1 | |
| pred = OFProtein( | |
| aatype=aa, | |
| atom_positions=pred_pos, | |
| atom_mask=mask, | |
| residue_index=resid, | |
| b_factors=outputs["plddt"][i], | |
| chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, | |
| ) | |
| pdbs.append(to_pdb(pred)) | |
| return pdbs | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--sequence", type=str, default=None) | |
| parser.add_argument("--fasta_file", type=str, default=None) | |
| parser.add_argument("--fasta_chunk_num", type=int, default=None) | |
| parser.add_argument("--fasta_chunk_id", type=int, default=None) | |
| parser.add_argument("--fasta_dir", type=str, default=None) | |
| parser.add_argument("--out_dir", type=str) | |
| parser.add_argument("--out_file", type=str, default="result.pdb") | |
| parser.add_argument("--out_info_file", type=str, default=None) | |
| parser.add_argument("--fold_chunk_size", type=int) | |
| args = parser.parse_args() | |
| # model = esm.pretrained.esmfold_v1() | |
| # model = model.eval().cuda() | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") | |
| model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True) | |
| model = model.cuda() | |
| # model.esm = model.esm.half() | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory. | |
| # Lower sizes will have lower memory requirements at the cost of increased speed. | |
| if args.fold_chunk_size is not None: | |
| model.trunk.set_chunk_size(args.fold_chunk_size) | |
| if args.fasta_file is not None: | |
| unfold_proteins = [] | |
| seq_dict = read_multi_fasta(args.fasta_file) | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| names, sequences = list(seq_dict.keys()), list(seq_dict.values()) | |
| if args.fasta_chunk_num is not None: | |
| chunk_size = len(names) // args.fasta_chunk_num + 1 | |
| start = args.fasta_chunk_id * chunk_size | |
| end = min((args.fasta_chunk_id + 1) * chunk_size, len(names)) | |
| names, sequences = names[start:end], sequences[start:end] | |
| out_info_dict = {"name": [], "plddt": []} | |
| bar = tqdm(zip(names, sequences)) | |
| for name, sequence in bar: | |
| bar.set_description(name) | |
| name = name[1:].split(" ")[0] | |
| out_file = os.path.join(args.out_dir, f"{name}.ef.pdb") | |
| if os.path.exists(out_file): | |
| out_info_dict["name"].append(name) | |
| struct = bsio.load_structure(out_file, extra_fields=["b_factor"]) | |
| out_info_dict["plddt"].append(struct.b_factor.mean()) | |
| continue | |
| # Multimer prediction can be done with chains separated by ':' | |
| try: | |
| tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() | |
| with torch.no_grad(): | |
| output = model(tokenized_input) | |
| except Exception as e: | |
| print(e) | |
| print(f"Failed to predict {name}") | |
| unfold_proteins.append(name) | |
| continue | |
| gc.collect() | |
| pdb = convert_outputs_to_pdb(output) | |
| with open(out_file, "w") as f: | |
| f.write("\n".join(pdb)) | |
| out_info_dict["name"].append(name) | |
| struct = bsio.load_structure(out_file, extra_fields=["b_factor"]) | |
| out_info_dict["plddt"].append(struct.b_factor.mean()) | |
| if args.out_info_file is not None: | |
| pd.DataFrame(out_info_dict).to_csv(args.out_info_file, index=False) | |
| if args.fasta_dir is not None: | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| proteins = sorted(os.listdir(args.fasta_dir)) | |
| bar = tqdm(proteins) | |
| for p in bar: | |
| name = p[:-6] | |
| bar.set_description(name) | |
| out_file = os.path.join(args.out_dir, f"{name}.ef.pdb") | |
| if os.path.exists(out_file): | |
| continue | |
| bar.set_description(p) | |
| sequence = read_fasta(os.path.join(args.fasta_dir, p), "seq") | |
| tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() | |
| # Multimer prediction can be done with chains separated by ':' | |
| with torch.no_grad(): | |
| output = model(tokenized_input) | |
| pdb = convert_outputs_to_pdb(output) | |
| with open(out_file, "w") as f: | |
| f.write("\n".join(pdb)) | |
| struct = bsio.load_structure(out_file, extra_fields=["b_factor"]) | |
| print(p, struct.b_factor.mean()) | |
| elif args.sequence is not None: | |
| sequence = args.sequence | |
| # Multimer prediction can be done with chains separated by ':' | |
| with torch.no_grad(): | |
| output = model.infer_pdb(sequence) | |
| with open(args.out_file, "w") as f: | |
| f.write(output) | |
| struct = bsio.load_structure(args.out_file, extra_fields=["b_factor"]) | |
| print(struct.b_factor.mean()) |