File size: 2,247 Bytes
809fb87 ff2b104 809fb87 ff2b104 d43f920 809fb87 ff2b104 809fb87 ff2b104 809fb87 ff2b104 809fb87 ff2b104 809fb87 ff2b104 809fb87 ff2b104 809fb87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import pandas as pd
from pathlib import Path
from Bio import SeqIO
from dscript.pretrained import get_pretrained
from dscript.language_model import lm_embed
from tqdm.auto import tqdm
from uuid import uuid4
model_map = {
"D-SCRIPT": "human_v1",
"Topsy-Turvy": "human_v2"
}
def predict(model, sequence_file, pairs_file):
run_id = uuid4()
gr.Info("Loading model...")
_ = lm_embed("M")
model = get_pretrained(model_map[model])
gr.Info("Loading files...")
try:
seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
except ValueError as e:
gr.Error("Invalid FASTA file - duplicate entry")
if Path(pairs_file.name).suffix == ".csv":
pairs = pd.read_csv(pairs_file.name)
elif Path(pairs_file.name).suffix == ".tsv":
pairs = pd.read_csv(pairs_file.name, sep="\t")
pairs.columns = ["protein1", "protein2"]
gr.Info("Predicting...")
results = []
progress = gr.Progress(track_tqdm=True)
for i, r in tqdm(pairs.iterrows(), total=len(pairs)):
gr.Info(f"[{i+1}/{len(pairs)}]")
prot1 = r["protein1"]
prot2 = r["protein2"]
seq1 = str(seqs[prot1].seq)
seq2 = str(seqs[prot2].seq)
lm1 = lm_embed(seq1)
lm2 = lm_embed(seq2)
interaction = model.predict(lm1, lm2).item()
results.append([prot1, prot2, interaction])
progress((i, len(pairs)))
results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
file_path = f"/tmp/{run_id}.tsv"
with open(file_path, "w") as f:
results.to_csv(f, sep="\t", index=False, header = True)
return results, file_path
demo = gr.Interface(
fn=predict,
inputs = [
gr.Dropdown(label="Model", choices = ["D-SCRIPT", "Topsy-Turvy"], value = "Topsy-Turvy"),
gr.File(label="Sequences (.fasta)", file_types = [".fasta"]),
gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"])
],
outputs = [
gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction']),
gr.File(label="Download results", type="file")
]
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch() |