Samuel Sledzieski commited on
Commit
ff2b104
1 Parent(s): 809fb87

File download and multiple models

Browse files
Files changed (1) hide show
  1. app.py +27 -6
app.py CHANGED
@@ -5,10 +5,23 @@ from Bio import SeqIO
5
  from dscript.pretrained import get_pretrained
6
  from dscript.language_model import lm_embed
7
  from tqdm.auto import tqdm
 
8
 
9
- def predict(sequence_file, pairs_file):
10
-
11
- model = get_pretrained('human_v1')
 
 
 
 
 
 
 
 
 
 
 
 
12
  seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
13
  if Path(pairs_file.name).suffix == ".csv":
14
  pairs = pd.read_csv(pairs_file.name)
@@ -16,9 +29,11 @@ def predict(sequence_file, pairs_file):
16
  pairs = pd.read_csv(pairs_file.name, sep="\t")
17
  pairs.columns = ["protein1", "protein2"]
18
 
 
19
  results = []
20
  progress = gr.Progress(track_tqdm=True)
21
  for i, r in tqdm(pairs.iterrows(), total=len(pairs)):
 
22
  prot1 = r["protein1"]
23
  prot2 = r["protein2"]
24
  seq1 = str(seqs[prot1].seq)
@@ -27,20 +42,26 @@ def predict(sequence_file, pairs_file):
27
  lm2 = lm_embed(seq2)
28
  interaction = model.predict(lm1, lm2).item()
29
  results.append([prot1, prot2, interaction])
30
- # progress((i, len(pairs)))
31
 
32
  results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
33
 
34
- return results
 
 
 
 
35
 
36
  demo = gr.Interface(
37
  fn=predict,
38
  inputs = [
 
39
  gr.File(label="Sequences (.fasta)", file_types = [".fasta"]),
40
  gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"])
41
  ],
42
  outputs = [
43
- gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction'])
 
44
  ]
45
  )
46
 
 
5
  from dscript.pretrained import get_pretrained
6
  from dscript.language_model import lm_embed
7
  from tqdm.auto import tqdm
8
+ from uuid import uuid4
9
 
10
+ model_map = {
11
+ "D-SCRIPT": "human_v1",
12
+ "Topsy-Turvy": "human_v2"
13
+ }
14
+
15
+ def predict(model, sequence_file, pairs_file):
16
+
17
+ run_id = uuid4()
18
+
19
+ gr.Info("Loading model...")
20
+ _ = lm_embed("M")
21
+
22
+ model = get_pretrained(model_map[model])
23
+
24
+ gr.Info("Loading files...")
25
  seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
26
  if Path(pairs_file.name).suffix == ".csv":
27
  pairs = pd.read_csv(pairs_file.name)
 
29
  pairs = pd.read_csv(pairs_file.name, sep="\t")
30
  pairs.columns = ["protein1", "protein2"]
31
 
32
+ gr.Info("Predicting...")
33
  results = []
34
  progress = gr.Progress(track_tqdm=True)
35
  for i, r in tqdm(pairs.iterrows(), total=len(pairs)):
36
+ gr.Info(f"[{i+1}/{len(pairs)}]")
37
  prot1 = r["protein1"]
38
  prot2 = r["protein2"]
39
  seq1 = str(seqs[prot1].seq)
 
42
  lm2 = lm_embed(seq2)
43
  interaction = model.predict(lm1, lm2).item()
44
  results.append([prot1, prot2, interaction])
45
+ progress((i, len(pairs)))
46
 
47
  results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
48
 
49
+ file_path = f"/tmp/{run_id}.tsv"
50
+ with open(file_path, "w") as f:
51
+ results.to_csv(f, sep="\t", index=False, header = True)
52
+
53
+ return results, file_path
54
 
55
  demo = gr.Interface(
56
  fn=predict,
57
  inputs = [
58
+ gr.Dropdown(label="Model", choices = ["D-SCRIPT", "Topsy-Turvy"], value = "Topsy-Turvy"),
59
  gr.File(label="Sequences (.fasta)", file_types = [".fasta"]),
60
  gr.File(label="Pairs (.csv/.tsv)", file_types = [".csv", ".tsv"])
61
  ],
62
  outputs = [
63
+ gr.DataFrame(label='Results', headers=['Protein 1', 'Protein 2', 'Interaction']),
64
+ gr.File(label="Download results", type="file")
65
  ]
66
  )
67