D-SCRIPT / app.py
samsl's picture
Add initial application
809fb87
raw
history blame
1.56 kB
import gradio as gr
import pandas as pd
from pathlib import Path
from Bio import SeqIO
from dscript.pretrained import get_pretrained
from dscript.language_model import lm_embed
from tqdm.auto import tqdm
def predict(sequence_file, pairs_file):
model = get_pretrained('human_v1')
seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
if Path(pairs_file.name).suffix == ".csv":
pairs = pd.read_csv(pairs_file.name)
elif Path(pairs_file.name).suffix == ".tsv":
pairs = pd.read_csv(pairs_file.name, sep="\t")
pairs.columns = ["protein1", "protein2"]
results = []
progress = gr.Progress(track_tqdm=True)
for i, r in tqdm(pairs.iterrows(), total=len(pairs)):
prot1 = r["protein1"]
prot2 = r["protein2"]
seq1 = str(seqs[prot1].seq)
seq2 = str(seqs[prot2].seq)
lm1 = lm_embed(seq1)
lm2 = lm_embed(seq2)
interaction = model.predict(lm1, lm2).item()
results.append([prot1, prot2, interaction])
# progress((i, len(pairs)))
results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
return results
demo = gr.Interface(
fn=predict,
inputs = [
gr.File(label="Sequences (.fasta)", file_types = [".fasta"]),
gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"])
],
outputs = [
gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction'])
]
)
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch()