numb3r3 commited on
Commit
1a800ed
1 Parent(s): 583e9af

implement compute_score api

Browse files
Files changed (1) hide show
  1. modeling_bert.py +53 -2
modeling_bert.py CHANGED
@@ -421,7 +421,7 @@ class JinaBertSelfAttention(nn.Module):
421
  # seem a bit unusual, but is taken from the original Transformer paper.
422
  attention_probs = self.dropout(attention_probs)
423
 
424
- # Add the alibi matrix to the attention_scores after the call to softmax
425
  attention_scores += bias
426
 
427
  # Mask heads if we want to
@@ -435,7 +435,7 @@ class JinaBertSelfAttention(nn.Module):
435
  context_layer = context_layer.view(new_context_layer_shape)
436
 
437
  outputs = (
438
- (context_layer, attention_probs if output_attention_probs else attention_scores)
439
  if output_attentions else (context_layer,)
440
  )
441
 
@@ -2072,6 +2072,57 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
2072
  attentions=outputs.attentions,
2073
  )
2074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2075
 
2076
  @add_start_docstrings(
2077
  """
 
421
  # seem a bit unusual, but is taken from the original Transformer paper.
422
  attention_probs = self.dropout(attention_probs)
423
 
424
+ # Add the alibi matrix to the attention_scores after the call to softmax
425
  attention_scores += bias
426
 
427
  # Mask heads if we want to
 
435
  context_layer = context_layer.view(new_context_layer_shape)
436
 
437
  outputs = (
438
+ (context_layer, attention_probs if output_attention_probs else attention_scores)
439
  if output_attentions else (context_layer,)
440
  )
441
 
 
2072
  attentions=outputs.attentions,
2073
  )
2074
 
2075
+ @torch.inference_mode()
2076
+ def compute_score(
2077
+ self,
2078
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
2079
+ batch_size: int = 32,
2080
+ device: Optional[torch.device] = None,
2081
+ **tokenizer_kwargs,
2082
+ ):
2083
+ assert isinstance(sentence_pairs, list)
2084
+ if isinstance(sentence_pairs[0], str):
2085
+ sentence_pairs = [sentence_pairs]
2086
+
2087
+ if not hasattr(self, 'tokenizer'):
2088
+ from transformers import AutoTokenizer
2089
+
2090
+ self.tokenizer = AutoTokenizer.from_pretrained(self._name_or_path)
2091
+
2092
+ is_training = self.training
2093
+ self.eval()
2094
+
2095
+ if device is not None:
2096
+ self.to(device)
2097
+
2098
+ all_scores = []
2099
+ for start_index in range(
2100
+ 0, len(sentence_pairs), batch_size
2101
+ ):
2102
+ sentences_batch = sentence_pairs[
2103
+ start_index : start_index + (batch_size or self._eval_batch_size)
2104
+ ]
2105
+ inputs = self._tokenizer(
2106
+ sentences_batch,
2107
+ padding=True,
2108
+ truncation=True,
2109
+ return_tensors='pt',
2110
+ **tokenizer_kwargs,
2111
+ ).to(self.device)
2112
+
2113
+ scores = (
2114
+ self.forward(**inputs, return_dict=True)
2115
+ .logits.view(
2116
+ -1,
2117
+ )
2118
+ .float()
2119
+ )
2120
+ all_scores.extend(scores.cpu().numpy().tolist())
2121
+
2122
+ if len(all_scores) == 1:
2123
+ return all_scores[0]
2124
+ return all_scores
2125
+
2126
 
2127
  @add_start_docstrings(
2128
  """