Loli-Killer commited on
Commit
1a696b5
β€’
1 Parent(s): 53fe34a

Added protein_bind methods

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +32 -23
  3. proteinbind_new.py +10 -9
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
app.py CHANGED
@@ -1,18 +1,18 @@
1
  # credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py
2
- from typing import Tuple
3
  import os
4
  import sys
5
  from urllib import request
6
 
 
7
  import gradio as gr
 
8
  import requests
9
- from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmModel, AutoModel
10
  import torch
11
- import progres as pg
12
- import esm
13
 
14
  import msa
15
-
16
 
17
  tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
18
  model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
@@ -30,6 +30,15 @@ msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR
30
  msa_transformer = msa_transformer.eval()
31
  msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter()
32
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  def nt_embed(sequence: str):
@@ -37,38 +46,38 @@ def nt_embed(sequence: str):
37
  attention_mask = tokens_ids != tokenizer_nt.pad_token_id
38
  with torch.no_grad():
39
  torch_outs = model_nt(
40
- tokens_ids,#.to('cuda'),
41
- attention_mask=attention_mask,#.to('cuda'),
42
  output_hidden_states=True
43
  )
44
  last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0]
45
- return last_layer_CLS
46
 
47
 
48
  def aa_embed(sequence: str):
49
  tokens = tokenizer_aa([sequence], return_tensors="pt")
50
  with torch.no_grad():
51
  torch_outs = model_aa(**tokens)
52
- return torch_outs[0]
53
 
54
 
55
  def se_embed(sentence: str):
56
  encoded_input = tokenizer_se([sentence], return_tensors='pt')
57
  with torch.no_grad():
58
  model_output = model_se(**encoded_input)
59
- return model_output[0]
60
 
61
 
62
  def msa_embed(sequences: list):
63
- inputs = msa.greedy_select(sequences, num_seqs=128) # can change this to pass more/fewer sequences
64
  msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
65
  msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
66
-
67
  with torch.no_grad():
68
- temp = msa_transformer(msa_transformer_batch_tokens,repr_layers=[12])['representations']
69
- temp = temp[12][:,:,0,:]
70
- temp = torch.mean(temp,(0,1))
71
- return temp
72
 
73
 
74
  def go_embed(terms):
@@ -79,13 +88,13 @@ def download_data_if_required():
79
  url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
80
  fps = [pg.trained_model_fp]
81
  urls = [f"{url_base}/trained_model.pt"]
82
- #for targetdb in pre_embedded_dbs:
83
  # fps.append(os.path.join(database_dir, targetdb + ".pt"))
84
  # urls.append(f"{url_base}/{targetdb}.pt")
85
 
86
  if not os.path.isdir(pg.trained_model_dir):
87
  os.makedirs(pg.trained_model_dir)
88
- #if not os.path.isdir(database_dir):
89
  # os.makedirs(database_dir)
90
 
91
  printed = False
@@ -103,7 +112,7 @@ def download_data_if_required():
103
  assert "model" in d
104
  else:
105
  assert "embeddings" in d
106
- except:
107
  if os.path.isfile(fp):
108
  os.remove(fp)
109
  print("Failed to download from", url, "and save to", fp, file=sys.stderr)
@@ -119,7 +128,7 @@ def get_pdb(pdb_code="", filepath=""):
119
  try:
120
  with open(filepath.name) as f:
121
  return f.read()
122
- except AttributeError as e:
123
  return None
124
  else:
125
  return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode()
@@ -150,12 +159,12 @@ def molecule(pdb):
150
  </head>
151
  <body>
152
  <div id="container" class="mol-container"></div>
153
-
154
  <script>
155
  let pdb = `"""
156
  + pdb
157
  + """`
