hdallatorre commited on
Commit
66427d7
1 Parent(s): 2fcf7c8

Upload SegmentNT

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. modeling_segment_nt.py +29 -27
  3. pytorch_model.bin +2 -2
config.json CHANGED
@@ -40,7 +40,7 @@
40
  "num_layers_head": 2,
41
  "pad_token_id": 1,
42
  "position_embedding_type": "rotary",
43
- "rescaling_factor": null,
44
  "tie_word_embeddings": false,
45
  "token_dropout": false,
46
  "torch_dtype": "float32",
 
40
  "num_layers_head": 2,
41
  "pad_token_id": 1,
42
  "position_embedding_type": "rotary",
43
+ "rescaling_factor": 2.44140625,
44
  "tie_word_embeddings": false,
45
  "token_dropout": false,
46
  "torch_dtype": "float32",
modeling_segment_nt.py CHANGED
@@ -115,56 +115,58 @@ class RotaryEmbedding(torch.nn.Module):
115
  super().__init__()
116
 
117
  # Extract argument from the config
118
- rescaling_factor = rotary_embedding_config.rescaling_factor
119
- upper_freq = 10000
120
-
121
- if rescaling_factor is None:
122
- inv_freq = 1.0 / (upper_freq ** (torch.arange(0, dim, 2).float() / dim))
123
- else:
124
- updated_base = upper_freq * (
125
- rescaling_factor ** (dim / (dim - 2))
126
- )
127
- inv_freq = 1.0 / (
128
- updated_base ** (torch.arange(0, dim, 2).float() / dim)
129
- )
130
-
131
- self.register_buffer("inv_freq", inv_freq)
132
 
133
  self._seq_len_cached = None
134
  self._cos_cached = None
135
  self._sin_cached = None
136
 
137
- def _update_cos_sin_tables(self, x, seq_dimension=2):
 
 
138
  seq_len = x.shape[seq_dimension]
139
 
140
  # Reset the tables if the sequence length has changed,
141
  # or if we're on a new device (possibly due to tracing for instance)
142
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
143
- self._seq_len_cached = seq_len
144
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
145
- self.inv_freq
146
- )
147
- freqs = torch.outer(t, self.inv_freq)
148
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
149
 
150
- self._cos_cached = emb.cos()[None, None, :, :]
151
- self._sin_cached = emb.sin()[None, None, :, :]
152
 
153
  return self._cos_cached, self._sin_cached
154
 
155
  def forward(
156
  self, q: torch.Tensor, k: torch.Tensor
157
  ) -> Tuple[torch.Tensor, torch.Tensor]:
158
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
159
- k, seq_dimension=-2
160
- )
 
 
 
 
 
 
 
161
 
 
 
 
 
162
  return (
163
  apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
164
  apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
165
  )
166
 
167
 
 
168
  class EsmContactPredictionHead(nn.Module):
169
  """Performs symmetrization, apc, and computes a logistic regression on the output features"""
170
 
 
115
  super().__init__()
116
 
117
  # Extract argument from the config
118
+ self.rescaling_factor = rotary_embedding_config.rescaling_factor
119
+ self.upper_freq = 10000
120
+ self.dim = dim
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  self._seq_len_cached = None
123
  self._cos_cached = None
124
  self._sin_cached = None
125
 
126
+
127
+
128
+ def _compute_cos_sin_tables(self, x, inv_freq, seq_dimension=2):
129
  seq_len = x.shape[seq_dimension]
130
 
131
  # Reset the tables if the sequence length has changed,
132
  # or if we're on a new device (possibly due to tracing for instance)
133
+ self._seq_len_cached = seq_len
134
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
135
+ inv_freq
136
+ )
137
+ freqs = torch.outer(t, inv_freq)
138
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
 
139
 
140
+ self._cos_cached = emb.cos()[None, None, :, :]
141
+ self._sin_cached = emb.sin()[None, None, :, :]
142
 
143
  return self._cos_cached, self._sin_cached
144
 
145
  def forward(
146
  self, q: torch.Tensor, k: torch.Tensor
147
  ) -> Tuple[torch.Tensor, torch.Tensor]:
148
+
149
+ if self.rescaling_factor is None:
150
+ inv_freq = 1.0 / (self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim))
151
+ else:
152
+ updated_base = self.upper_freq * (
153
+ self.rescaling_factor ** (self.dim / (self.dim - 2))
154
+ )
155
+ inv_freq = 1.0 / (
156
+ updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
157
+ )
158
 
159
+ self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
160
+ k, inv_freq, seq_dimension=-2,
161
+ )
162
+
163
  return (
164
  apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
165
  apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
166
  )
167
 
168
 
169
+
170
  class EsmContactPredictionHead(nn.Module):
171
  """Performs symmetrization, apc, and computes a logistic regression on the output features"""
172
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ae081cbe0ee351a510930a8d2d5a94e150c1e40afdb93e69fea5d345639ad2cf
3
- size 2237478985
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf3b06784e943efd3f33b6059ad921218490cd691d2a0ffb11db3da8ef424b5d
3
+ size 2237465429