aksell commited on
Commit
386fb31
·
1 Parent(s): 6002f51

Remove Prot-T5 and ProtGPT2

Browse files

The attention visualization was not implemented since these
models don't tokenize single residues, and the models are
too large to be loaded with the current infrastructure.

Could be added back in the future if needed.

Files changed (3) hide show
  1. hexviz/attention.py +1 -29
  2. hexviz/models.py +4 -27
  3. tests/test_models.py +2 -14
hexviz/attention.py CHANGED
@@ -7,8 +7,7 @@ import streamlit as st
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
9
 
10
- from hexviz.models import (ModelType, get_prot_bert, get_protgpt2, get_protT5,
11
- get_tape_bert, get_zymctrl)
12
 
13
 
14
  def get_structure(pdb_code: str) -> Structure:
@@ -71,20 +70,6 @@ def get_attention(
71
  attention_stacked = torch.stack([attention for attention in attention_squeezed])
72
  attentions = attention_stacked
73
 
74
- elif model_type == ModelType.ProtGPT2:
75
- tokenizer, model = get_protgpt2()
76
- input_ids = tokenizer.encode(input, return_tensors='pt').to(device)
77
- with torch.no_grad():
78
- outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
79
- attentions = outputs.attentions
80
-
81
- # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
82
- attention_squeezed = [torch.squeeze(attention) for attention in attentions]
83
- # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
84
- attention_stacked = torch.stack([attention for attention in attention_squeezed])
85
- attentions = attention_stacked
86
- # TODO extend attentions to be per token, not per word piece
87
- # simplest way to draw attention for multi residue token models for now
88
  elif model_type == ModelType.PROT_BERT:
89
  tokenizer, model = get_prot_bert()
90
  token_idxs = tokenizer.encode(sequence)
@@ -95,19 +80,6 @@ def get_attention(
95
  attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
96
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
97
 
98
- elif model_type == ModelType.PROT_T5:
99
- # Introduce white-space between all amino acids
100
- sequence = " ".join(sequence)
101
- # tokenize sequences and pad up to the longest sequence in the batch
102
- ids = tokenizer.encode_plus(sequence, add_special_tokens=True, padding="longest")
103
-
104
- input_ids = torch.tensor(ids['input_ids']).to(device)
105
- attention_mask = torch.tensor(ids['attention_mask']).to(device)
106
-
107
- with torch.no_grad():
108
- attns = model(input_ids=input_ids,attention_mask=attention_mask)[-1]
109
-
110
- tokenizer, model = get_protT5()
111
  else:
112
  raise ValueError(f"Model {model_type} not supported")
113
 
 
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
9
 
10
+ from hexviz.models import ModelType, get_prot_bert, get_tape_bert, get_zymctrl
 
11
 
12
 
13
  def get_structure(pdb_code: str) -> Structure:
 
70
  attention_stacked = torch.stack([attention for attention in attention_squeezed])
71
  attentions = attention_stacked
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  elif model_type == ModelType.PROT_BERT:
74
  tokenizer, model = get_prot_bert()
75
  token_idxs = tokenizer.encode(sequence)
 
80
  attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
81
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  else:
84
  raise ValueError(f"Model {model_type} not supported")
85
 
hexviz/models.py CHANGED
@@ -5,14 +5,12 @@ import streamlit as st
5
  import torch
6
  from tape import ProteinBertModel, TAPETokenizer
7
  from transformers import (AutoTokenizer, BertForMaskedLM, BertTokenizer,
8
- GPT2LMHeadModel, T5EncoderModel, T5Tokenizer)
9
 
10
 
11
  class ModelType(str, Enum):
12
  TAPE_BERT = "TAPE-BERT"
13
- PROT_T5 = "prot_t5_xl_half_uniref50-enc"
14
  ZymCTRL = "ZymCTRL"
15
- ProtGPT2 = "ProtGPT2"
16
  PROT_BERT = "ProtBert"
17
 
18
 
@@ -22,20 +20,6 @@ class Model:
22
  self.layers: int = layers
23
  self.heads: int = heads
24
 
25
- @st.cache
26
- def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
27
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
- tokenizer = T5Tokenizer.from_pretrained(
29
- "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
30
- )
31
-
32
- model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
33
- device
34
- )
35
-
36
- model.full() if device == "cpu" else model.half()
37
-
38
- return tokenizer, model
39
 
40
  @st.cache
41
  def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
@@ -43,12 +27,6 @@ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
43
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
44
  return tokenizer, model
45
 
46
- @st.cache
47
- def get_prot_bert() -> Tuple[BertTokenizer, BertForMaskedLM]:
48
- tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
49
- model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
50
- return tokenizer, model
51
-
52
  @st.cache
53
  def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
54
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -57,8 +35,7 @@ def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
57
  return tokenizer, model
58
 
59
  @st.cache
60
- def get_protgpt2() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
61
- device = torch.device('cuda')
62
- tokenizer = AutoTokenizer.from_pretrained('nferruz/ProtGPT2')
63
- model = GPT2LMHeadModel.from_pretrained('nferruz/ProtGPT2').to(device)
64
  return tokenizer, model
 
5
  import torch
6
  from tape import ProteinBertModel, TAPETokenizer
7
  from transformers import (AutoTokenizer, BertForMaskedLM, BertTokenizer,
8
+ GPT2LMHeadModel)
9
 
10
 
11
  class ModelType(str, Enum):
12
  TAPE_BERT = "TAPE-BERT"
 
13
  ZymCTRL = "ZymCTRL"
 
14
  PROT_BERT = "ProtBert"
15
 
16
 
 
20
  self.layers: int = layers
21
  self.heads: int = heads
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @st.cache
25
  def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
 
27
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
28
  return tokenizer, model
29
 
 
 
 
 
 
 
30
  @st.cache
31
  def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
35
  return tokenizer, model
36
 
37
  @st.cache
38
+ def get_prot_bert() -> Tuple[BertTokenizer, BertForMaskedLM]:
39
+ tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
40
+ model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
 
41
  return tokenizer, model
tests/test_models.py CHANGED
@@ -1,21 +1,9 @@
1
 
2
- from transformers import (GPT2LMHeadModel, GPT2TokenizerFast, T5EncoderModel,
3
- T5Tokenizer)
4
 
5
- from hexviz.models import get_protT5, get_zymctrl
6
 
7
 
8
- def test_get_protT5():
9
- result = get_protT5()
10
-
11
- assert result is not None
12
- assert isinstance(result, tuple)
13
-
14
- tokenizer, model = result
15
-
16
- assert isinstance(tokenizer, T5Tokenizer)
17
- assert isinstance(model, T5EncoderModel)
18
-
19
  def test_get_zymctrl():
20
  result = get_zymctrl()
21
 
 
1
 
2
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
 
3
 
4
+ from hexviz.models import get_zymctrl
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  def test_get_zymctrl():
8
  result = get_zymctrl()
9