import time
import gradio as gr
import pandas as pd
import torch
from pathlib import Path
from Bio import SeqIO
from dscript.pretrained import get_pretrained
from dscript.language_model import lm_embed
from tqdm.auto import tqdm
from uuid import uuid4
from predict_3di import get_3di_sequences, predictions_to_dict, one_hot_3di_sequence
model_map = {
"D-SCRIPT": "human_v1",
"Topsy-Turvy": "human_v2",
"TT3D": "human_tt3d",
}
theme = "Default"
title = "D-SCRIPT: Predicting Protein-Protein Interactions"
description = """
"""
# article = """
#
#
#
# D-SCRIPT is a deep learning method for predicting a physical interaction between two proteins given just their sequences.
# It generalizes well to new species and is robust to limitations in training data size. Its design reflects the intuition that for two proteins to physically interact,
# a subset of amino acids from each protein should be in contact with the other. The intermediate stages of D-SCRIPT directly implement this intuition, with the penultimate stage
# in D-SCRIPT being a rough estimate of the inter-protein contact map of the protein dimer. This structurally-motivated design enhances the interpretability of the results and,
# since structure is more conserved evolutionarily than sequence, improves generalizability across species.
#
# Computational methods to predict protein-protein interaction (PPI) typically segregate into sequence-based "bottom-up" methods that infer properties from the characteristics of the
# individual protein sequences, or global "top-down" methods that infer properties from the pattern of already known PPIs in the species of interest. However, a way to incorporate
# top-down insights into sequence-based bottom-up PPI prediction methods has been elusive. Topsy-Turvy builds upon D-SCRIPT by synthesizing both views in a sequence-based,
# multi-scale, deep-learning model for PPI prediction. While Topsy-Turvy makes predictions using only sequence data, during the training phase it takes a transfer-learning approach by
# incorporating patterns from both global and molecular-level views of protein interaction. In a cross-species context, we show it achieves state-of-the-art performance, offering the
# ability to perform genome-scale, interpretable PPI prediction for non-model organisms with no existing experimental PPI data.
# """
article = """
Note that running here with the "TT3D" model does not run structure prediction on the sequences, but rather uses the [ProstT5](https://github.com/mheinzinger/ProstT5) language model to
translate amino acid to 3di sequences. This is much faster than running structure prediction, but the results may not be as accurate.
"""
fold_vocab = {
"D": 0,
"P": 1,
"V": 2,
"Q": 3,
"A": 4,
"W": 5,
"K": 6,
"E": 7,
"I": 8,
"T": 9,
"L": 10,
"F": 11,
"G": 12,
"S": 13,
"M": 14,
"H": 15,
"C": 16,
"R": 17,
"Y": 18,
"N": 19,
"X": 20,
}
def predict(model_name, pairs_file, sequence_file, progress = gr.Progress()):
run_id = uuid4()
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# gr.Info("Loading model...")
_ = lm_embed("M", use_cuda = (device.type == "cuda"))
model = get_pretrained(model_map[model_name]).to(device)
# gr.Info("Loading files...")
try:
seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
except ValueError as _:
raise gr.Error("Invalid FASTA file - duplicate entry")
if Path(pairs_file.name).suffix == ".csv":
pairs = pd.read_csv(pairs_file.name)
elif Path(pairs_file.name).suffix == ".tsv":
pairs = pd.read_csv(pairs_file.name, sep="\t")
try:
pairs.columns = ["protein1", "protein2"]
except ValueError as _:
raise gr.Error("Invalid pairs file - must have two columns 'protein1' and 'protein2'")
do_foldseek = False
if model_name == "TT3D":
do_foldseek = True
need_to_translate = set(pairs["protein1"]).union(set(pairs["protein2"]))
seqs_to_translate = {k: str(seqs[k].seq) for k in need_to_translate if k in seqs}
half_precision = False
assert not (half_precision and device=="cpu"), print("Running fp16 on CPU is not supported, yet")
gr.Info(f"Loading Foldseek embeddings -- this may take some time ({len(seqs_to_translate)} embeddings)...")
predictions = get_3di_sequences(
seqs_to_translate,
model_dir = "Rostlab/ProstT5",
report_fn = gr.Info,
error_fn = gr.Error,
device=device,
)
foldseek_sequences = predictions_to_dict(predictions)
foldseek_embeddings = {k: one_hot_3di_sequence(s.upper(), fold_vocab) for k, s in foldseek_sequences.items()}
# for k in seqs_to_translate.keys():
# print(seqs_to_translate[k])
# print(len(seqs_to_translate[k]))
# print(foldseek_embeddings[k])
# print(foldseek_embeddings[k].shape)
progress(0, desc="Starting...")
results = []
for i in progress.tqdm(range(len(pairs))):
r = pairs.iloc[i]
prot1 = r["protein1"]
prot2 = r["protein2"]
seq1 = str(seqs[prot1].seq)
seq2 = str(seqs[prot2].seq)
fold1 = foldseek_embeddings[prot1].to(device) if do_foldseek else None
fold2 = foldseek_embeddings[prot2].to(device) if do_foldseek else None
lm1 = lm_embed(seq1).to(device)
lm2 = lm_embed(seq2).to(device)
interaction = model.predict(lm1, lm2, embed_foldseek = do_foldseek, f0 = fold1, f1 = fold2).item()
results.append([prot1, prot2, interaction])
results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
file_path = f"/tmp/{run_id}.tsv"
with open(file_path, "w") as f:
results.to_csv(f, sep="\t", index=False, header = True)
return results, file_path
demo = gr.Interface(
fn=predict,
inputs = [
gr.Dropdown(label="Model", choices = ["D-SCRIPT", "Topsy-Turvy", "TT3D"], value = "Topsy-Turvy"),
gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"]),
gr.File(label="Sequences (.fasta)", file_types = [".fasta"]),
],
outputs = [
gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction']),
gr.File(label="Download results", type="file")
],
# title = title,
# description = description,
article = article,
theme = theme,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()