158
-
159
  $(document).ready(function () {
160
  let element = $("#container");
161
  let config = { backgroundColor: "black" };
@@ -272,4 +281,4 @@ with demo:
272
 
273
  if __name__ == "__main__":
274
  download_data_if_required()
275
- demo.launch()
 
1
  # credit: https://huggingface.co/spaces/simonduerr/3dmol.js/blob/main/app.py
 
2
  import os
3
  import sys
4
  from urllib import request
5
 
6
+ import esm
7
  import gradio as gr
8
+ import progres as pg
9
  import requests
 
10
  import torch
11
+ from transformers import (AutoModel, AutoModelForMaskedLM, AutoTokenizer,
12
+ EsmModel)
13
 
14
  import msa
15
+ import proteinbind_new
16
 
17
  tokenizer_nt = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
18
  model_nt = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-1000g")
 
30
  msa_transformer = msa_transformer.eval()
31
  msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter()
32
 
33
+ model = proteinbind_new.create_proteinbind(True)
34
+
35
+
36
+ def pass_through(torch_output, key: str):
37
+ input_data = {
38
+ key: torch_output,
39
+ }
40
+ output = model(input_data)
41
+ return output[key]
42
 
43
 
44
  def nt_embed(sequence: str):
 
46
  attention_mask = tokens_ids != tokenizer_nt.pad_token_id
47
  with torch.no_grad():
48
  torch_outs = model_nt(
49
+ tokens_ids, # .to('cuda'),
50
+ attention_mask=attention_mask, # .to('cuda'),
51
  output_hidden_states=True
52
  )
53
  last_layer_CLS = torch_outs.hidden_states[-1].detach()[:, 0, :][0]
54
+ return pass_through(last_layer_CLS, "dna")
55
 
56
 
57
  def aa_embed(sequence: str):
58
  tokens = tokenizer_aa([sequence], return_tensors="pt")
59
  with torch.no_grad():
60
  torch_outs = model_aa(**tokens)
61
+ return pass_through(torch_outs[0], "aa")
62
 
63
 
64
  def se_embed(sentence: str):
65
  encoded_input = tokenizer_se([sentence], return_tensors='pt')
66
  with torch.no_grad():
67
  model_output = model_se(**encoded_input)
68
+ return pass_through(model_output[0], "text")
69
 
70
 
71
  def msa_embed(sequences: list):
72
+ inputs = msa.greedy_select(sequences, num_seqs=128) # can change this to pass more/fewer sequences
73
  msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
74
  msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
75
+
76
  with torch.no_grad():
77
+ temp = msa_transformer(msa_transformer_batch_tokens, repr_layers=[12])['representations']
78
+ temp = temp[12][:, :, 0, :]
79
+ temp = torch.mean(temp, (0, 1))
80
+ return pass_through(temp, "msa")
81
 
82
 
83
  def go_embed(terms):
 
88
  url_base = f"https://zenodo.org/record/{pg.zenodo_record}/files"
89
  fps = [pg.trained_model_fp]
90
  urls = [f"{url_base}/trained_model.pt"]
91
+ # for targetdb in pre_embedded_dbs:
92
  # fps.append(os.path.join(database_dir, targetdb + ".pt"))
93
  # urls.append(f"{url_base}/{targetdb}.pt")
94
 
95
  if not os.path.isdir(pg.trained_model_dir):
96
  os.makedirs(pg.trained_model_dir)
97
+ # if not os.path.isdir(database_dir):
98
  # os.makedirs(database_dir)
99
 
100
  printed = False
 
112
  assert "model" in d
113
  else:
114
  assert "embeddings" in d
115
+ except Exception:
116
  if os.path.isfile(fp):
117
  os.remove(fp)
118
  print("Failed to download from", url, "and save to", fp, file=sys.stderr)
 
128
  try:
129
  with open(filepath.name) as f:
130
  return f.read()
131
+ except AttributeError:
132
  return None
133
  else:
134
  return requests.get(f"https://files.rcsb.org/view/{pdb_code}.pdb").content.decode()
 
159
  </head>
160
  <body>
161
  <div id="container" class="mol-container"></div>
162
+
163
  <script>
164
  let pdb = `"""
165
  + pdb
166
  + """`
167
+
168
  $(document).ready(function () {
169
  let element = $("#container");
170
  let config = { backgroundColor: "black" };
 
281
 
282
  if __name__ == "__main__":
283
  download_data_if_required()
284
+ demo.launch()
proteinbind_new.py CHANGED
@@ -15,6 +15,7 @@ ModalityType = SimpleNamespace(
15
  TEXT="text",
16
  )
17
 
 
18
  class Normalize(nn.Module):
19
  def __init__(self, dim: int) -> None:
20
  super().__init__()
@@ -23,6 +24,7 @@ class Normalize(nn.Module):
23
  def forward(self, x):
24
  return torch.nn.functional.normalize(x, dim=self.dim, p=2)
25
 
 
26
  class EmbeddingDataset(Dataset):
27
  """
28
  The main class for turning any modality to a torch Dataset that can be passed to
@@ -42,6 +44,7 @@ class EmbeddingDataset(Dataset):
42
  embedding = self.embedding[idx]
43
  return {"aa": sequence, self.modality: embedding}
44
 
 
45
  class DualEmbeddingDataset(Dataset):
46
  """
47
  The main class for turning any modality to a torch Dataset that can be passed to
@@ -60,7 +63,8 @@ class DualEmbeddingDataset(Dataset):
60
  sequence_embedding = self.sequence_embedding[idx]
61
  embedding = self.embedding[idx]
62
  return {"aa": sequence_embedding, self.modality: embedding}
63
-
 
64
  class ProteinBindModel(nn.Module):
65
 
66
  def __init__(
@@ -92,7 +96,6 @@ class ProteinBindModel(nn.Module):
92
  out_embed_dim
93
  )
94
 
95
-
96
  def _create_modality_trunk(
97
  self,
98
  aa_embed_dim,
@@ -140,7 +143,7 @@ class ProteinBindModel(nn.Module):
140
  nn.ReLU(),
141
  nn.Linear(512, in_embed_dim),
142
  )
143
-
144
  modality_trunks[ModalityType.GO] = nn.Sequential(
145
  nn.Linear(go_embed_dim, 512),
146
  nn.ReLU(),
@@ -220,7 +223,6 @@ class ProteinBindModel(nn.Module):
220
  modality_postprocessors[ModalityType.GO] = Normalize(dim=-1)
221
  modality_postprocessors[ModalityType.MSA] = Normalize(dim=-1)
222
 
223
-
224
  return nn.ModuleDict(modality_postprocessors)
225
 
226
  def forward(self, inputs):
@@ -239,7 +241,6 @@ class ProteinBindModel(nn.Module):
239
 
240
  for modality_key, modality_value in inputs.items():
241
 
242
-
243
  modality_value = self.modality_trunks[modality_key](
244
  modality_value
245
  )
@@ -247,10 +248,10 @@ class ProteinBindModel(nn.Module):
247
  modality_value = self.modality_heads[modality_key](
248
  modality_value
249
  )
250
-
251
  modality_value = self.modality_postprocessors[modality_key](
252
- modality_value
253
- )
254
  outputs[modality_key] = modality_value
255
 
256
  return outputs
@@ -274,7 +275,7 @@ def create_proteinbind(pretrained=False):
274
  )
275
 
276
  if pretrained:
277
- #get path from config
278
  PATH = 'best_model.pth'
279
 
280
  model.load_state_dict(torch.load(PATH))
 
15
  TEXT="text",
16
  )
17
 
18
+
19
  class Normalize(nn.Module):
20
  def __init__(self, dim: int) -> None:
21
  super().__init__()
 
24
  def forward(self, x):
25
  return torch.nn.functional.normalize(x, dim=self.dim, p=2)
26
 
27
+
28
  class EmbeddingDataset(Dataset):
29
  """
30
  The main class for turning any modality to a torch Dataset that can be passed to
 
44
  embedding = self.embedding[idx]
45
  return {"aa": sequence, self.modality: embedding}
46
 
47
+
48
  class DualEmbeddingDataset(Dataset):
49
  """
50
  The main class for turning any modality to a torch Dataset that can be passed to
 
63
  sequence_embedding = self.sequence_embedding[idx]
64
  embedding = self.embedding[idx]
65
  return {"aa": sequence_embedding, self.modality: embedding}
66
+
67
+
68
  class ProteinBindModel(nn.Module):
69
 
70
  def __init__(
 
96
  out_embed_dim
97
  )
98
 
 
99
  def _create_modality_trunk(
100
  self,
101
  aa_embed_dim,
 
143
  nn.ReLU(),
144
  nn.Linear(512, in_embed_dim),
145
  )
146
+
147
  modality_trunks[ModalityType.GO] = nn.Sequential(
148
  nn.Linear(go_embed_dim, 512),
149
  nn.ReLU(),
 
223
  modality_postprocessors[ModalityType.GO] = Normalize(dim=-1)
224
  modality_postprocessors[ModalityType.MSA] = Normalize(dim=-1)
225
 
 
226
  return nn.ModuleDict(modality_postprocessors)
227
 
228
  def forward(self, inputs):
 
241
 
242
  for modality_key, modality_value in inputs.items():
243
 
 
244
  modality_value = self.modality_trunks[modality_key](
245
  modality_value
246
  )
 
248
  modality_value = self.modality_heads[modality_key](
249
  modality_value
250
  )
251
+
252
  modality_value = self.modality_postprocessors[modality_key](
253
+ modality_value
254
+ )
255
  outputs[modality_key] = modality_value
256
 
257
  return outputs
 
275
  )
276
 
277
  if pretrained:
278
+ # get path from config
279
  PATH = 'best_model.pth'
280
 
281
  model.load_state_dict(torch.load(PATH))