samsl commited on
Commit
c729a21
1 Parent(s): c680ea1

add message asking for notification of use

Browse files
Files changed (1) hide show
  1. app.py +83 -76
app.py CHANGED
@@ -20,6 +20,9 @@ model_map = {
20
  theme = "Default"
21
  title = "D-SCRIPT: Predicting Protein-Protein Interactions"
22
  description = """
 
 
 
23
  """
24
 
25
  # article = """
@@ -52,7 +55,6 @@ article = """
52
 
53
  Note that running here with the "TT3D" model does not run structure prediction on the sequences, but rather uses the [ProstT5](https://github.com/mheinzinger/ProstT5) language model to
54
  translate amino acid to 3di sequences. This is much faster than running structure prediction, but the results may not be as accurate.
55
-
56
  """
57
 
58
  fold_vocab = {
@@ -81,85 +83,90 @@ fold_vocab = {
81
 
82
  def predict(model_name, pairs_file, sequence_file, progress = gr.Progress()):
83
 
84
- run_id = uuid4()
85
- device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
86
-
87
- # gr.Info("Loading model...")
88
- _ = lm_embed("M", use_cuda = (device.type == "cuda"))
89
-
90
- model = get_pretrained(model_map[model_name]).to(device)
91
-
92
- # gr.Info("Loading files...")
93
  try:
94
- seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
95
- except ValueError as _:
96
- raise gr.Error("Invalid FASTA file - duplicate entry")
97
-
98
- if Path(pairs_file.name).suffix == ".csv":
99
- pairs = pd.read_csv(pairs_file.name)
100
- elif Path(pairs_file.name).suffix == ".tsv":
101
- pairs = pd.read_csv(pairs_file.name, sep="\t")
102
- try:
103
- pairs.columns = ["protein1", "protein2"]
104
- except ValueError as _:
105
- raise gr.Error("Invalid pairs file - must have two columns 'protein1' and 'protein2'")
106
-
107
- do_foldseek = False
108
- if model_name == "TT3D":
109
- do_foldseek = True
110
-
111
- need_to_translate = set(pairs["protein1"]).union(set(pairs["protein2"]))
112
- seqs_to_translate = {k: str(seqs[k].seq) for k in need_to_translate if k in seqs}
113
-
114
- half_precision = False
115
- assert not (half_precision and device=="cpu"), print("Running fp16 on CPU is not supported, yet")
116
-
117
- gr.Info(f"Loading Foldseek embeddings -- this may take some time ({len(seqs_to_translate)} embeddings)...")
118
- predictions = get_3di_sequences(
119
- seqs_to_translate,
120
- model_dir = "Rostlab/ProstT5",
121
- report_fn = gr.Info,
122
- error_fn = gr.Error,
123
- device=device,
124
- )
125
- foldseek_sequences = predictions_to_dict(predictions)
126
- foldseek_embeddings = {k: one_hot_3di_sequence(s.upper(), fold_vocab) for k, s in foldseek_sequences.items()}
127
-
128
- # for k in seqs_to_translate.keys():
129
- # print(seqs_to_translate[k])
130
- # print(len(seqs_to_translate[k]))
131
- # print(foldseek_embeddings[k])
132
- # print(foldseek_embeddings[k].shape)
133
-
134
- progress(0, desc="Starting...")
135
- results = []
136
- for i in progress.tqdm(range(len(pairs))):
137
-
138
- r = pairs.iloc[i]
139
-
140
- prot1 = r["protein1"]
141
- prot2 = r["protein2"]
142
-
143
- seq1 = str(seqs[prot1].seq)
144
- seq2 = str(seqs[prot2].seq)
145
 
146
- fold1 = foldseek_embeddings[prot1].to(device) if do_foldseek else None
147
- fold2 = foldseek_embeddings[prot2].to(device) if do_foldseek else None
148
-
149
- lm1 = lm_embed(seq1).to(device)
150
- lm2 = lm_embed(seq2).to(device)
151
-
152
- interaction = model.predict(lm1, lm2, embed_foldseek = do_foldseek, f0 = fold1, f1 = fold2).item()
153
-
154
- results.append([prot1, prot2, interaction])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
157
 
158
- file_path = f"/tmp/{run_id}.tsv"
159
- with open(file_path, "w") as f:
160
- results.to_csv(f, sep="\t", index=False, header = True)
161
-
162
- return results, file_path
163
 
164
  demo = gr.Interface(
165
  fn=predict,
 
20
  theme = "Default"
21
  title = "D-SCRIPT: Predicting Protein-Protein Interactions"
22
  description = """
23
+ If you use this interface to make predictions, please let us know (email samsl@mit.edu)!
24
+ We want to keep this web version free to use with GPU support, and to do that we need to demonstrate to
25
+ our funders that it is being used. Thank you!
26
  """
27
 
28
  # article = """
 
55
 
56
  Note that running here with the "TT3D" model does not run structure prediction on the sequences, but rather uses the [ProstT5](https://github.com/mheinzinger/ProstT5) language model to
57
  translate amino acid to 3di sequences. This is much faster than running structure prediction, but the results may not be as accurate.
 
58
  """
59
 
60
  fold_vocab = {
 
83
 
84
  def predict(model_name, pairs_file, sequence_file, progress = gr.Progress()):
85
 
 
 
 
 
 
 
 
 
 
86
  try:
87
+ run_id = uuid4()
88
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # gr.Info("Loading model...")
91
+ _ = lm_embed("M", use_cuda = (device.type == "cuda"))
92
+
93
+ model = get_pretrained(model_map[model_name]).to(device)
94
+
95
+ # gr.Info("Loading files...")
96
+ try:
97
+ seqs = SeqIO.to_dict(SeqIO.parse(sequence_file.name, "fasta"))
98
+ except ValueError as _:
99
+ raise gr.Error("Invalid FASTA file - duplicate entry")
100
+
101
+ if Path(pairs_file.name).suffix == ".csv":
102
+ pairs = pd.read_csv(pairs_file.name)
103
+ elif Path(pairs_file.name).suffix == ".tsv":
104
+ pairs = pd.read_csv(pairs_file.name, sep="\t")
105
+ try:
106
+ pairs.columns = ["protein1", "protein2"]
107
+ except ValueError as _:
108
+ raise gr.Error("Invalid pairs file - must have two columns 'protein1' and 'protein2'")
109
+
110
+ do_foldseek = False
111
+ if model_name == "TT3D":
112
+ do_foldseek = True
113
+
114
+ need_to_translate = set(pairs["protein1"]).union(set(pairs["protein2"]))
115
+ seqs_to_translate = {k: str(seqs[k].seq) for k in need_to_translate if k in seqs}
116
+
117
+ half_precision = False
118
+ assert not (half_precision and device=="cpu"), print("Running fp16 on CPU is not supported, yet")
119
+
120
+ gr.Info(f"Loading Foldseek embeddings -- this may take some time ({len(seqs_to_translate)} embeddings)...")
121
+ predictions = get_3di_sequences(
122
+ seqs_to_translate,
123
+ model_dir = "Rostlab/ProstT5",
124
+ report_fn = gr.Info,
125
+ error_fn = gr.Error,
126
+ device=device,
127
+ )
128
+ foldseek_sequences = predictions_to_dict(predictions)
129
+ foldseek_embeddings = {k: one_hot_3di_sequence(s.upper(), fold_vocab) for k, s in foldseek_sequences.items()}
130
+
131
+ # for k in seqs_to_translate.keys():
132
+ # print(seqs_to_translate[k])
133
+ # print(len(seqs_to_translate[k]))
134
+ # print(foldseek_embeddings[k])
135
+ # print(foldseek_embeddings[k].shape)
136
+
137
+ progress(0, desc="Starting...")
138
+ results = []
139
+ for i in progress.tqdm(range(len(pairs))):
140
+
141
+ r = pairs.iloc[i]
142
+
143
+ prot1 = r["protein1"]
144
+ prot2 = r["protein2"]
145
+
146
+ seq1 = str(seqs[prot1].seq)
147
+ seq2 = str(seqs[prot2].seq)
148
+
149
+ fold1 = foldseek_embeddings[prot1].to(device) if do_foldseek else None
150
+ fold2 = foldseek_embeddings[prot2].to(device) if do_foldseek else None
151
+
152
+ lm1 = lm_embed(seq1).to(device)
153
+ lm2 = lm_embed(seq2).to(device)
154
+
155
+ interaction = model.predict(lm1, lm2, embed_foldseek = do_foldseek, f0 = fold1, f1 = fold2).item()
156
+
157
+ results.append([prot1, prot2, interaction])
158
+
159
+ results = pd.DataFrame(results, columns = ["Protein 1", "Protein 2", "Interaction"])
160
+
161
+ file_path = f"/tmp/{run_id}.tsv"
162
+ with open(file_path, "w") as f:
163
+ results.to_csv(f, sep="\t", index=False, header = True)
164
 
165
+ return results, file_path
166
 
167
+ except Exception as e:
168
+ gr.Error(e)
169
+ return None, None
 
 
170
 
171
  demo = gr.Interface(
172
  fn=predict,