Truncate to 8k by default

#5
by Jackmin108 - opened
Files changed (1) hide show
  1. modeling_bert.py +3 -2
modeling_bert.py CHANGED
@@ -1195,7 +1195,9 @@ class JinaBertModel(JinaBertPreTrainedModel):
1195
  inverse_permutation = np.argsort(permutation)
1196
  sentences = [sentences[idx] for idx in permutation]
1197
 
1198
- padding = tokenizer_kwargs.pop('padding', True)
 
 
1199
 
1200
  all_embeddings = []
1201
 
@@ -1214,7 +1216,6 @@ class JinaBertModel(JinaBertPreTrainedModel):
1214
  encoded_input = self.tokenizer(
1215
  sentences[i : i + batch_size],
1216
  return_tensors='pt',
1217
- padding=padding,
1218
  **tokenizer_kwargs,
1219
  ).to(self.device)
1220
  token_embs = self.forward(**encoded_input)[0]
 
1195
  inverse_permutation = np.argsort(permutation)
1196
  sentences = [sentences[idx] for idx in permutation]
1197
 
1198
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
1199
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 8192)
1200
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
1201
 
1202
  all_embeddings = []
1203
 
 
1216
  encoded_input = self.tokenizer(
1217
  sentences[i : i + batch_size],
1218
  return_tensors='pt',
 
1219
  **tokenizer_kwargs,
1220
  ).to(self.device)
1221
  token_embs = self.forward(**encoded_input)[0]