oweller2 commited on
Commit
3cd88d6
1 Parent(s): 0f166fe
Files changed (2) hide show
  1. modeling_flexbert.py +8 -3
  2. tokenizer.py +44 -4
modeling_flexbert.py CHANGED
@@ -1727,9 +1727,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1727
 
1728
 
1729
  batch_size, seq_len = input_ids.shape[:2]
1730
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
1731
- input_ids, attention_mask, position_ids, None
1732
- )
 
 
 
 
 
1733
  return {
1734
  "input_ids": input_ids,
1735
  "attention_mask": attention_mask,
 
1727
 
1728
 
1729
  batch_size, seq_len = input_ids.shape[:2]
1730
+ if self.unpad_embeddings:
1731
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
1732
+ input_ids, attention_mask, position_ids, None
1733
+ )
1734
+ else:
1735
+ indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).repeat(batch_size, 1)
1736
+ cu_seqlens = None
1737
+ max_seqlen = None
1738
  return {
1739
  "input_ids": input_ids,
1740
  "attention_mask": attention_mask,
tokenizer.py CHANGED
@@ -7,13 +7,53 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
7
  def _batch_encode_plus(self, *args, **kwargs):
8
  outputs = super()._batch_encode_plus(*args, **kwargs)
9
  del outputs["token_type_ids"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  for key in ['input_ids', 'attention_mask']:
11
  if isinstance(outputs[key], torch.Tensor):
12
- outputs[key] = outputs[key][..., :-1]
13
- elif isinstance(outputs[key], numpy.ndarray):
14
- outputs[key] = outputs[key][..., :-1]
 
 
 
 
 
 
 
 
 
 
 
 
15
  elif isinstance(outputs[key], list):
16
- outputs[key] = [sequence[:-1] for sequence in outputs[key]]
 
 
 
 
 
17
  return outputs
18
 
19
  # Register the class
 
7
  def _batch_encode_plus(self, *args, **kwargs):
8
  outputs = super()._batch_encode_plus(*args, **kwargs)
9
  del outputs["token_type_ids"]
10
+
11
+ # Get the input_ids to check for EOS tokens
12
+ input_ids = outputs['input_ids']
13
+
14
+ # Function to check if sequence ends with EOS token
15
+ def ends_with_eos(sequence):
16
+ if len(sequence) == 0:
17
+ return False
18
+ return sequence[-1] == self.eos_token_id
19
+
20
+ # Check for EOS tokens using input_ids only
21
+ if isinstance(input_ids, torch.Tensor):
22
+ last_token_is_eos = torch.tensor([
23
+ ends_with_eos(seq) for seq in input_ids
24
+ ], dtype=torch.bool)
25
+ elif isinstance(input_ids, numpy.ndarray):
26
+ last_token_is_eos = numpy.array([
27
+ ends_with_eos(seq) for seq in input_ids
28
+ ], dtype=bool)
29
+ elif isinstance(input_ids, list):
30
+ last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
31
+
32
+ # Use the same last_token_is_eos check for both input_ids and attention_mask
33
  for key in ['input_ids', 'attention_mask']:
34
  if isinstance(outputs[key], torch.Tensor):
35
+ # Only remove last token where last_token_is_eos is True
36
+ mask = last_token_is_eos.unsqueeze(-1)
37
+ outputs[key] = torch.where(
38
+ mask,
39
+ outputs[key][..., :-1],
40
+ outputs[key]
41
+ )
42
+ elif isinstance(outputs[key], numpy.ndarray):
43
+ # Expand dimensions for broadcasting
44
+ mask = numpy.expand_dims(last_token_is_eos, -1)
45
+ outputs[key] = numpy.where(
46
+ mask,
47
+ outputs[key][..., :-1],
48
+ outputs[key]
49
+ )
50
  elif isinstance(outputs[key], list):
51
+ # For lists, use the same last_token_is_eos list for both keys
52
+ outputs[key] = [
53
+ sequence[:-1] if is_eos else sequence
54
+ for sequence, is_eos in zip(outputs[key], last_token_is_eos)
55
+ ]
56
+
57
  return outputs
58
 
59
  # Register the class