rrivera1849 commited on
Commit
afd830f
1 Parent(s): 036bf15

Upload LUAR

Browse files
Files changed (1) hide show
  1. model.py +29 -12
model.py CHANGED
@@ -146,7 +146,7 @@ class LUAR(PreTrainedModel):
146
  config.k_bucket_size,
147
  )
148
  self.linear = nn.Linear(self.hidden_size, config.embedding_size)
149
-
150
  def create_transformer(self):
151
  """Creates the Transformer backbone.
152
  """
@@ -163,7 +163,7 @@ class LUAR(PreTrainedModel):
163
  sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
164
  return sum_embeddings / sum_mask
165
 
166
- def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False):
167
  """Computes the Author Embedding.
168
  """
169
  B, E, _ = attention_mask.shape
@@ -171,14 +171,31 @@ class LUAR(PreTrainedModel):
171
  input_ids = rearrange(input_ids, 'b e l -> (b e) l')
172
  attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
173
 
174
- outputs = self.transformer(
175
- input_ids=input_ids,
176
- attention_mask=attention_mask,
177
- return_dict=True,
178
- output_hidden_states=True,
179
- output_attentions=output_attentions,
180
- )
181
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  # at this point, we're embedding individual "comments"
183
  comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
184
  comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
@@ -194,9 +211,9 @@ class LUAR(PreTrainedModel):
194
 
195
  return episode_embeddings
196
 
197
- def forward(self, input_ids, attention_mask, output_attentions=False):
198
  """Calculates a fixed-length feature vector for a batch of episode samples.
199
  """
200
- output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
201
 
202
  return output
 
146
  config.k_bucket_size,
147
  )
148
  self.linear = nn.Linear(self.hidden_size, config.embedding_size)
149
+
150
  def create_transformer(self):
151
  """Creates the Transformer backbone.
152
  """
 
163
  sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
164
  return sum_embeddings / sum_mask
165
 
166
+ def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
167
  """Computes the Author Embedding.
168
  """
169
  B, E, _ = attention_mask.shape
 
171
  input_ids = rearrange(input_ids, 'b e l -> (b e) l')
172
  attention_mask = rearrange(attention_mask, 'b e l -> (b e) l')
173
 
174
+ if document_batch_size > 0:
175
+ outputs = {"last_hidden_state": [], "attentions": []}
176
+ for i in range(0, len(input_ids), document_batch_size):
177
+ out = self.transformer(
178
+ input_ids=input_ids[i:i+document_batch_size],
179
+ attention_mask=attention_mask[i:i+document_batch_size],
180
+ return_dict=True,
181
+ output_hidden_states=False,
182
+ output_attentions=output_attentions,
183
+ )
184
+ outputs["last_hidden_state"].append(out["last_hidden_state"])
185
+ if output_attentions:
186
+ outputs["attentions"].append(out["attentions"])
187
+ outputs["last_hidden_state"] = torch.cat(outputs["last_hidden_state"], dim=0)
188
+ if output_attentions:
189
+ outputs["attentions"] = tuple([torch.cat([x[i] for x in outputs["attentions"]], dim=0) for i in range(len(outputs["attentions"][0]))])
190
+ else:
191
+ outputs = self.transformer(
192
+ input_ids=input_ids,
193
+ attention_mask=attention_mask,
194
+ return_dict=True,
195
+ output_hidden_states=False,
196
+ output_attentions=output_attentions,
197
+ )
198
+
199
  # at this point, we're embedding individual "comments"
200
  comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
201
  comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E)
 
211
 
212
  return episode_embeddings
213
 
214
+ def forward(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0):
215
  """Calculates a fixed-length feature vector for a batch of episode samples.
216
  """
217
+ output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions, document_batch_size)
218
 
219
  return output