oweller2
commited on
Commit
•
3cd88d6
1
Parent(s):
0f166fe
unpad
Browse files- modeling_flexbert.py +8 -3
- tokenizer.py +44 -4
modeling_flexbert.py
CHANGED
@@ -1727,9 +1727,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1727 |
|
1728 |
|
1729 |
batch_size, seq_len = input_ids.shape[:2]
|
1730 |
-
|
1731 |
-
input_ids,
|
1732 |
-
|
|
|
|
|
|
|
|
|
|
|
1733 |
return {
|
1734 |
"input_ids": input_ids,
|
1735 |
"attention_mask": attention_mask,
|
|
|
1727 |
|
1728 |
|
1729 |
batch_size, seq_len = input_ids.shape[:2]
|
1730 |
+
if self.unpad_embeddings:
|
1731 |
+
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
|
1732 |
+
input_ids, attention_mask, position_ids, None
|
1733 |
+
)
|
1734 |
+
else:
|
1735 |
+
indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).repeat(batch_size, 1)
|
1736 |
+
cu_seqlens = None
|
1737 |
+
max_seqlen = None
|
1738 |
return {
|
1739 |
"input_ids": input_ids,
|
1740 |
"attention_mask": attention_mask,
|
tokenizer.py
CHANGED
@@ -7,13 +7,53 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
|
7 |
def _batch_encode_plus(self, *args, **kwargs):
|
8 |
outputs = super()._batch_encode_plus(*args, **kwargs)
|
9 |
del outputs["token_type_ids"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
for key in ['input_ids', 'attention_mask']:
|
11 |
if isinstance(outputs[key], torch.Tensor):
|
12 |
-
|
13 |
-
|
14 |
-
outputs[key] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
elif isinstance(outputs[key], list):
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
return outputs
|
18 |
|
19 |
# Register the class
|
|
|
7 |
def _batch_encode_plus(self, *args, **kwargs):
|
8 |
outputs = super()._batch_encode_plus(*args, **kwargs)
|
9 |
del outputs["token_type_ids"]
|
10 |
+
|
11 |
+
# Get the input_ids to check for EOS tokens
|
12 |
+
input_ids = outputs['input_ids']
|
13 |
+
|
14 |
+
# Function to check if sequence ends with EOS token
|
15 |
+
def ends_with_eos(sequence):
|
16 |
+
if len(sequence) == 0:
|
17 |
+
return False
|
18 |
+
return sequence[-1] == self.eos_token_id
|
19 |
+
|
20 |
+
# Check for EOS tokens using input_ids only
|
21 |
+
if isinstance(input_ids, torch.Tensor):
|
22 |
+
last_token_is_eos = torch.tensor([
|
23 |
+
ends_with_eos(seq) for seq in input_ids
|
24 |
+
], dtype=torch.bool)
|
25 |
+
elif isinstance(input_ids, numpy.ndarray):
|
26 |
+
last_token_is_eos = numpy.array([
|
27 |
+
ends_with_eos(seq) for seq in input_ids
|
28 |
+
], dtype=bool)
|
29 |
+
elif isinstance(input_ids, list):
|
30 |
+
last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
|
31 |
+
|
32 |
+
# Use the same last_token_is_eos check for both input_ids and attention_mask
|
33 |
for key in ['input_ids', 'attention_mask']:
|
34 |
if isinstance(outputs[key], torch.Tensor):
|
35 |
+
# Only remove last token where last_token_is_eos is True
|
36 |
+
mask = last_token_is_eos.unsqueeze(-1)
|
37 |
+
outputs[key] = torch.where(
|
38 |
+
mask,
|
39 |
+
outputs[key][..., :-1],
|
40 |
+
outputs[key]
|
41 |
+
)
|
42 |
+
elif isinstance(outputs[key], numpy.ndarray):
|
43 |
+
# Expand dimensions for broadcasting
|
44 |
+
mask = numpy.expand_dims(last_token_is_eos, -1)
|
45 |
+
outputs[key] = numpy.where(
|
46 |
+
mask,
|
47 |
+
outputs[key][..., :-1],
|
48 |
+
outputs[key]
|
49 |
+
)
|
50 |
elif isinstance(outputs[key], list):
|
51 |
+
# For lists, use the same last_token_is_eos list for both keys
|
52 |
+
outputs[key] = [
|
53 |
+
sequence[:-1] if is_eos else sequence
|
54 |
+
for sequence, is_eos in zip(outputs[key], last_token_is_eos)
|
55 |
+
]
|
56 |
+
|
57 |
return outputs
|
58 |
|
59 |
# Register the class
|