|
import time |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import torch |
|
from pathlib import Path |
|
from Bio import SeqIO |
|
from dscript.pretrained import get_pretrained |
|
from dscript.language_model import lm_embed |
|
from tqdm.auto import tqdm |
|
from uuid import uuid4 |
|
from predict_3di import get_3di_sequences, predictions_to_dict, one_hot_3di_sequence |
|
|
|
model_map = { |
|
"D-SCRIPT": "human_v1", |
|
"Topsy-Turvy": "human_v2", |
|
"TT3D": "human_tt3d", |
|
} |
|
|
|
theme = "Default" |
|
title = "D-SCRIPT: Predicting Protein-Protein Interactions" |
|
description = """ |
|
If you use this interface to make predictions, please let us know (by emailing samsl@mit.edu)! |
|
We want to keep this web version free to use with GPU support, and to do that we need to demonstrate to |
|
our funders that it is being used. Thank you! |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
article = """ |
|
|
|
Note that running here with the "TT3D" model does not run structure prediction on the sequences, but rather uses the [ProstT5](https://github.com/mheinzinger/ProstT5) language model to |
|
translate amino acid to 3di sequences. This is much faster than running structure prediction, but the results may not be as accurate. |
|
""" |
|
|
|
fold_vocab = { |
|
"D": 0, |
|
"P": 1, |
|
"V": 2, |
|
"Q": 3, |
|
"A": 4, |
|
"W": 5, |
|
"K": 6, |
|
"E": 7, |
|
"I": 8, |
|
"T": 9, |
|
"L": 10, |
|
"F": 11, |
|
"G": 12, |
|
"S": 13, |
|
"M": 14, |
|
"H": 15, |
|
"C": 16, |
|
"R": 17, |
|
"Y": 18, |
|
"N": 19, |
|
"X": 20, |
|
} |
|
|
|
def predict(model_name, pairs_file, sequence_file, progress = gr.Progress()): |
|
|
|
try: |
|
run_id = uuid4() |
|
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
|
|
_ = lm_embed("M", use_cuda = (device.type == "cuda")) |
|
|
|
model = get_pretrained(model_map[model_name]).to(device) |
|
|
|
|
|
try: |
|
seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta")) |
|
except ValueError as _: |
|
raise gr.Error("Invalid FASTA file - duplicate entry") |
|
|
|
if Path(pairs_file.name).suffix == ".csv": |
|
pairs = pd.read_csv(pairs_file.name) |
|
elif Path(pairs_file.name).suffix == ".tsv": |
|
pairs = pd.read_csv(pairs_file.name, sep="\t") |
|
try: |
|
pairs.columns = ["protein1", "protein2"] |
|
except ValueError as _: |
|
raise gr.Error("Invalid pairs file - must have two columns 'protein1' and 'protein2'") |
|
|
|
do_foldseek = False |
|
if model_name == "TT3D": |
|
do_foldseek = True |
|
|
|
need_to_translate = set(pairs["protein1"]).union(set(pairs["protein2"])) |
|
seqs_to_translate = {k: str(seqs[k].seq) for k in need_to_translate if k in seqs} |
|
|
|
half_precision = False |
|
assert not (half_precision and device=="cpu"), print("Running fp16 on CPU is not supported, yet") |
|
|
|
gr.Info(f"Loading Foldseek embeddings -- this may take some time ({len(seqs_to_translate)} embeddings)...") |
|
predictions = get_3di_sequences( |
|
seqs_to_translate, |
|
model_dir = "Rostlab/ProstT5", |
|
report_fn = gr.Info, |
|
error_fn = gr.Error, |
|
device=device, |
|
) |
|
foldseek_sequences = predictions_to_dict(predictions) |
|
foldseek_embeddings = {k: one_hot_3di_sequence(s.upper(), fold_vocab) for k, s in foldseek_sequences.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0, desc="Starting...") |
|
results = [] |
|
for i in progress.tqdm(range(len(pairs))): |
|
|
|
r = pairs.iloc[i] |
|
|
|
prot1 = r["protein1"] |
|
prot2 = r["protein2"] |
|
|
|
seq1 = str(seqs[prot1].seq) |
|
seq2 = str(seqs[prot2].seq) |
|
|
|
fold1 = foldseek_embeddings[prot1].to(device) if do_foldseek else None |
|
fold2 = foldseek_embeddings[prot2].to(device) if do_foldseek else None |
|
|
|
lm1 = lm_embed(seq1).to(device) |
|
lm2 = lm_embed(seq2).to(device) |
|
|
|
interaction = model.predict(lm1, lm2, embed_foldseek = do_foldseek, f0 = fold1, f1 = fold2).item() |
|
|
|
results.append([prot1, prot2, interaction]) |
|
|
|
results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"]) |
|
|
|
file_path = f"/tmp/{run_id}.tsv" |
|
with open(file_path, "w") as f: |
|
results.to_csv(f, sep="\t", index=False, header = True) |
|
|
|
return results, file_path |
|
|
|
except Exception as e: |
|
gr.Error(e) |
|
return None, None |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs = [ |
|
gr.Dropdown(label="Model", choices = ["D-SCRIPT", "Topsy-Turvy", "TT3D"], value = "Topsy-Turvy"), |
|
gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"]), |
|
gr.File(label="Sequences (.fasta)", file_types = [".fasta"]), |
|
], |
|
outputs = [ |
|
gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction']), |
|
gr.File(label="Download results", type="file") |
|
], |
|
|
|
|
|
article = article, |
|
theme = theme, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20).launch() |