PatrickHaller commited on
Commit
d59f795
1 Parent(s): 6fb6ad8

Upload NGMEForCausalLM

Browse files
Files changed (8) hide show
  1. config.json +35 -0
  2. configuration_ngme.py +177 -0
  3. generation_config.json +7 -0
  4. modeling_ngme.py +1206 -0
  5. ngme.py +178 -0
  6. pytorch_model.bin +3 -0
  7. sampling.py +205 -0
  8. tokenization_ngme.py +1299 -0
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/scratch/phmaker/ngme/babylm",
3
+ "architectures": [
4
+ "NGMEForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_ngme.NGMEConfig",
8
+ "AutoModelForCausalLM": "modeling_ngme.NGMEForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 0,
12
+ "ffn_dim": 512,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 1024,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 11008,
17
+ "max_position_embeddings": 2048,
18
+ "model_type": "ngme",
19
+ "num_attention_heads": 4,
20
+ "num_hidden_layers": 4,
21
+ "num_key_value_heads": 4,
22
+ "pad_token_id": 0,
23
+ "pretraining_tp": 1,
24
+ "rms_norm_eps": 1e-06,
25
+ "rope_scaling": null,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.31.0",
29
+ "unk_idx": 1,
30
+ "unk_token_id": 1,
31
+ "use_cache": true,
32
+ "use_flash_attn": false,
33
+ "use_small_embedding": false,
34
+ "vocab_size": 36484
35
+ }
configuration_ngme.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class NGMEConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`NGMEModel`]. It is used to instantiate an LLaMA
31
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
+ defaults will yield a similar configuration to that of the LLaMA-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`NGMEModel`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ pretraining_tp (`int`, *optional*, defaults to `1`):
59
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
60
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
61
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
62
+ issue](https://github.com/pytorch/pytorch/issues/76232).
63
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
64
+ The non-linear activation function (function or string) in the decoder.
65
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
66
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
67
+ just in case (e.g., 512 or 1024 or 2048).
68
+ initializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
71
+ The epsilon used by the rms normalization layers.
72
+ use_cache (`bool`, *optional*, defaults to `True`):
73
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
74
+ relevant if `config.is_decoder=True`.
75
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
76
+ Whether to tie weight embeddings
77
+ rope_scaling (`Dict`, *optional*):
78
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
79
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
80
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
81
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
82
+ these scaling strategies behave:
83
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
84
+ experimental feature, subject to breaking API changes in future versions.
85
+
86
+ Example:
87
+
88
+ ```python
89
+ >>> from transformers import NGMEModel, LlamaConfig
90
+
91
+ >>> # Initializing a LLaMA llama-7b style configuration
92
+ >>> configuration = NGMEConfig()
93
+
94
+ >>> # Initializing a model from the llama-7b style configuration
95
+ >>> model = NGMEModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+ model_type = "ngme"
101
+ keys_to_ignore_at_inference = ["past_key_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=32000,
106
+ hidden_size=4096,
107
+ intermediate_size=11008,
108
+ num_hidden_layers=32,
109
+ num_attention_heads=32,
110
+ num_key_value_heads=None,
111
+ hidden_act="silu",
112
+ max_position_embeddings=2048,
113
+ initializer_range=0.02,
114
+ rms_norm_eps=1e-6,
115
+ use_cache=True,
116
+ pad_token_id=None,
117
+ bos_token_id=1,
118
+ eos_token_id=2,
119
+ pretraining_tp=1,
120
+ tie_word_embeddings=False,
121
+ rope_scaling=None,
122
+ use_flash_attn=False,
123
+ use_small_embedding=False,
124
+ unk_idx=-1,
125
+ **kwargs,
126
+ ):
127
+ self.vocab_size = vocab_size
128
+ self.unk_idx = unk_idx
129
+ self.max_position_embeddings = max_position_embeddings
130
+ self.hidden_size = hidden_size
131
+ self.intermediate_size = intermediate_size
132
+ self.num_hidden_layers = num_hidden_layers
133
+ self.num_attention_heads = num_attention_heads
134
+
135
+ # for backward compatibility
136
+ if num_key_value_heads is None:
137
+ num_key_value_heads = num_attention_heads
138
+
139
+ self.num_key_value_heads = num_key_value_heads
140
+ self.hidden_act = hidden_act
141
+ self.initializer_range = initializer_range
142
+ self.rms_norm_eps = rms_norm_eps
143
+ self.pretraining_tp = pretraining_tp
144
+ self.use_cache = use_cache
145
+ self.rope_scaling = rope_scaling
146
+ self._rope_scaling_validation()
147
+ self.use_flash_attn = use_flash_attn
148
+ self.use_small_embedding = use_small_embedding
149
+
150
+ super().__init__(
151
+ pad_token_id=pad_token_id,
152
+ bos_token_id=bos_token_id,
153
+ eos_token_id=eos_token_id,
154
+ tie_word_embeddings=tie_word_embeddings,
155
+ **kwargs,
156
+ )
157
+
158
+ def _rope_scaling_validation(self):
159
+ """
160
+ Validate the `rope_scaling` configuration.
161
+ """
162
+ if self.rope_scaling is None:
163
+ return
164
+
165
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
166
+ raise ValueError(
167
+ "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
168
+ f"got {self.rope_scaling}"
169
+ )
170
+ rope_scaling_type = self.rope_scaling.get("type", None)
171
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
172
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
173
+ raise ValueError(
174
+ f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
175
+ )
176
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
177
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 0,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.31.0"
7
+ }
modeling_ngme.py ADDED
@@ -0,0 +1,1206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPast,
13
+ CausalLMOutputWithPast,
14
+ SequenceClassifierOutputWithPast,
15
+ )
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+ from transformers.generation.utils import SampleOutput
19
+ from transformers.generation.logits_process import LogitsProcessorList
20
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
21
+
22
+ from .tokenization_ngme import NGMETokenizer
23
+ from .sampling import sample as sample_ngme
24
+ from .configuration_ngme import NGMEConfig
25
+ from .ngme import (
26
+ soft_n_hot,
27
+ NGramsEmbedding,
28
+ collect_n_gram_sequences,
29
+ shift_with_pad,
30
+ )
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
36
+ def _make_causal_mask(
37
+ input_ids_shape: torch.Size,
38
+ dtype: torch.dtype,
39
+ device: torch.device,
40
+ past_key_values_length: int = 0,
41
+ ):
42
+ """
43
+ Make causal mask used for bi-directional self-attention.
44
+ """
45
+ bsz, tgt_len = input_ids_shape
46
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
47
+ mask_cond = torch.arange(mask.size(-1), device=device)
48
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
49
+ mask = mask.to(dtype)
50
+
51
+ if past_key_values_length > 0:
52
+ mask = torch.cat(
53
+ [
54
+ torch.zeros(
55
+ tgt_len, past_key_values_length, dtype=dtype, device=device
56
+ ),
57
+ mask,
58
+ ],
59
+ dim=-1,
60
+ )
61
+ return mask[None, None, :, :].expand(
62
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
63
+ )
64
+
65
+
66
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
67
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
68
+ """
69
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
70
+ """
71
+ bsz, src_len = mask.size()
72
+ tgt_len = tgt_len if tgt_len is not None else src_len
73
+
74
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
75
+
76
+ inverted_mask = 1.0 - expanded_mask
77
+
78
+ return inverted_mask.masked_fill(
79
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
80
+ )
81
+
82
+
83
+ class NGMERMSNorm(nn.Module):
84
+ def __init__(self, hidden_size, eps=1e-6):
85
+ """
86
+ NGMERMSNorm is equivalent to T5LayerNorm
87
+ """
88
+ super().__init__()
89
+ self.weight = nn.Parameter(torch.ones(hidden_size))
90
+ self.variance_epsilon = eps
91
+
92
+ def forward(self, hidden_states):
93
+ input_dtype = hidden_states.dtype
94
+ hidden_states = hidden_states.to(torch.float32)
95
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
96
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
97
+ return self.weight * hidden_states.to(input_dtype)
98
+
99
+
100
+ class NGMERotaryEmbedding(torch.nn.Module):
101
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
102
+ super().__init__()
103
+
104
+ self.dim = dim
105
+ self.max_position_embeddings = max_position_embeddings
106
+ self.base = base
107
+ inv_freq = 1.0 / (
108
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
109
+ )
110
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
111
+
112
+ # Build here to make `torch.jit.trace` work.
113
+ self._set_cos_sin_cache(
114
+ seq_len=max_position_embeddings,
115
+ device=self.inv_freq.device,
116
+ dtype=torch.get_default_dtype(),
117
+ )
118
+
119
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
120
+ self.max_seq_len_cached = seq_len
121
+ t = torch.arange(
122
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
123
+ )
124
+
125
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
126
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
127
+ emb = torch.cat((freqs, freqs), dim=-1)
128
+ self.register_buffer(
129
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
130
+ )
131
+ self.register_buffer(
132
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
133
+ )
134
+
135
+ def forward(self, x, seq_len=None):
136
+ # x: [bs, num_attention_heads, seq_len, head_size]
137
+ if seq_len > self.max_seq_len_cached:
138
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
139
+
140
+ return (
141
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
142
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
143
+ )
144
+
145
+
146
+ class NGMELinearScalingRotaryEmbedding(NGMERotaryEmbedding):
147
+ """NGMERotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
148
+
149
+ def __init__(
150
+ self,
151
+ dim,
152
+ max_position_embeddings=2048,
153
+ base=10000,
154
+ device=None,
155
+ scaling_factor=1.0,
156
+ ):
157
+ self.scaling_factor = scaling_factor
158
+ super().__init__(dim, max_position_embeddings, base, device)
159
+
160
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
161
+ self.max_seq_len_cached = seq_len
162
+ t = torch.arange(
163
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
164
+ )
165
+ t = t / self.scaling_factor
166
+
167
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
168
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
169
+ emb = torch.cat((freqs, freqs), dim=-1)
170
+ self.register_buffer(
171
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
172
+ )
173
+ self.register_buffer(
174
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
175
+ )
176
+
177
+
178
+ class NGMEDynamicNTKScalingRotaryEmbedding(NGMERotaryEmbedding):
179
+ """NGMERotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
180
+
181
+ def __init__(
182
+ self,
183
+ dim,
184
+ max_position_embeddings=2048,
185
+ base=10000,
186
+ device=None,
187
+ scaling_factor=1.0,
188
+ ):
189
+ self.scaling_factor = scaling_factor
190
+ super().__init__(dim, max_position_embeddings, base, device)
191
+
192
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
193
+ self.max_seq_len_cached = seq_len
194
+
195
+ if seq_len > self.max_position_embeddings:
196
+ base = self.base * (
197
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
198
+ - (self.scaling_factor - 1)
199
+ ) ** (self.dim / (self.dim - 2))
200
+ inv_freq = 1.0 / (
201
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
202
+ )
203
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
204
+
205
+ t = torch.arange(
206
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
207
+ )
208
+
209
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
210
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
211
+ emb = torch.cat((freqs, freqs), dim=-1)
212
+ self.register_buffer(
213
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
214
+ )
215
+ self.register_buffer(
216
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
217
+ )
218
+
219
+
220
+ def rotate_half(x):
221
+ """Rotates half the hidden dims of the input."""
222
+ x1 = x[..., : x.shape[-1] // 2]
223
+ x2 = x[..., x.shape[-1] // 2 :]
224
+ return torch.cat((-x2, x1), dim=-1)
225
+
226
+
227
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
228
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
229
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
230
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
231
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
232
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
233
+ q_embed = (q * cos) + (rotate_half(q) * sin)
234
+ k_embed = (k * cos) + (rotate_half(k) * sin)
235
+ return q_embed, k_embed
236
+
237
+
238
+ class NGMEMLP(nn.Module):
239
+ def __init__(self, config):
240
+ super().__init__()
241
+ self.config = config
242
+ self.hidden_size = config.hidden_size
243
+ self.intermediate_size = config.intermediate_size
244
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
245
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
246
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
247
+ self.act_fn = ACT2FN[config.hidden_act]
248
+
249
+ def forward(self, x):
250
+ if self.config.pretraining_tp > 1:
251
+ slice = self.intermediate_size // self.config.pretraining_tp
252
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
253
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
254
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
255
+
256
+ gate_proj = torch.cat(
257
+ [
258
+ F.linear(x, gate_proj_slices[i])
259
+ for i in range(self.config.pretraining_tp)
260
+ ],
261
+ dim=-1,
262
+ )
263
+ up_proj = torch.cat(
264
+ [
265
+ F.linear(x, up_proj_slices[i])
266
+ for i in range(self.config.pretraining_tp)
267
+ ],
268
+ dim=-1,
269
+ )
270
+
271
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
272
+ down_proj = [
273
+ F.linear(intermediate_states[i], down_proj_slices[i])
274
+ for i in range(self.config.pretraining_tp)
275
+ ]
276
+ down_proj = sum(down_proj)
277
+ else:
278
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
279
+
280
+ return down_proj
281
+
282
+
283
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
284
+ """
285
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
286
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
287
+ """
288
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
289
+ if n_rep == 1:
290
+ return hidden_states
291
+ hidden_states = hidden_states[:, :, None, :, :].expand(
292
+ batch, num_key_value_heads, n_rep, slen, head_dim
293
+ )
294
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
295
+
296
+
297
+ class NGMEAttention(nn.Module):
298
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
299
+
300
+ def __init__(self, config: NGMEConfig):
301
+ super().__init__()
302
+ self.config = config
303
+ self.hidden_size = config.hidden_size
304
+ self.num_heads = config.num_attention_heads
305
+ self.head_dim = self.hidden_size // self.num_heads
306
+ self.num_key_value_heads = config.num_key_value_heads
307
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
308
+ self.max_position_embeddings = config.max_position_embeddings
309
+
310
+ if (self.head_dim * self.num_heads) != self.hidden_size:
311
+ raise ValueError(
312
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
313
+ f" and `num_heads`: {self.num_heads})."
314
+ )
315
+ self.q_proj = nn.Linear(
316
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
317
+ )
318
+ self.k_proj = nn.Linear(
319
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
320
+ )
321
+ self.v_proj = nn.Linear(
322
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
323
+ )
324
+ self.o_proj = nn.Linear(
325
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
326
+ )
327
+ self._init_rope()
328
+
329
+ def _init_rope(self):
330
+ if self.config.rope_scaling is None:
331
+ self.rotary_emb = NGMERotaryEmbedding(
332
+ self.head_dim, max_position_embeddings=self.max_position_embeddings
333
+ )
334
+ else:
335
+ scaling_type = self.config.rope_scaling["type"]
336
+ scaling_factor = self.config.rope_scaling["factor"]
337
+ if scaling_type == "linear":
338
+ self.rotary_emb = NGMELinearScalingRotaryEmbedding(
339
+ self.head_dim,
340
+ max_position_embeddings=self.max_position_embeddings,
341
+ scaling_factor=scaling_factor,
342
+ )
343
+ elif scaling_type == "dynamic":
344
+ self.rotary_emb = NGMEDynamicNTKScalingRotaryEmbedding(
345
+ self.head_dim,
346
+ max_position_embeddings=self.max_position_embeddings,
347
+ scaling_factor=scaling_factor,
348
+ )
349
+ else:
350
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
351
+
352
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
353
+ return (
354
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
355
+ .transpose(1, 2)
356
+ .contiguous()
357
+ )
358
+
359
+ def forward(
360
+ self,
361
+ hidden_states: torch.Tensor,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ position_ids: Optional[torch.LongTensor] = None,
364
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
365
+ output_attentions: bool = False,
366
+ use_cache: bool = False,
367
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
368
+ bsz, q_len, _ = hidden_states.size()
369
+
370
+ if self.config.pretraining_tp > 1:
371
+ key_value_slicing = (
372
+ self.num_key_value_heads * self.head_dim
373
+ ) // self.config.pretraining_tp
374
+ query_slices = self.q_proj.weight.split(
375
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
376
+ )
377
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
378
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
379
+
380
+ query_states = [
381
+ F.linear(hidden_states, query_slices[i])
382
+ for i in range(self.config.pretraining_tp)
383
+ ]
384
+ query_states = torch.cat(query_states, dim=-1)
385
+
386
+ key_states = [
387
+ F.linear(hidden_states, key_slices[i])
388
+ for i in range(self.config.pretraining_tp)
389
+ ]
390
+ key_states = torch.cat(key_states, dim=-1)
391
+
392
+ value_states = [
393
+ F.linear(hidden_states, value_slices[i])
394
+ for i in range(self.config.pretraining_tp)
395
+ ]
396
+ value_states = torch.cat(value_states, dim=-1)
397
+
398
+ else:
399
+ query_states = self.q_proj(hidden_states)
400
+ key_states = self.k_proj(hidden_states)
401
+ value_states = self.v_proj(hidden_states)
402
+
403
+ query_states = query_states.view(
404
+ bsz, q_len, self.num_heads, self.head_dim
405
+ ).transpose(1, 2)
406
+ key_states = key_states.view(
407
+ bsz, q_len, self.num_key_value_heads, self.head_dim
408
+ ).transpose(1, 2)
409
+ value_states = value_states.view(
410
+ bsz, q_len, self.num_key_value_heads, self.head_dim
411
+ ).transpose(1, 2)
412
+
413
+ kv_seq_len = key_states.shape[-2]
414
+ if past_key_value is not None:
415
+ kv_seq_len += past_key_value[0].shape[-2]
416
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
417
+ query_states, key_states = apply_rotary_pos_emb(
418
+ query_states, key_states, cos, sin, position_ids
419
+ )
420
+
421
+ if past_key_value is not None:
422
+ # reuse k, v, self_attention
423
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
424
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
425
+
426
+ past_key_value = (key_states, value_states) if use_cache else None
427
+
428
+ # repeat k/v heads if n_kv_heads < n_heads
429
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
430
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
431
+
432
+ attn_weights = torch.matmul(
433
+ query_states, key_states.transpose(2, 3)
434
+ ) / math.sqrt(self.head_dim)
435
+
436
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
437
+ raise ValueError(
438
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
439
+ f" {attn_weights.size()}"
440
+ )
441
+
442
+ if attention_mask is not None:
443
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
444
+ raise ValueError(
445
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
446
+ )
447
+ attn_weights = attn_weights + attention_mask
448
+
449
+ # upcast attention to fp32
450
+ attn_weights = nn.functional.softmax(
451
+ attn_weights, dim=-1, dtype=torch.float32
452
+ ).to(query_states.dtype)
453
+ attn_output = torch.matmul(attn_weights, value_states)
454
+
455
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
456
+ raise ValueError(
457
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
458
+ f" {attn_output.size()}"
459
+ )
460
+
461
+ attn_output = attn_output.transpose(1, 2).contiguous()
462
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
463
+
464
+ if self.config.pretraining_tp > 1:
465
+ attn_output = attn_output.split(
466
+ self.hidden_size // self.config.pretraining_tp, dim=2
467
+ )
468
+ o_proj_slices = self.o_proj.weight.split(
469
+ self.hidden_size // self.config.pretraining_tp, dim=1
470
+ )
471
+ attn_output = sum(
472
+ [
473
+ F.linear(attn_output[i], o_proj_slices[i])
474
+ for i in range(self.config.pretraining_tp)
475
+ ]
476
+ )
477
+ else:
478
+ attn_output = self.o_proj(attn_output)
479
+
480
+ if not output_attentions:
481
+ attn_weights = None
482
+
483
+ return attn_output, attn_weights, past_key_value
484
+
485
+
486
+ class NGMEDecoderLayer(nn.Module):
487
+ def __init__(self, config: NGMEConfig):
488
+ super().__init__()
489
+ self.hidden_size = config.hidden_size
490
+ self.self_attn = NGMEAttention(config=config)
491
+ self.mlp = NGMEMLP(config)
492
+ self.input_layernorm = NGMERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
493
+ self.post_attention_layernorm = NGMERMSNorm(
494
+ config.hidden_size, eps=config.rms_norm_eps
495
+ )
496
+
497
+ def forward(
498
+ self,
499
+ hidden_states: torch.Tensor,
500
+ attention_mask: Optional[torch.Tensor] = None,
501
+ position_ids: Optional[torch.LongTensor] = None,
502
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
503
+ output_attentions: Optional[bool] = False,
504
+ use_cache: Optional[bool] = False,
505
+ ) -> Tuple[
506
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
507
+ ]:
508
+ """
509
+ Args:
510
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
511
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
512
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
513
+ output_attentions (`bool`, *optional*):
514
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
515
+ returned tensors for more detail.
516
+ use_cache (`bool`, *optional*):
517
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
518
+ (see `past_key_values`).
519
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
520
+ """
521
+
522
+ residual = hidden_states
523
+
524
+ hidden_states = self.input_layernorm(hidden_states)
525
+
526
+ # Self Attention
527
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
528
+ hidden_states=hidden_states,
529
+ attention_mask=attention_mask,
530
+ position_ids=position_ids,
531
+ past_key_value=past_key_value,
532
+ output_attentions=output_attentions,
533
+ use_cache=use_cache,
534
+ )
535
+ hidden_states = residual + hidden_states
536
+
537
+ # Fully Connected
538
+ residual = hidden_states
539
+ hidden_states = self.post_attention_layernorm(hidden_states)
540
+ hidden_states = self.mlp(hidden_states)
541
+ hidden_states = residual + hidden_states
542
+
543
+ outputs = (hidden_states,)
544
+
545
+ if output_attentions:
546
+ outputs += (self_attn_weights,)
547
+
548
+ if use_cache:
549
+ outputs += (present_key_value,)
550
+
551
+ return outputs
552
+
553
+
554
+ class NGMEPreTrainedModel(PreTrainedModel):
555
+ config_class = NGMEConfig
556
+ base_model_prefix = "model"
557
+ supports_gradient_checkpointing = True
558
+ _no_split_modules = ["NGMEDecoderLayer"]
559
+ _skip_keys_device_placement = "past_key_values"
560
+
561
+ def _init_weights(self, module):
562
+ std = self.config.initializer_range
563
+ if isinstance(module, nn.Linear):
564
+ module.weight.data.normal_(mean=0.0, std=std)
565
+ if module.bias is not None:
566
+ module.bias.data.zero_()
567
+ elif isinstance(module, NGramsEmbedding):
568
+ if self.config.use_small_embedding:
569
+ nn.init.uniform_(module.weight, a=-1e-4, b=1e-4)
570
+ else:
571
+ module.weight.data.normal_(mean=0.0, std=std)
572
+ if module.padding_idx is not None:
573
+ module.weight.data[module.padding_idx].zero_()
574
+
575
+ def _set_gradient_checkpointing(self, module, value=False):
576
+ if isinstance(module, NGMEModel):
577
+ module.gradient_checkpointing = value
578
+
579
+
580
+ class NGMEModel(NGMEPreTrainedModel):
581
+ """
582
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`NGMEDecoderLayer`]
583
+
584
+ Args:
585
+ config: NGMEConfig
586
+ """
587
+
588
+ def __init__(self, config: NGMEConfig):
589
+ super().__init__(config)
590
+ self.padding_idx = config.pad_token_id
591
+ self.vocab_size = config.vocab_size
592
+
593
+ self.embed_tokens = NGramsEmbedding(
594
+ config.vocab_size,
595
+ config.hidden_size,
596
+ self.padding_idx,
597
+ unk_idx=config.unk_idx,
598
+ )
599
+
600
+ if self.config.use_small_embedding:
601
+ self.embed_layer_norm = nn.LayerNorm(config.hidden_size)
602
+
603
+ self.layers = nn.ModuleList(
604
+ [NGMEDecoderLayer(config) for _ in range(config.num_hidden_layers)]
605
+ )
606
+ self.norm = NGMERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
607
+
608
+ self.gradient_checkpointing = False
609
+ # Initialize weights and apply final processing
610
+ self.post_init()
611
+
612
+ def get_input_embeddings(self):
613
+ return self.embed_tokens
614
+
615
+ def set_input_embeddings(self, value):
616
+ self.embed_tokens = value
617
+
618
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
619
+ def _prepare_decoder_attention_mask(
620
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
621
+ ):
622
+ # create causal mask
623
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
624
+ combined_attention_mask = None
625
+ if input_shape[-1] > 1:
626
+ combined_attention_mask = _make_causal_mask(
627
+ input_shape,
628
+ inputs_embeds.dtype,
629
+ device=inputs_embeds.device,
630
+ past_key_values_length=past_key_values_length,
631
+ )
632
+
633
+ if attention_mask is not None:
634
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
635
+ expanded_attn_mask = _expand_mask(
636
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
637
+ ).to(inputs_embeds.device)
638
+ combined_attention_mask = (
639
+ expanded_attn_mask
640
+ if combined_attention_mask is None
641
+ else expanded_attn_mask + combined_attention_mask
642
+ )
643
+
644
+ return combined_attention_mask
645
+
646
+ def forward(
647
+ self,
648
+ input_ids: torch.LongTensor = None,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ position_ids: Optional[torch.LongTensor] = None,
651
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
652
+ inputs_embeds: Optional[torch.FloatTensor] = None,
653
+ use_cache: Optional[bool] = None,
654
+ output_attentions: Optional[bool] = None,
655
+ output_hidden_states: Optional[bool] = None,
656
+ return_dict: Optional[bool] = None,
657
+ ngram_sequences: List[torch.Tensor] = [],
658
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
659
+ output_attentions = (
660
+ output_attentions
661
+ if output_attentions is not None
662
+ else self.config.output_attentions
663
+ )
664
+ output_hidden_states = (
665
+ output_hidden_states
666
+ if output_hidden_states is not None
667
+ else self.config.output_hidden_states
668
+ )
669
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
670
+
671
+ return_dict = (
672
+ return_dict if return_dict is not None else self.config.use_return_dict
673
+ )
674
+
675
+ # retrieve input_ids and inputs_embeds
676
+ if input_ids is not None and inputs_embeds is not None:
677
+ raise ValueError(
678
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
679
+ )
680
+ elif input_ids is not None:
681
+ batch_size, seq_length = input_ids.shape
682
+ elif inputs_embeds is not None:
683
+ batch_size, seq_length, _ = inputs_embeds.shape
684
+ else:
685
+ raise ValueError(
686
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
687
+ )
688
+
689
+ seq_length_with_past = seq_length
690
+ past_key_values_length = 0
691
+
692
+ if past_key_values is not None:
693
+ past_key_values_length = past_key_values[0][0].shape[2]
694
+ seq_length_with_past = seq_length_with_past + past_key_values_length
695
+
696
+ if position_ids is None:
697
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
698
+ position_ids = torch.arange(
699
+ past_key_values_length,
700
+ seq_length + past_key_values_length,
701
+ dtype=torch.long,
702
+ device=device,
703
+ )
704
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
705
+ else:
706
+ position_ids = position_ids.view(-1, seq_length).long()
707
+
708
+ if inputs_embeds is None:
709
+ inputs_embeds = self.embed_tokens(input_ids, ngram_sequences)
710
+
711
+ if self.config.use_small_embedding:
712
+ inputs_embeds = self.embed_layer_norm(inputs_embeds)
713
+
714
+ # embed positions
715
+ if attention_mask is None:
716
+ attention_mask = torch.ones(
717
+ (batch_size, seq_length_with_past),
718
+ dtype=torch.bool,
719
+ device=inputs_embeds.device,
720
+ )
721
+ attention_mask = self._prepare_decoder_attention_mask(
722
+ attention_mask,
723
+ (batch_size, seq_length),
724
+ inputs_embeds,
725
+ past_key_values_length,
726
+ )
727
+
728
+ hidden_states = inputs_embeds
729
+
730
+ if self.gradient_checkpointing and self.training:
731
+ if use_cache:
732
+ logger.warning_once(
733
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
734
+ )
735
+ use_cache = False
736
+
737
+ # decoder layers
738
+ all_hidden_states = () if output_hidden_states else None
739
+ all_self_attns = () if output_attentions else None
740
+ next_decoder_cache = () if use_cache else None
741
+
742
+ for idx, decoder_layer in enumerate(self.layers):
743
+ if output_hidden_states:
744
+ all_hidden_states += (hidden_states,)
745
+
746
+ past_key_value = (
747
+ past_key_values[idx] if past_key_values is not None else None
748
+ )
749
+
750
+ if self.gradient_checkpointing and self.training:
751
+
752
+ def create_custom_forward(module):
753
+ def custom_forward(*inputs):
754
+ # None for past_key_value
755
+ return module(*inputs, output_attentions, None)
756
+
757
+ return custom_forward
758
+
759
+ layer_outputs = torch.utils.checkpoint.checkpoint(
760
+ create_custom_forward(decoder_layer),
761
+ hidden_states,
762
+ attention_mask,
763
+ position_ids,
764
+ None,
765
+ )
766
+ else:
767
+ layer_outputs = decoder_layer(
768
+ hidden_states,
769
+ attention_mask=attention_mask,
770
+ position_ids=position_ids,
771
+ past_key_value=past_key_value,
772
+ output_attentions=output_attentions,
773
+ use_cache=use_cache,
774
+ )
775
+
776
+ hidden_states = layer_outputs[0]
777
+
778
+ if use_cache:
779
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
780
+
781
+ if output_attentions:
782
+ all_self_attns += (layer_outputs[1],)
783
+
784
+ hidden_states = self.norm(hidden_states)
785
+
786
+ # add hidden states from the last decoder layer
787
+ if output_hidden_states:
788
+ all_hidden_states += (hidden_states,)
789
+
790
+ next_cache = next_decoder_cache if use_cache else None
791
+ if not return_dict:
792
+ return tuple(
793
+ v
794
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
795
+ if v is not None
796
+ )
797
+ return BaseModelOutputWithPast(
798
+ last_hidden_state=hidden_states,
799
+ past_key_values=next_cache,
800
+ hidden_states=all_hidden_states,
801
+ attentions=all_self_attns,
802
+ )
803
+
804
+
805
+ class NGMEForCausalLM(NGMEPreTrainedModel):
806
+ _tied_weights_keys = ["lm_head.weight"]
807
+
808
+ def __init__(self, config):
809
+ super().__init__(config)
810
+ self.model = NGMEModel(config)
811
+ self.vocab_size = config.vocab_size
812
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
813
+
814
+ self.tokenizer: Optional[NGMETokenizer] = None
815
+ weight = torch.ones(config.vocab_size)
816
+ weight[config.unk_idx] = 0
817
+ self.loss_fct = CrossEntropyLoss(weight=weight)
818
+
819
+ # Create weight tensor for the vocab weights and ignore unk_token weights
820
+ weigth_tensor = torch.ones(config.vocab_size)
821
+ weigth_tensor[config.unk_idx] = 0
822
+ self.loss_fct = nn.CrossEntropyLoss(weight=weigth_tensor)
823
+
824
+ # Initialize weights and apply final processing
825
+ self.post_init()
826
+
827
+ def get_input_embeddings(self):
828
+ return self.model.embed_tokens
829
+
830
+ def set_input_embeddings(self, value):
831
+ self.model.embed_tokens = value
832
+
833
+ def get_output_embeddings(self):
834
+ return self.lm_head
835
+
836
+ def set_output_embeddings(self, new_embeddings):
837
+ self.lm_head = new_embeddings
838
+
839
+ def set_decoder(self, decoder):
840
+ self.model = decoder
841
+
842
+ def get_decoder(self):
843
+ return self.model
844
+
845
+ def _collect_ngram_labels(
846
+ self,
847
+ unigram_labels: torch.LongTensor,
848
+ label_gram_2_sequence: Optional[torch.LongTensor] = None,
849
+ label_gram_3_sequence: Optional[torch.LongTensor] = None,
850
+ label_gram_4_sequence: Optional[torch.LongTensor] = None,
851
+ label_target_gram_2_sequence: Optional[torch.LongTensor] = None,
852
+ label_target_gram_3_sequence: Optional[torch.LongTensor] = None,
853
+ label_target_gram_4_sequence: Optional[torch.LongTensor] = None,
854
+ ):
855
+ ngram_labels = [unigram_labels[..., 1:].contiguous()]
856
+
857
+ if label_gram_2_sequence is not None:
858
+ if label_target_gram_2_sequence is not None:
859
+ two_gram_labels = label_target_gram_2_sequence[..., 1:].contiguous()
860
+ else:
861
+ two_gram_labels = shift_with_pad(
862
+ label_gram_2_sequence, 2, unigram_labels
863
+ )
864
+ ngram_labels.append(two_gram_labels)
865
+
866
+ if label_gram_3_sequence is not None:
867
+ if label_target_gram_3_sequence is not None:
868
+ three_gram_labels = label_target_gram_3_sequence[..., 1:].contiguous()
869
+ else:
870
+ three_gram_labels = shift_with_pad(
871
+ label_gram_3_sequence, 3, unigram_labels
872
+ )
873
+ ngram_labels.append(three_gram_labels)
874
+
875
+ if label_gram_4_sequence is not None:
876
+ if label_target_gram_4_sequence is not None:
877
+ four_gram_labels = label_target_gram_4_sequence[..., 1:].contiguous()
878
+ else:
879
+ four_gram_labels = shift_with_pad(
880
+ label_gram_4_sequence, 4, unigram_labels
881
+ )
882
+ ngram_labels.append(four_gram_labels)
883
+
884
+ return ngram_labels
885
+
886
+ def forward(
887
+ self,
888
+ input_ids: torch.LongTensor = None,
889
+ attention_mask: Optional[torch.Tensor] = None,
890
+ position_ids: Optional[torch.LongTensor] = None,
891
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
892
+ inputs_embeds: Optional[torch.FloatTensor] = None,
893
+ labels: Optional[torch.LongTensor] = None,
894
+ use_cache: Optional[bool] = None,
895
+ output_attentions: Optional[bool] = None,
896
+ output_hidden_states: Optional[bool] = None,
897
+ return_dict: Optional[bool] = None,
898
+ label_gram_2_sequence: Optional[torch.LongTensor] = None,
899
+ label_gram_3_sequence: Optional[torch.LongTensor] = None,
900
+ label_gram_4_sequence: Optional[torch.LongTensor] = None,
901
+ label_target_gram_2_sequence: Optional[torch.LongTensor] = None,
902
+ label_target_gram_3_sequence: Optional[torch.LongTensor] = None,
903
+ label_target_gram_4_sequence: Optional[torch.LongTensor] = None,
904
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
905
+ output_attentions = (
906
+ output_attentions
907
+ if output_attentions is not None
908
+ else self.config.output_attentions
909
+ )
910
+ output_hidden_states = (
911
+ output_hidden_states
912
+ if output_hidden_states is not None
913
+ else self.config.output_hidden_states
914
+ )
915
+ return_dict = (
916
+ return_dict if return_dict is not None else self.config.use_return_dict
917
+ )
918
+
919
+ ngram_sequences = collect_n_gram_sequences(
920
+ gram_2_sequence=label_gram_2_sequence,
921
+ gram_3_sequence=label_gram_3_sequence,
922
+ gram_4_sequence=label_gram_4_sequence,
923
+ )
924
+
925
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
926
+ outputs = self.model(
927
+ input_ids=input_ids,
928
+ attention_mask=attention_mask,
929
+ position_ids=position_ids,
930
+ past_key_values=past_key_values,
931
+ inputs_embeds=inputs_embeds,
932
+ use_cache=use_cache,
933
+ output_attentions=output_attentions,
934
+ output_hidden_states=output_hidden_states,
935
+ return_dict=return_dict,
936
+ ngram_sequences=ngram_sequences,
937
+ )
938
+
939
+ hidden_states = outputs[0]
940
+ if self.config.pretraining_tp > 1:
941
+ lm_head_slices = self.lm_head.weight.split(
942
+ self.vocab_size // self.config.pretraining_tp, dim=0
943
+ )
944
+ logits = [
945
+ F.linear(hidden_states, lm_head_slices[i])
946
+ for i in range(self.config.pretraining_tp)
947
+ ]
948
+ logits = torch.cat(logits, dim=-1)
949
+ else:
950
+ logits = self.lm_head(hidden_states)
951
+
952
+ logits = logits.float()
953
+
954
+ loss = None
955
+ if labels is not None:
956
+ # Shift so that tokens < n predict n
957
+ shift_logits = logits[..., :-1, :].contiguous()
958
+
959
+ ngram_labels = self._collect_ngram_labels(
960
+ unigram_labels=labels,
961
+ label_gram_2_sequence=label_gram_2_sequence,
962
+ label_gram_3_sequence=label_gram_3_sequence,
963
+ label_gram_4_sequence=label_gram_4_sequence,
964
+ label_target_gram_2_sequence=label_target_gram_2_sequence,
965
+ label_target_gram_3_sequence=label_target_gram_3_sequence,
966
+ label_target_gram_4_sequence=label_target_gram_4_sequence,
967
+ )
968
+
969
+ shift_labels = torch.stack(ngram_labels, dim=0)
970
+ shift_labels = soft_n_hot(
971
+ shift_labels, self.config.vocab_size, strategy="exp"
972
+ )
973
+
974
+ # Flatten the tokens
975
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
976
+ shift_labels = shift_labels.view(-1, shift_labels.size(-1))
977
+ # Enable model parallelism
978
+ shift_labels = shift_labels.to(shift_logits.device)
979
+ loss = self.loss_fct(shift_logits, shift_labels)
980
+
981
+ if not return_dict:
982
+ output = (logits,) + outputs[1:]
983
+ return (loss,) + output if loss is not None else output
984
+
985
+ return CausalLMOutputWithPast(
986
+ loss=loss,
987
+ logits=logits,
988
+ past_key_values=outputs.past_key_values,
989
+ hidden_states=outputs.hidden_states,
990
+ attentions=outputs.attentions,
991
+ )
992
+
993
+ def prepare_inputs_for_generation(
994
+ self,
995
+ input_ids,
996
+ past_key_values=None,
997
+ attention_mask=None,
998
+ inputs_embeds=None,
999
+ **kwargs,
1000
+ ):
1001
+ if past_key_values:
1002
+ input_ids = input_ids[:, -1:]
1003
+
1004
+ position_ids = kwargs.get("position_ids", None)
1005
+ if attention_mask is not None and position_ids is None:
1006
+ # create position_ids on the fly for batch generation
1007
+ position_ids = attention_mask.long().cumsum(-1) - 1
1008
+ position_ids.masked_fill_(attention_mask == 0, 1)
1009
+ if past_key_values:
1010
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1011
+
1012
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1013
+ if inputs_embeds is not None and past_key_values is None:
1014
+ model_inputs = {"inputs_embeds": inputs_embeds}
1015
+ else:
1016
+ model_inputs = {"input_ids": input_ids}
1017
+
1018
+ model_inputs.update(
1019
+ {
1020
+ "position_ids": position_ids,
1021
+ "past_key_values": past_key_values,
1022
+ "use_cache": kwargs.get("use_cache"),
1023
+ "attention_mask": attention_mask,
1024
+ }
1025
+ )
1026
+ return model_inputs
1027
+
1028
+ @staticmethod
1029
+ def _reorder_cache(past_key_values, beam_idx):
1030
+ reordered_past = ()
1031
+ for layer_past in past_key_values:
1032
+ reordered_past += (
1033
+ tuple(
1034
+ past_state.index_select(0, beam_idx.to(past_state.device))
1035
+ for past_state in layer_past
1036
+ ),
1037
+ )
1038
+ return reordered_past
1039
+
1040
+ def sample(
1041
+ self,
1042
+ input_ids: torch.LongTensor,
1043
+ logits_processor: Optional[LogitsProcessorList] = None,
1044
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1045
+ logits_warper: Optional[LogitsProcessorList] = None,
1046
+ max_length: Optional[int] = None,
1047
+ pad_token_id: Optional[int] = None,
1048
+ eos_token_id: Optional[Union[int, List[int]]] = None,
1049
+ output_attentions: Optional[bool] = None,
1050
+ output_hidden_states: Optional[bool] = None,
1051
+ output_scores: Optional[bool] = None,
1052
+ return_dict_in_generate: Optional[bool] = None,
1053
+ synced_gpus: bool = False,
1054
+ streamer=None,
1055
+ **model_kwargs,
1056
+ ) -> Union[SampleOutput, torch.LongTensor]:
1057
+ if not hasattr(self, "tokenizer"):
1058
+ raise ValueError(
1059
+ "You are trying to sample from a model that does not have a tokenizer."
1060
+ "Add a tokenizer as an attribute of your model (either manually or automatically)."
1061
+ )
1062
+
1063
+ return sample_ngme(
1064
+ self,
1065
+ input_ids,
1066
+ logits_processor=logits_processor,
1067
+ stopping_criteria=stopping_criteria,
1068
+ logits_warper=logits_warper,
1069
+ max_length=max_length,
1070
+ pad_token_id=pad_token_id,
1071
+ eos_token_id=eos_token_id,
1072
+ output_attentions=output_attentions,
1073
+ output_hidden_states=output_hidden_states,
1074
+ output_scores=output_scores,
1075
+ return_dict_in_generate=return_dict_in_generate,
1076
+ synced_gpus=synced_gpus,
1077
+ streamer=streamer,
1078
+ **model_kwargs,
1079
+ )
1080
+
1081
+
1082
+ class NGMEForSequenceClassification(NGMEPreTrainedModel):
1083
+ def __init__(self, config):
1084
+ super().__init__(config)
1085
+ self.num_labels = config.num_labels
1086
+ self.model = NGMEModel(config)
1087
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1088
+
1089
+ # Initialize weights and apply final processing
1090
+ self.post_init()
1091
+
1092
+ def get_input_embeddings(self):
1093
+ return self.model.embed_tokens
1094
+
1095
+ def set_input_embeddings(self, value):
1096
+ self.model.embed_tokens = value
1097
+
1098
+ def forward(
1099
+ self,
1100
+ input_ids: torch.LongTensor = None,
1101
+ attention_mask: Optional[torch.Tensor] = None,
1102
+ position_ids: Optional[torch.LongTensor] = None,
1103
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1104
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1105
+ labels: Optional[torch.LongTensor] = None,
1106
+ use_cache: Optional[bool] = None,
1107
+ output_attentions: Optional[bool] = None,
1108
+ output_hidden_states: Optional[bool] = None,
1109
+ return_dict: Optional[bool] = None,
1110
+ label_gram_2_sequence: Optional[torch.LongTensor] = None,
1111
+ label_gram_3_sequence: Optional[torch.LongTensor] = None,
1112
+ label_gram_4_sequence: Optional[torch.LongTensor] = None,
1113
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1114
+ r"""
1115
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1116
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1117
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1118
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1119
+ """
1120
+ return_dict = (
1121
+ return_dict if return_dict is not None else self.config.use_return_dict
1122
+ )
1123
+
1124
+ ngram_sequences = collect_n_gram_sequences(
1125
+ gram_2_sequence=label_gram_2_sequence,
1126
+ gram_3_sequence=label_gram_3_sequence,
1127
+ gram_4_sequence=label_gram_4_sequence,
1128
+ )
1129
+
1130
+ transformer_outputs = self.model(
1131
+ input_ids,
1132
+ attention_mask=attention_mask,
1133
+ position_ids=position_ids,
1134
+ past_key_values=past_key_values,
1135
+ inputs_embeds=inputs_embeds,
1136
+ use_cache=use_cache,
1137
+ output_attentions=output_attentions,
1138
+ output_hidden_states=output_hidden_states,
1139
+ return_dict=return_dict,
1140
+ ngram_sequences=ngram_sequences,
1141
+ )
1142
+ hidden_states = transformer_outputs[0]
1143
+ logits = self.score(hidden_states)
1144
+
1145
+ if input_ids is not None:
1146
+ batch_size = input_ids.shape[0]
1147
+ else:
1148
+ batch_size = inputs_embeds.shape[0]
1149
+
1150
+ if self.config.pad_token_id is None and batch_size != 1:
1151
+ raise ValueError(
1152
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1153
+ )
1154
+ if self.config.pad_token_id is None:
1155
+ sequence_lengths = -1
1156
+ else:
1157
+ if input_ids is not None:
1158
+ sequence_lengths = (
1159
+ torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1160
+ ).to(logits.device)
1161
+ else:
1162
+ sequence_lengths = -1
1163
+
1164
+ pooled_logits = logits[
1165
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1166
+ ]
1167
+
1168
+ loss = None
1169
+
1170
+ if labels is not None:
1171
+ labels = labels.to(logits.device)
1172
+ if self.config.problem_type is None:
1173
+ if self.num_labels == 1:
1174
+ self.config.problem_type = "regression"
1175
+ elif self.num_labels > 1 and (
1176
+ labels.dtype == torch.long or labels.dtype == torch.int
1177
+ ):
1178
+ self.config.problem_type = "single_label_classification"
1179
+ else:
1180
+ self.config.problem_type = "multi_label_classification"
1181
+
1182
+ if self.config.problem_type == "regression":
1183
+ loss_fct = MSELoss()
1184
+ if self.num_labels == 1:
1185
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1186
+ else:
1187
+ loss = loss_fct(pooled_logits, labels)
1188
+ elif self.config.problem_type == "single_label_classification":
1189
+ loss_fct = CrossEntropyLoss()
1190
+ loss = loss_fct(
1191
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1192
+ )
1193
+ elif self.config.problem_type == "multi_label_classification":
1194
+ loss_fct = BCEWithLogitsLoss()
1195
+ loss = loss_fct(pooled_logits, labels)
1196
+ if not return_dict:
1197
+ output = (pooled_logits,) + transformer_outputs[1:]
1198
+ return ((loss,) + output) if loss is not None else output
1199
+
1200
+ return SequenceClassifierOutputWithPast(
1201
+ loss=loss,
1202
+ logits=pooled_logits,
1203
+ past_key_values=transformer_outputs.past_key_values,
1204
+ hidden_states=transformer_outputs.hidden_states,
1205
+ attentions=transformer_outputs.attentions,
1206
+ )
ngme.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, List
3
+ from functools import lru_cache
4
+ from itertools import chain, tee
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ n_dists = {
10
+ 0: [1],
11
+ 1: [0.4, 0.6],
12
+ 2: [0.2, 0.3, 0.5],
13
+ 3: [0.1, 0.2, 0.3, 0.4],
14
+ 4: [0.1, 0.15, 0.2, 0.25, 0.3],
15
+ }
16
+
17
+ strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2}
18
+
19
+ def pad_sequence(
20
+ sequence,
21
+ n,
22
+ pad_left=False,
23
+ pad_right=False,
24
+ left_pad_symbol=None,
25
+ right_pad_symbol=None,
26
+ ):
27
+ """Copied from NLTK"""
28
+ sequence = iter(sequence)
29
+ if pad_left:
30
+ sequence = chain((left_pad_symbol,) * (n - 1), sequence)
31
+ if pad_right:
32
+ sequence = chain(sequence, (right_pad_symbol,) * (n - 1))
33
+ return sequence
34
+
35
+ def ngrams(sequence, n, **kwargs):
36
+ """Copied from NLTK"""
37
+ sequence = pad_sequence(sequence, n, **kwargs)
38
+
39
+ # Creates the sliding window, of n no. of items.
40
+ # `iterables` is a tuple of iterables where each iterable is a window of n items.
41
+ iterables = tee(sequence, n)
42
+
43
+ for i, sub_iterable in enumerate(iterables): # For each window,
44
+ for _ in range(i): # iterate through every order of ngrams
45
+ next(sub_iterable, None) # generate the ngrams within the window.
46
+ return zip(*iterables) # Unpack and flattens the iterables.
47
+
48
+
49
+ @lru_cache(maxsize=5)
50
+ def soft_dist(n):
51
+ return [1 / n] * n
52
+
53
+
54
+ @lru_cache(maxsize=5)
55
+ def n_dist(n: int, strategy: str) -> list[float]:
56
+ """dist of ngram weight is logarithmic"""
57
+ ns = list(range(1, n + 1))
58
+ xs = list(map(strats[strategy], ns))
59
+ result = list(map(lambda x: x / sum(xs), xs))
60
+ return result
61
+
62
+ def soft_n_hot(
63
+ input,
64
+ num_classes: int,
65
+ strategy: Optional[str],
66
+ ):
67
+
68
+ shape = list(input.size())[1:]
69
+
70
+ shape.append(num_classes)
71
+
72
+ ret = torch.zeros(shape).to(input.device)
73
+
74
+ if strategy:
75
+ soft_labels = n_dist(input.size(0), strategy)
76
+ else:
77
+ soft_labels = [1] * input.size(0)
78
+
79
+ for i, t in enumerate(input):
80
+ ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i])
81
+
82
+ return ret
83
+
84
+
85
+ def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None):
86
+
87
+ shape = list(t.size())
88
+
89
+ if ngram_sequences is not None:
90
+ shape.append(num_clases)
91
+ ret = torch.zeros(shape).to(t.device)
92
+ ret.scatter_(-1, t.unsqueeze(-1), 1)
93
+ for seq in ngram_sequences:
94
+ if unk_idx is not None:
95
+ mask = torch.eq(seq, unk_idx)
96
+ seq[mask] = t[mask]
97
+ ret.scatter_(-1, seq.unsqueeze(-1), 1)
98
+ return ret
99
+
100
+ elif len(shape) == 2:
101
+ return F.one_hot(t, num_classes=num_clases).float()
102
+ else:
103
+ shape = shape[1:]
104
+ shape.append(num_clases)
105
+ ret = torch.zeros(shape).to(t.device)
106
+ # Expect that first dimension is for all n-grams
107
+ for seq in t:
108
+ ret.scatter_(-1, seq.unsqueeze(-1), 1)
109
+
110
+ return ret
111
+
112
+
113
+ class NGramsEmbedding(torch.nn.Embedding):
114
+ """N-Hot encoder"""
115
+
116
+ def __init__(
117
+ self,
118
+ num_embeddings: int,
119
+ embedding_dim: int,
120
+ padding_idx: Optional[int] = None,
121
+ max_norm: Optional[float] = None,
122
+ norm_type: float = 2,
123
+ scale_grad_by_freq: bool = False,
124
+ sparse: bool = False,
125
+ _weight: Optional[torch.Tensor] = None,
126
+ device=None,
127
+ dtype=None,
128
+ unk_idx: Optional[int] = None
129
+ ) -> None:
130
+ super().__init__(
131
+ num_embeddings,
132
+ embedding_dim,
133
+ padding_idx=padding_idx,
134
+ max_norm=max_norm,
135
+ norm_type=norm_type,
136
+ scale_grad_by_freq=scale_grad_by_freq,
137
+ sparse=sparse,
138
+ _weight=_weight,
139
+ device=device,
140
+ dtype=dtype,
141
+ )
142
+
143
+ self.num_classes = num_embeddings
144
+ self.unk_idx = unk_idx
145
+
146
+ def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None):
147
+ return self._forward(
148
+ n_hot(input, self.num_classes, ngram_sequences, self.unk_idx)
149
+ )
150
+
151
+ def _forward(self, n_hot: torch.Tensor) -> torch.Tensor:
152
+ return F.linear(n_hot, self.weight.t())
153
+
154
+
155
+ def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]:
156
+ sequences = []
157
+ for n in range(2, len(kwargs)+2):
158
+ s = kwargs[f"gram_{n}_sequence"]
159
+ if s is not None:
160
+ sequences.append(s)
161
+ else:
162
+ break
163
+
164
+ return sequences
165
+
166
+ def shift_with_pad(target_tensor, n, from_tensor):
167
+ shifted = target_tensor[:, n:]
168
+
169
+ seq_size = target_tensor.size(1) - 1
170
+
171
+ missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device)
172
+
173
+ # Pad with missing idxs from unigram tensor
174
+ shifted = torch.concat(
175
+ (shifted, from_tensor.index_select(1, missing_idxs)), dim=1
176
+ )
177
+
178
+ return shifted
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b26375a83551a0a47ff2fdc26e2ca3616198d3c3a12e5b5ef4ea70ac51f9296
3
+ size 907246313
sampling.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch import nn
7
+ from transformers import BatchEncoding
8
+ from transformers.generation.logits_process import (
9
+ LogitsProcessorList,
10
+ )
11
+ from transformers.generation.stopping_criteria import (
12
+ StoppingCriteriaList,
13
+ validate_stopping_criteria,
14
+ )
15
+
16
+ from transformers.generation.utils import SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput
17
+
18
+ def sample(
19
+ self,
20
+ input_ids: torch.LongTensor,
21
+ logits_processor: Optional[LogitsProcessorList] = None,
22
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
23
+ logits_warper: Optional[LogitsProcessorList] = None,
24
+ max_length: Optional[int] = None,
25
+ pad_token_id: Optional[int] = None,
26
+ eos_token_id: Optional[Union[int, List[int]]] = None,
27
+ output_attentions: Optional[bool] = None,
28
+ output_hidden_states: Optional[bool] = None,
29
+ output_scores: Optional[bool] = None,
30
+ return_dict_in_generate: Optional[bool] = None,
31
+ synced_gpus: Optional[bool] = False,
32
+ **model_kwargs,
33
+ ) -> Union[SampleOutput, torch.LongTensor]:
34
+
35
+ if type(input_ids) in [dict, BatchEncoding]:
36
+ input_ids, ngram_sequences = input_ids["input_ids"], input_ids
37
+ del ngram_sequences["input_ids"]
38
+ del ngram_sequences["attention_mask"]
39
+ else:
40
+ ngram_sequences = {}
41
+
42
+ # init values
43
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
44
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
45
+ if max_length is not None:
46
+ warnings.warn(
47
+ "`max_length` is deprecated in this function, use"
48
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
49
+ UserWarning,
50
+ )
51
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
52
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
53
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
54
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
55
+ if isinstance(eos_token_id, int):
56
+ eos_token_id = [eos_token_id]
57
+
58
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
59
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
60
+ output_attentions = (
61
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
62
+ )
63
+ output_hidden_states = (
64
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
65
+ )
66
+ return_dict_in_generate = (
67
+ return_dict_in_generate
68
+ if return_dict_in_generate is not None
69
+ else self.generation_config.return_dict_in_generate
70
+ )
71
+
72
+ # init attention / hidden states / scores tuples
73
+ scores = () if (return_dict_in_generate and output_scores) else None
74
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
75
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
76
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
77
+
78
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
79
+ if return_dict_in_generate and self.config.is_encoder_decoder:
80
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
81
+ encoder_hidden_states = (
82
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
83
+ )
84
+
85
+ # keep track of which sequences are already finished
86
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
87
+
88
+ this_peer_finished = False # used by synced_gpus only
89
+ # auto-regressive generation
90
+ while True:
91
+ if synced_gpus:
92
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
93
+ # The following logic allows an early break if all peers finished generating their sequence
94
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
95
+ # send 0.0 if we finished, 1.0 otherwise
96
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
97
+ # did all peers finish? the reduced sum will be 0.0 then
98
+ if this_peer_finished_flag.item() == 0.0:
99
+ break
100
+
101
+ # prepare model inputs
102
+ model_inputs = {"input_ids": input_ids}
103
+
104
+ # forward pass to get next token
105
+ outputs = self(
106
+ **model_inputs,
107
+ return_dict=True,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ **ngram_sequences
111
+ )
112
+
113
+ if synced_gpus and this_peer_finished:
114
+ continue # don't waste resources running the code we don't need
115
+
116
+ next_token_logits = outputs.logits[:, -1, :]
117
+
118
+ # pre-process distribution
119
+ next_token_scores = logits_processor(input_ids, next_token_logits)
120
+ next_token_scores = logits_warper(input_ids, next_token_scores)
121
+
122
+ # Store scores, attentions and hidden_states when required
123
+ if return_dict_in_generate:
124
+ if output_scores:
125
+ scores += (next_token_scores,)
126
+ if output_attentions:
127
+ decoder_attentions += (
128
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
129
+ )
130
+ if self.config.is_encoder_decoder:
131
+ cross_attentions += (outputs.cross_attentions,)
132
+
133
+ if output_hidden_states:
134
+ decoder_hidden_states += (
135
+ (outputs.decoder_hidden_states,)
136
+ if self.config.is_encoder_decoder
137
+ else (outputs.hidden_states,)
138
+ )
139
+
140
+ # sample
141
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
142
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
143
+
144
+ # finished sentences should have their next token be a padding token
145
+ if eos_token_id is not None:
146
+ if pad_token_id is None:
147
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
148
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
149
+
150
+ # update generated ids, model inputs, and length for next step
151
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
152
+ decoded = self.tokenizer.batch_decode(input_ids)[0]
153
+ encoded = self.tokenizer(
154
+ decoded, return_tensors="pt", return_ngram_sequences=True
155
+ )
156
+ input_ids = encoded.input_ids.to(self.device)
157
+
158
+ ngram_sequences = {}
159
+
160
+ if "label_gram_2_sequence" in encoded:
161
+ ngram_sequences["label_gram_2_sequence"] = encoded["label_gram_2_sequence"].to(self.device)
162
+
163
+ if "label_gram_3_sequence" in encoded:
164
+ ngram_sequences["label_gram_3_sequence"] = encoded["label_gram_3_sequence"].to(self.device)
165
+
166
+ if "label_gram_4_sequence" in encoded:
167
+ ngram_sequences["label_gram_4_sequence"] = encoded["label_gram_4_sequence"].to(self.device)
168
+
169
+ model_kwargs = self._update_model_kwargs_for_generation(
170
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
171
+ )
172
+
173
+ # if eos_token was found in one sentence, set sentence to finished
174
+ if eos_token_id_tensor is not None:
175
+ unfinished_sequences = unfinished_sequences.mul(
176
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
177
+ )
178
+
179
+ # stop when each sentence is finished, or if we exceed the maximum length
180
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
181
+ if not synced_gpus:
182
+ break
183
+ else:
184
+ this_peer_finished = True
185
+
186
+ if return_dict_in_generate:
187
+ if self.config.is_encoder_decoder:
188
+ return SampleEncoderDecoderOutput(
189
+ sequences=input_ids,
190
+ scores=scores,
191
+ encoder_attentions=encoder_attentions,
192
+ encoder_hidden_states=encoder_hidden_states,
193
+ decoder_attentions=decoder_attentions,
194
+ cross_attentions=cross_attentions,
195
+ decoder_hidden_states=decoder_hidden_states,
196
+ )
197
+ else:
198
+ return SampleDecoderOnlyOutput(
199
+ sequences=input_ids,
200
+ scores=scores,
201
+ attentions=decoder_attentions,
202
+ hidden_states=decoder_hidden_states,
203
+ )
204
+ else:
205
+ return input_ids
tokenization_ngme.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import unittest
3
+ import os
4
+ from collections import Counter
5
+ from typing import Dict, List, Optional, Sized, Tuple, Union, Any
6
+
7
+ import torch
8
+ import numpy as np
9
+ from tokenizers import AddedToken
10
+ from transformers import PreTrainedTokenizer
11
+ from transformers.tokenization_utils_base import (
12
+ BatchEncoding,
13
+ EncodedInput,
14
+ TruncationStrategy,
15
+ )
16
+ from transformers.utils import logging
17
+ from transformers.utils.generic import PaddingStrategy, TensorType, to_py_obj
18
+
19
+ from .ngme import ngrams as ngram_tokenizer
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ def load_vocab(vocab_file):
25
+ """Loads a vocabulary file into a dictionary."""
26
+ with open(vocab_file, "r", encoding="utf-8") as f:
27
+ vocab = json.load(f)
28
+ return vocab
29
+
30
+
31
+ def all_same(items):
32
+ return all(x == items[0] for x in items)
33
+
34
+
35
+ class NGMETokenizer(PreTrainedTokenizer):
36
+ model_input_names = ["input_ids", "attention_mask"]
37
+ vocab_file = "vocab.json"
38
+ vocab_files_names = {"vocab_file": vocab_file}
39
+
40
+ def __init__(
41
+ self,
42
+ vocab_file,
43
+ eos_token="\n",
44
+ pad_token="\n",
45
+ unk_token="<unk>",
46
+ eod_token="<eod>",
47
+ **kwargs,
48
+ ):
49
+ super().__init__(
50
+ eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, **kwargs
51
+ )
52
+
53
+ eos_token = (
54
+ AddedToken(
55
+ eos_token,
56
+ lstrip=False,
57
+ rstrip=False,
58
+ )
59
+ if isinstance(eos_token, str)
60
+ else eos_token
61
+ )
62
+ pad_token = (
63
+ AddedToken(
64
+ pad_token,
65
+ lstrip=False,
66
+ rstrip=False,
67
+ )
68
+ if isinstance(pad_token, str)
69
+ else pad_token
70
+ )
71
+ unk_token = (
72
+ AddedToken(
73
+ unk_token,
74
+ lstrip=False,
75
+ rstrip=False,
76
+ )
77
+ if isinstance(unk_token, str)
78
+ else unk_token
79
+ )
80
+
81
+ self._ngram2word2idx = {}
82
+ self._ngram2idx2word = {}
83
+ self._current_max_idx = 0
84
+ self._frequencies: Counter = Counter()
85
+
86
+ self._load_from_file(vocab_file)
87
+
88
+ for n in range(2, self.ngram + 1):
89
+ self.model_input_names.append(f"ngram_{n}_sequence")
90
+
91
+ # TODO: COuld also be whitespace if n+1gram dont contain it
92
+ self._special_token = "Ġ"
93
+ assert self._special_token not in self._ngram2word2idx[1]
94
+
95
+ def __call__(self, *args, **kwargs) -> BatchEncoding:
96
+ if "return_ngram_sequences" in kwargs:
97
+ return_ngram_sequences = kwargs["return_ngram_sequences"]
98
+ del kwargs["return_ngram_sequences"]
99
+ else:
100
+ return_ngram_sequences = False
101
+
102
+ # We could check the args and kwargs beforehand and apply extra ngram sequences based on it, but
103
+ # we let HF handle all logic and reverse take the char sequence from the ids
104
+ batch_encoding = super().__call__(*args, **kwargs)
105
+
106
+ if return_ngram_sequences:
107
+ ngram_sequences = self.create_ngram_sequences(args[0])
108
+ # NOTE: This is pretty hard coded, lets just throw an error if the user wants to use it differently
109
+
110
+ if "padding" in kwargs:
111
+ if kwargs["padding"] == "max_length":
112
+ padded_sequences = {}
113
+ for n_key, sequence in ngram_sequences.items():
114
+ padded_sequences[n_key] = self.pad_sequence_right(
115
+ sequence,
116
+ len(batch_encoding["input_ids"][0]),
117
+ self.pad_token_id,
118
+ )
119
+
120
+ ngram_sequences = padded_sequences
121
+ elif kwargs["padding"] == "longest":
122
+ padded_sequences = {}
123
+ for n_key, sequence in ngram_sequences.items():
124
+ padded_sequences[n_key] = self.pad_sequence_right(
125
+ sequence,
126
+ max([len(seq) for seq in sequence]),
127
+ self.pad_token_id,
128
+ )
129
+ ngram_sequences = padded_sequences
130
+
131
+ else:
132
+ raise ValueError(
133
+ f"Padding {kwargs['padding']} not supported for ngram sequences"
134
+ )
135
+
136
+ if "truncation" in kwargs and kwargs["truncation"]:
137
+ truncated_sequences = {}
138
+ for n_key, sequence in ngram_sequences.items():
139
+ truncated_sequences[n_key] = self.truncate_sequence_right(
140
+ sequence, len(batch_encoding["input_ids"][0])
141
+ )
142
+ ngram_sequences = truncated_sequences
143
+
144
+ batch_encoding.update(ngram_sequences)
145
+
146
+ if "return_tensors" in kwargs:
147
+ batch_encoding.convert_to_tensors(kwargs["return_tensors"])
148
+
149
+ return batch_encoding
150
+
151
+ def pad_sequence_right(
152
+ self, batched_sequence: List[List[int]], padding_length: int, padding_value: int
153
+ ) -> List[List[int]]:
154
+ padded_sequence = []
155
+ for sequence in batched_sequence:
156
+ padded_sequence.append(
157
+ sequence + [padding_value] * (padding_length - len(sequence))
158
+ )
159
+ return padded_sequence
160
+
161
+ def truncate_sequence_right(
162
+ self, batched_sequence: List[List[int]], max_length: int
163
+ ) -> List[List[int]]:
164
+ truncated_sequence = []
165
+ for sequence in batched_sequence:
166
+ truncated_sequence.append(sequence[:max_length])
167
+ return truncated_sequence
168
+
169
+ def create_ngram_sequences(self, char_sequences: List[str]) -> Dict[str, Any]:
170
+ ngram_sequences_output = {}
171
+
172
+ if isinstance(char_sequences, str):
173
+ char_sequences = [char_sequences]
174
+
175
+ for n in range(2, self.ngram + 1):
176
+ ngram_sequences = []
177
+ for char_sequence in char_sequences:
178
+ ngrams = ["".join(ngram) for ngram in ngram_tokenizer(char_sequence, n)]
179
+ # Fill in the front with existign unigrams, for same length and
180
+ # because the timestep t should not look ahead
181
+ ngrams = list(char_sequence[: n - 1]) + ngrams
182
+ encoded_ngrams = self.encode(ngrams) if len(ngrams) > 0 else []
183
+ ngram_sequences.append(encoded_ngrams)
184
+
185
+ ngram_sequences_output[f"label_gram_{n}_sequence"] = ngram_sequences
186
+
187
+ return ngram_sequences_output
188
+
189
+ def _seq_size(self, encoded) -> Union[int, List[int]]:
190
+ if isinstance(encoded, torch.Tensor):
191
+ encoded = encoded.tolist()
192
+
193
+ if isinstance(encoded[0], list):
194
+ return [len(enc) for enc in encoded]
195
+
196
+ return len(encoded)
197
+
198
+ def _load_from_file(self, filename: str):
199
+ """Loads a dictionary from a file."""
200
+ vocab_file = load_vocab(filename)
201
+ self.ngram = vocab_file["ngram"]
202
+
203
+ if "\n" not in vocab_file["vocab"]:
204
+ self._add_ngram("\n", 1)
205
+
206
+ for token in vocab_file["vocab"]:
207
+ self._add_ngram(token["token"], token["ngram"])
208
+ self._frequencies.update({token["token"]: token["frequency"]})
209
+
210
+ def _add_ngram(self, word, ngram: int) -> int:
211
+ """Add a new n-gram token to the dictionary."""
212
+ self._frequencies.update({word: 1})
213
+
214
+ if ngram not in self._ngram2idx2word:
215
+ self._ngram2idx2word[ngram] = {self._current_max_idx: word}
216
+ self._ngram2word2idx[ngram] = {word: self._current_max_idx}
217
+ self._current_max_idx += 1
218
+ else:
219
+ if word not in self._ngram2word2idx[ngram]:
220
+ self._ngram2idx2word[ngram][self._current_max_idx] = word
221
+ self._ngram2word2idx[ngram][word] = self._current_max_idx
222
+ self._current_max_idx += 1
223
+
224
+ return self._ngram2word2idx[ngram][word]
225
+
226
+ def _is_contiguous(self):
227
+ vocab_size = len(self)
228
+ return list(range(vocab_size)) == [idx for idx, token in self._get_all_tokens()]
229
+
230
+ def _get_all_tokens(self):
231
+ """Returns all tokens in the dictionary."""
232
+ for ngram in range(1, self.ngram + 1):
233
+ for idx, token in self._ngram2idx2word[ngram].items():
234
+ yield idx, token
235
+
236
+ def save_vocabulary(
237
+ self, save_directory: str, filename_prefix: Optional[str] = None
238
+ ) -> Tuple[str]:
239
+ filename = os.path.join(
240
+ save_directory,
241
+ (filename_prefix + "-" if filename_prefix else ""),
242
+ self.vocab_file,
243
+ )
244
+
245
+ index = 0
246
+ vocab = {"ngram": self.ngram, "vocab": []}
247
+
248
+ for ngram in range(1, self.ngram + 1):
249
+ for idx, token in self._ngram2idx2word[ngram].items():
250
+ if index != idx:
251
+ index = idx
252
+
253
+ try:
254
+ frequency = self._frequencies[token]
255
+ except KeyError:
256
+ frequency = -1
257
+
258
+ index += 1
259
+ vocab["vocab"].append(
260
+ {
261
+ "token": token,
262
+ "index": idx,
263
+ "frequency": frequency,
264
+ "ngram": ngram,
265
+ }
266
+ )
267
+
268
+ with open(filename, "w", encoding="utf-8") as writer:
269
+ json.dump(vocab, writer, indent=4, ensure_ascii=False)
270
+
271
+ return (filename,)
272
+
273
+ @property
274
+ def vocab_size(self) -> int:
275
+ return self._current_max_idx
276
+
277
+ def _tokenize(self, text: str) -> List[str]:
278
+ return list(text)
279
+
280
+ def get_idx(self, token: str, ngram: Optional[int] = None) -> int:
281
+ if ngram:
282
+ if token in self._ngram2word2idx[ngram]:
283
+ return self._ngram2word2idx[ngram][token]
284
+ else:
285
+ return self._ngram2word2idx[1]["<unk>"]
286
+
287
+ for ngram in range(1, self.ngram + 1):
288
+ if token in self._ngram2word2idx[ngram]:
289
+ return self._ngram2word2idx[ngram][token]
290
+
291
+ return self._ngram2word2idx[1]["<unk>"]
292
+
293
+ def _convert_ngram_tokens_to_ids(self, ngram_tokens: List[str]) -> List[int]:
294
+ return [self.get_idx(token) for token in ngram_tokens]
295
+
296
+ def convert_tokens_to_ids(self, tokens: List[str]):
297
+ if not tokens:
298
+ return []
299
+
300
+ if isinstance(tokens, str):
301
+ return self.get_idx(tokens)
302
+
303
+ return self._convert_ngram_tokens_to_ids(tokens)
304
+
305
+ def _convert_id_to_token(self, index: int) -> str:
306
+ return self.get_item_for_index(index)
307
+
308
+ def get_item_for_index(self, idx) -> str:
309
+ """Return the token for a given index."""
310
+ for idxs in self._ngram2idx2word.values():
311
+ if idx in idxs:
312
+ return idxs[idx]
313
+
314
+ return self.unk_token
315
+
316
+ def convert_tokens_to_string(self, tokens):
317
+ return "".join(tokens)
318
+
319
+ def create_weight_tensor(self) -> torch.Tensor:
320
+ unked_freqs = self._frequencies.most_common()
321
+
322
+ t = torch.ones(len(self))
323
+
324
+ for token, freq in unked_freqs:
325
+ t[self._ngram2word2idx[self._token_to_n_order(token)][token]] = freq
326
+
327
+ # Ensure the only whitespace character is weighted
328
+ t[self._ngram2word2idx[1][" "]] = 1.0
329
+
330
+ max_t = max(t)
331
+
332
+ normed_weights = torch.tensor([(1 - (x / (max_t + 1))).item() for x in t])
333
+
334
+ marker_tokens = [self.get_idx("<unk>", n) for n in range(1, self.ngram + 1)]
335
+ marker_tokens.extend(
336
+ [self.get_idx("<start>", n) for n in range(1, self.ngram + 1)]
337
+ )
338
+ # Instead of explicit ignore indexes, we use the weight vector and set target idxs to 0
339
+ for marker in marker_tokens:
340
+ normed_weights[marker] = 0
341
+
342
+ return normed_weights
343
+
344
+ def _token_to_n_order(self, token: str) -> int:
345
+ """Get N-gram order for a token"""
346
+ for n_gram, word2idx in self._ngram2word2idx.items():
347
+ if token in word2idx:
348
+ return n_gram
349
+
350
+ return 0
351
+
352
+
353
+ class GPTNGMETokenizer(PreTrainedTokenizer):
354
+ model_input_names = ["input_ids", "attention_mask"]
355
+ vocab_file = "vocab.json"
356
+ vocab_files_names = {"vocab_file": vocab_file}
357
+
358
+ def __init__(
359
+ self, vocab_file, eos_token="\n", pad_token="\n", unk_token="<unk>", **kwargs
360
+ ):
361
+ eos_token = (
362
+ AddedToken(
363
+ eos_token,
364
+ lstrip=False,
365
+ rstrip=False,
366
+ )
367
+ if isinstance(eos_token, str)
368
+ else eos_token
369
+ )
370
+ pad_token = (
371
+ AddedToken(
372
+ pad_token,
373
+ lstrip=False,
374
+ rstrip=False,
375
+ )
376
+ if isinstance(pad_token, str)
377
+ else pad_token
378
+ )
379
+ unk_token = (
380
+ AddedToken(
381
+ unk_token,
382
+ lstrip=False,
383
+ rstrip=False,
384
+ )
385
+ if isinstance(unk_token, str)
386
+ else unk_token
387
+ )
388
+
389
+ super().__init__(
390
+ eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, **kwargs
391
+ )
392
+
393
+ self._ngram2word2idx = {}
394
+ self._ngram2idx2word = {}
395
+ self._current_max_idx = 0
396
+ self._frequencies: Counter = Counter()
397
+
398
+ self._load_from_file(vocab_file)
399
+
400
+ def _load_from_file(self, filename: str):
401
+ """Loads a dictionary from a file."""
402
+ vocab_file = load_vocab(filename)
403
+ self.ngram = vocab_file["ngram"]
404
+
405
+ if "\n" not in vocab_file["vocab"]:
406
+ self._add_ngram("\n", 1)
407
+
408
+ for token in vocab_file["vocab"]:
409
+ self._add_ngram(token["token"], token["ngram"])
410
+ self._frequencies.update({token["token"]: token["frequency"]})
411
+
412
+ def _add_ngram(self, word, ngram: int) -> int:
413
+ """Add a new n-gram token to the dictionary."""
414
+ self._frequencies.update({word: 1})
415
+
416
+ if ngram not in self._ngram2idx2word:
417
+ self._ngram2idx2word[ngram] = {self._current_max_idx: word}
418
+ self._ngram2word2idx[ngram] = {word: self._current_max_idx}
419
+ self._current_max_idx += 1
420
+ else:
421
+ if word not in self._ngram2word2idx[ngram]:
422
+ self._ngram2idx2word[ngram][self._current_max_idx] = word
423
+ self._ngram2word2idx[ngram][word] = self._current_max_idx
424
+ self._current_max_idx += 1
425
+
426
+ return self._ngram2word2idx[ngram][word]
427
+
428
+ def _is_contiguous(self):
429
+ vocab_size = len(self)
430
+ return list(range(vocab_size)) == [idx for idx, token in self._get_all_tokens()]
431
+
432
+ def _get_all_tokens(self):
433
+ """Returns all tokens in the dictionary."""
434
+ for ngram in range(1, self.ngram + 1):
435
+ for idx, token in self._ngram2idx2word[ngram].items():
436
+ yield idx, token
437
+
438
+ def save_vocabulary(
439
+ self, save_directory: str, filename_prefix: Optional[str] = None
440
+ ) -> Tuple[str]:
441
+ filename = os.path.join(
442
+ save_directory,
443
+ (filename_prefix + "-" if filename_prefix else ""),
444
+ self.vocab_file,
445
+ )
446
+
447
+ index = 0
448
+ vocab = {"ngram": self.ngram, "vocab": []}
449
+
450
+ for ngram in range(1, self.ngram + 1):
451
+ for idx, token in self._ngram2idx2word[ngram].items():
452
+ if index != idx:
453
+ index = idx
454
+
455
+ try:
456
+ frequency = self._frequencies[token]
457
+ except KeyError:
458
+ frequency = -1
459
+
460
+ index += 1
461
+ vocab["vocab"].append(
462
+ {
463
+ "token": token,
464
+ "index": idx,
465
+ "frequency": frequency,
466
+ "ngram": ngram,
467
+ }
468
+ )
469
+
470
+ with open(filename, "w", encoding="utf-8") as writer:
471
+ json.dump(vocab, writer, indent=4, ensure_ascii=False)
472
+
473
+ return (filename,)
474
+
475
+ @property
476
+ def vocab_size(self) -> int:
477
+ return self._current_max_idx
478
+
479
+ def retokenize(self, input_ids, *args, **kwargs):
480
+ decoded = self.convert_ids_to_tokens(input_ids)
481
+ sequence = "".join(decoded)
482
+ new_decoded = self(sequence, *args, **kwargs).input_ids
483
+ return new_decoded
484
+
485
+ def _tokenize(self, text):
486
+ ngram_sequences = []
487
+ for n in range(1, self.ngram + 1):
488
+ words = ["<start>" for _ in range(1, n)]
489
+ words.extend(list(text))
490
+
491
+ tokens = []
492
+ for i, word in enumerate(ngram_tokenizer(words, n)):
493
+ if "<start>" in word:
494
+ word = [w for w in list(word) if w != "<start>"]
495
+ tokens.append("".join(word))
496
+
497
+ ngram_sequences.append(tokens)
498
+
499
+ return ngram_sequences
500
+
501
+ def get_idx(self, token: str, ngram: Optional[int] = None) -> int:
502
+ if ngram:
503
+ if token in self._ngram2word2idx[ngram]:
504
+ return self._ngram2word2idx[ngram][token]
505
+ else:
506
+ return self._ngram2word2idx[1]["<unk>"]
507
+
508
+ for ngram in range(1, self.ngram + 1):
509
+ if token in self._ngram2word2idx[ngram]:
510
+ return self._ngram2word2idx[ngram][token]
511
+
512
+ return self._ngram2word2idx[1]["<unk>"]
513
+
514
+ def _convert_ngram_tokens_to_ids(self, ngram_tokens: List[str]) -> List[int]:
515
+ return [self.get_idx(token) for token in ngram_tokens]
516
+
517
+ def convert_tokens_to_ids(self, tokens: List[List[str]]):
518
+ if not tokens:
519
+ return []
520
+
521
+ if isinstance(tokens, str):
522
+ return self.get_idx(tokens)
523
+
524
+ return [
525
+ self._convert_ngram_tokens_to_ids(ngram_tokens) for ngram_tokens in tokens
526
+ ]
527
+
528
+ def _convert_id_to_token(self, index: int) -> str:
529
+ return self.get_item_for_index(index)
530
+
531
+ def get_item_for_index(self, idx) -> str:
532
+ """Return the token for a given index."""
533
+ for idxs in self._ngram2idx2word.values():
534
+ if idx in idxs:
535
+ return idxs[idx]
536
+
537
+ return self.unk_token
538
+
539
+ def _decode(
540
+ self, token_ids: List[List[int]], skip_special_tokens: bool = False, **kwargs
541
+ ) -> str:
542
+ return "".join(self.convert_ids_to_tokens(token_ids[0]))
543
+
544
+ def debug_decode(self, token_ids: List[List[int]]):
545
+ for n in range(1, self.ngram + 1):
546
+ print(f"{n}-gram: {self.convert_ids_to_tokens(token_ids[n-1])}")
547
+
548
+ def _pad(
549
+ self,
550
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
551
+ max_length: Optional[int] = None,
552
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
553
+ pad_to_multiple_of: Optional[int] = None,
554
+ return_attention_mask: Optional[bool] = None,
555
+ ) -> dict:
556
+ """
557
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
558
+
559
+ Args:
560
+ encoded_inputs:
561
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
562
+ max_length: maximum length of the returned list and optionally padding length (see below).
563
+ Will truncate by taking into account the special tokens.
564
+ padding_strategy: PaddingStrategy to use for padding.
565
+
566
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
567
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
568
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
569
+ The tokenizer padding sides are defined in self.padding_side:
570
+
571
+ - 'left': pads on the left of the sequences
572
+ - 'right': pads on the right of the sequences
573
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
574
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
575
+ `>= 7.5` (Volta).
576
+ return_attention_mask:
577
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
578
+ """
579
+ # encoded_inputs == one sample -> List[List[int]]
580
+
581
+ # Load from model defaults
582
+ if return_attention_mask is None:
583
+ return_attention_mask = "attention_mask" in self.model_input_names
584
+
585
+ required_input = encoded_inputs[self.model_input_names[0]]
586
+ # PHA: Check if we have a list of list of list, then we unpack
587
+ if (
588
+ len(required_input) != 0
589
+ and isinstance(required_input[0], list)
590
+ and isinstance(required_input[0][0], list)
591
+ ):
592
+ required_input = required_input[0]
593
+
594
+ if padding_strategy == PaddingStrategy.LONGEST:
595
+ max_length = len(required_input)
596
+
597
+ if (
598
+ max_length is not None
599
+ and pad_to_multiple_of is not None
600
+ and (max_length % pad_to_multiple_of != 0)
601
+ ):
602
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
603
+
604
+ needs_to_be_padded = (
605
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
606
+ and len(required_input[0]) != max_length
607
+ )
608
+
609
+ # Initialize attention mask if not present.
610
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
611
+ if len(required_input) == 0:
612
+ encoded_inputs["attention_mask"] = []
613
+ else:
614
+ encoded_inputs["attention_mask"] = [1] * len(required_input[0])
615
+
616
+ if needs_to_be_padded:
617
+ difference = max_length - len(required_input[0])
618
+
619
+ if self.padding_side == "right":
620
+ if return_attention_mask:
621
+ encoded_inputs["attention_mask"] = (
622
+ encoded_inputs["attention_mask"] + [0] * difference
623
+ )
624
+ if "token_type_ids" in encoded_inputs:
625
+ encoded_inputs["token_type_ids"] = (
626
+ encoded_inputs["token_type_ids"]
627
+ + [self.pad_token_type_id] * difference
628
+ )
629
+ if "special_tokens_mask" in encoded_inputs:
630
+ encoded_inputs["special_tokens_mask"] = (
631
+ encoded_inputs["special_tokens_mask"] + [1] * difference
632
+ )
633
+ for i in range(len(encoded_inputs[self.model_input_names[0]])):
634
+ encoded_inputs[self.model_input_names[0]][i] = (
635
+ required_input[i] + [self.pad_token_id] * difference
636
+ )
637
+ elif self.padding_side == "left":
638
+ if return_attention_mask:
639
+ encoded_inputs["attention_mask"] = [
640
+ 0
641
+ ] * difference + encoded_inputs["attention_mask"]
642
+ if "token_type_ids" in encoded_inputs:
643
+ encoded_inputs["token_type_ids"] = [
644
+ self.pad_token_type_id
645
+ ] * difference + encoded_inputs["token_type_ids"]
646
+ if "special_tokens_mask" in encoded_inputs:
647
+ encoded_inputs["special_tokens_mask"] = [
648
+ 1
649
+ ] * difference + encoded_inputs["special_tokens_mask"]
650
+
651
+ for i in range(len(encoded_inputs[self.model_input_names[0]])):
652
+ encoded_inputs[self.model_input_names[0]][i] = [
653
+ self.pad_token_id
654
+ ] * difference + required_input[i]
655
+ else:
656
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
657
+
658
+ return encoded_inputs
659
+
660
+ def pad(
661
+ self,
662
+ encoded_inputs: Union[
663
+ BatchEncoding,
664
+ List[BatchEncoding],
665
+ Dict[str, EncodedInput],
666
+ Dict[str, List[EncodedInput]],
667
+ List[Dict[str, EncodedInput]],
668
+ ],
669
+ padding: Union[bool, str, PaddingStrategy] = True,
670
+ max_length: Optional[int] = None,
671
+ pad_to_multiple_of: Optional[int] = None,
672
+ return_attention_mask: Optional[bool] = None,
673
+ return_tensors: Optional[Union[str, TensorType]] = None,
674
+ verbose: bool = True,
675
+ ) -> BatchEncoding:
676
+ """
677
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
678
+ in the batch.
679
+
680
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`,
681
+
682
+ `self.pad_token_id` and `self.pad_token_type_id`).
683
+
684
+ Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the
685
+ text followed by a call to the `pad` method to get a padded encoding.
686
+
687
+ <Tip>
688
+
689
+ If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
690
+ result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
691
+ PyTorch tensors, you will lose the specific device of your tensors however.
692
+
693
+ </Tip>
694
+
695
+ Args:
696
+ encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `Dict[str, List[int]]`, `Dict[str, List[List[int]]` or `List[Dict[str, List[int]]]`):
697
+ Tokenized inputs. Can represent one input ([`BatchEncoding`] or `Dict[str, List[int]]`) or a batch of
698
+ tokenized inputs (list of [`BatchEncoding`], *Dict[str, List[List[int]]]* or *List[Dict[str,
699
+ List[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
700
+ collate function.
701
+
702
+ Instead of `List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see
703
+ the note above for the return type.
704
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
705
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
706
+ index) among:
707
+
708
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
709
+ sequence if provided).
710
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
711
+ acceptable input length for the model if that argument is not provided.
712
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
713
+ lengths).
714
+ max_length (`int`, *optional*):
715
+ Maximum length of the returned list and optionally padding length (see above).
716
+ pad_to_multiple_of (`int`, *optional*):
717
+ If set will pad the sequence to a multiple of the provided value.
718
+
719
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
720
+ `>= 7.5` (Volta).
721
+ return_attention_mask (`bool`, *optional*):
722
+ Whether to return the attention mask. If left to the default, will return the attention mask according
723
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
724
+
725
+ [What are attention masks?](../glossary#attention-mask)
726
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
727
+ If set, will return tensors instead of list of python integers. Acceptable values are:
728
+
729
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
730
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
731
+ - `'np'`: Return Numpy `np.ndarray` objects.
732
+ verbose (`bool`, *optional*, defaults to `True`):
733
+ Whether or not to print more information and warnings.
734
+ """
735
+
736
+ # Problem: The pad function checks if the encoded_inputs is a list or not
737
+ # If it is a list it assumes that we have batches
738
+ # With ngme encoding the input is always a list
739
+
740
+ # If we have a list of dicts, let's convert it in a dict of lists
741
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
742
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(
743
+ encoded_inputs[0], Mapping
744
+ ):
745
+ encoded_inputs = {
746
+ key: [example[key] for example in encoded_inputs]
747
+ for key in encoded_inputs[0].keys()
748
+ }
749
+
750
+ # The model's main input name, usually `input_ids`, has be passed for padding
751
+ if self.model_input_names[0] not in encoded_inputs:
752
+ raise ValueError(
753
+ "You should supply an encoding or a list of encodings to this method "
754
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
755
+ )
756
+
757
+ required_input = encoded_inputs[self.model_input_names[0]]
758
+
759
+ if required_input is None or (
760
+ isinstance(required_input, Sized) and len(required_input) == 0
761
+ ):
762
+ if return_attention_mask:
763
+ encoded_inputs["attention_mask"] = []
764
+ return encoded_inputs
765
+
766
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
767
+ # and rebuild them afterwards if no return_tensors is specified
768
+ # Note that we lose the specific device the tensor may be on for PyTorch
769
+
770
+ first_element = required_input[0]
771
+ # PHA: First element in ngme is a list of list
772
+ if isinstance(first_element, (list, tuple)):
773
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
774
+ for item in required_input:
775
+ if len(item) != 0:
776
+ first_element = item[0]
777
+ break
778
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
779
+ if not isinstance(first_element, (int, list, tuple)):
780
+ if is_tf_tensor(first_element):
781
+ return_tensors = "tf" if return_tensors is None else return_tensors
782
+ elif is_torch_tensor(first_element):
783
+ return_tensors = "pt" if return_tensors is None else return_tensors
784
+ elif isinstance(first_element, np.ndarray):
785
+ return_tensors = "np" if return_tensors is None else return_tensors
786
+ else:
787
+ raise ValueError(
788
+ f"type of {first_element} unknown: {type(first_element)}. "
789
+ "Should be one of a python, numpy, pytorch or tensorflow object."
790
+ )
791
+
792
+ for key, value in encoded_inputs.items():
793
+ encoded_inputs[key] = to_py_obj(value)
794
+
795
+ # Convert padding_strategy in PaddingStrategy
796
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
797
+ padding=padding, max_length=max_length, verbose=verbose
798
+ )
799
+
800
+ required_input = encoded_inputs[self.model_input_names[0]]
801
+
802
+ if required_input:
803
+ if isinstance(required_input[0], (list, tuple)):
804
+ if len(required_input[0]) > 0 and not isinstance(
805
+ required_input[0][0], (list, tuple)
806
+ ):
807
+ encoded_inputs = self._pad(
808
+ encoded_inputs,
809
+ max_length=max_length,
810
+ padding_strategy=padding_strategy,
811
+ pad_to_multiple_of=pad_to_multiple_of,
812
+ return_attention_mask=return_attention_mask,
813
+ )
814
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
815
+
816
+ batch_size = len(required_input)
817
+ assert all(
818
+ len(v) == batch_size for v in encoded_inputs.values()
819
+ ), "Some items in the output dictionary have a different batch size than others."
820
+
821
+ if padding_strategy == PaddingStrategy.LONGEST:
822
+ max_length = max(len(inputs[0]) for inputs in required_input)
823
+ padding_strategy = PaddingStrategy.MAX_LENGTH
824
+
825
+ batch_outputs = {}
826
+ for i in range(batch_size):
827
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
828
+ outputs = self._pad(
829
+ inputs,
830
+ max_length=max_length,
831
+ padding_strategy=padding_strategy,
832
+ pad_to_multiple_of=pad_to_multiple_of,
833
+ return_attention_mask=return_attention_mask,
834
+ )
835
+
836
+ for key, value in outputs.items():
837
+ if key not in batch_outputs:
838
+ batch_outputs[key] = []
839
+ batch_outputs[key].append(value)
840
+
841
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
842
+
843
+ def prepare_for_model(
844
+ self,
845
+ ids: List[int],
846
+ pair_ids: Optional[List[int]] = None,
847
+ add_special_tokens: bool = True,
848
+ padding: Union[bool, str, PaddingStrategy] = False,
849
+ truncation: Union[bool, str, TruncationStrategy] = None,
850
+ max_length: Optional[int] = None,
851
+ stride: int = 0,
852
+ pad_to_multiple_of: Optional[int] = None,
853
+ return_tensors: Optional[Union[str, TensorType]] = None,
854
+ return_token_type_ids: Optional[bool] = None,
855
+ return_attention_mask: Optional[bool] = None,
856
+ return_overflowing_tokens: bool = False,
857
+ return_special_tokens_mask: bool = False,
858
+ return_offsets_mapping: bool = False,
859
+ return_length: bool = False,
860
+ verbose: bool = True,
861
+ prepend_batch_axis: bool = False,
862
+ **kwargs,
863
+ ) -> BatchEncoding:
864
+ """
865
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
866
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
867
+ manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
868
+ different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
869
+ overflowing tokens. Such a combination of arguments will raise an error.
870
+ Args:
871
+ ids (`List[int]`):
872
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
873
+ `convert_tokens_to_ids` methods.
874
+ pair_ids (`List[int]`, *optional*):
875
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
876
+ and `convert_tokens_to_ids` methods.
877
+ """
878
+
879
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
880
+ (
881
+ padding_strategy,
882
+ truncation_strategy,
883
+ max_length,
884
+ kwargs,
885
+ ) = self._get_padding_truncation_strategies(
886
+ padding=padding,
887
+ truncation=truncation,
888
+ max_length=max_length,
889
+ pad_to_multiple_of=pad_to_multiple_of,
890
+ verbose=verbose,
891
+ **kwargs,
892
+ )
893
+
894
+ pair = bool(pair_ids is not None)
895
+
896
+ if len(ids) == 0:
897
+ len_ids = 0
898
+ else:
899
+ len_ids = len(ids[0])
900
+
901
+ if pair and len(pair_ids) == 0:
902
+ len_pair_ids = 0
903
+ elif pair and len(pair_ids) > 0:
904
+ len_pair_ids = len(pair_ids[0])
905
+ else:
906
+ len_pair_ids = 0
907
+
908
+ if return_token_type_ids and not add_special_tokens:
909
+ raise ValueError(
910
+ "Asking to return token_type_ids while setting add_special_tokens to False "
911
+ "results in an undefined behavior. Please set add_special_tokens to True or "
912
+ "set return_token_type_ids to None."
913
+ )
914
+
915
+ if (
916
+ return_overflowing_tokens
917
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
918
+ and pair_ids is not None
919
+ ):
920
+ raise ValueError(
921
+ "Not possible to return overflowing tokens for pair of sequences with the "
922
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
923
+ "for instance `only_second` or `only_first`."
924
+ )
925
+
926
+ # Load from model defaults
927
+ if return_token_type_ids is None:
928
+ return_token_type_ids = "token_type_ids" in self.model_input_names
929
+ if return_attention_mask is None:
930
+ return_attention_mask = "attention_mask" in self.model_input_names
931
+
932
+ encoded_inputs = {}
933
+
934
+ # Compute the total size of the returned encodings
935
+ total_len = (
936
+ len_ids
937
+ + len_pair_ids
938
+ + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
939
+ )
940
+
941
+ # Truncation: Handle max sequence length
942
+ overflowing_tokens = []
943
+ if (
944
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
945
+ and max_length
946
+ and total_len > max_length
947
+ ):
948
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(
949
+ ids,
950
+ pair_ids=pair_ids,
951
+ num_tokens_to_remove=total_len - max_length,
952
+ truncation_strategy=truncation_strategy,
953
+ stride=stride,
954
+ )
955
+
956
+ if return_overflowing_tokens:
957
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
958
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
959
+
960
+ # Add special tokens
961
+ if add_special_tokens:
962
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
963
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
964
+ else:
965
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
966
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
967
+
968
+ # Build output dictionary
969
+ encoded_inputs["input_ids"] = sequence
970
+ if return_token_type_ids:
971
+ encoded_inputs["token_type_ids"] = token_type_ids
972
+ if return_special_tokens_mask:
973
+ if add_special_tokens:
974
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(
975
+ ids, pair_ids
976
+ )
977
+ else:
978
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
979
+
980
+ # Check lengths
981
+ self._eventual_warn_about_too_long_sequence(
982
+ encoded_inputs["input_ids"], max_length, verbose
983
+ )
984
+
985
+ # Padding
986
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
987
+ encoded_inputs = self.pad(
988
+ encoded_inputs,
989
+ max_length=max_length,
990
+ padding=padding_strategy.value,
991
+ pad_to_multiple_of=pad_to_multiple_of,
992
+ return_attention_mask=return_attention_mask,
993
+ )
994
+
995
+ if return_length:
996
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
997
+
998
+ batch_outputs = BatchEncoding(
999
+ encoded_inputs,
1000
+ tensor_type=return_tensors,
1001
+ prepend_batch_axis=prepend_batch_axis,
1002
+ )
1003
+
1004
+ return batch_outputs
1005
+
1006
+ def build_inputs_with_special_tokens(
1007
+ self,
1008
+ token_ids_0: List[List[int]],
1009
+ token_ids_1: Optional[List[List[int]]] = None,
1010
+ ) -> List[List[int]]:
1011
+ """
1012
+ Concatenate nested ngram sequences.
1013
+
1014
+ Args:
1015
+ token_ids_0 (`List[List[int]]`): The first tokenized sequence.
1016
+ token_ids_1 (`List[List[int]]`, *optional*): The second tokenized sequence.
1017
+
1018
+ Returns:
1019
+ `List[List[int]]`: The model input with special tokens.
1020
+ """
1021
+ if token_ids_1 is None or len(token_ids_1) == 0:
1022
+ return token_ids_0
1023
+
1024
+ if len(token_ids_0) == 0:
1025
+ return token_ids_1
1026
+
1027
+ return np.concatenate(
1028
+ (np.array(token_ids_0), np.array(token_ids_1)), axis=1
1029
+ ).tolist()
1030
+
1031
+ def truncate_sequences(
1032
+ self,
1033
+ ids: List[int],
1034
+ pair_ids: Optional[List[int]] = None,
1035
+ num_tokens_to_remove: int = 0,
1036
+ truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
1037
+ stride: int = 0,
1038
+ ) -> Tuple[List[int], List[int], List[int]]:
1039
+ """
1040
+ Truncates a sequence pair in-place following the strategy.
1041
+ Args:
1042
+ ids (`List[int]`):
1043
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
1044
+ `convert_tokens_to_ids` methods.
1045
+ pair_ids (`List[int]`, *optional*):
1046
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
1047
+ and `convert_tokens_to_ids` methods.
1048
+ num_tokens_to_remove (`int`, *optional*, defaults to 0):
1049
+ Number of tokens to remove using the truncation strategy.
1050
+ truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
1051
+ The strategy to follow for truncation. Can be:
1052
+ - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
1053
+ maximum acceptable input length for the model if that argument is not provided. This will truncate
1054
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a
1055
+ batch of pairs) is provided.
1056
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
1057
+ maximum acceptable input length for the model if that argument is not provided. This will only
1058
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
1059
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
1060
+ maximum acceptable input length for the model if that argument is not provided. This will only
1061
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
1062
+ - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
1063
+ than the model maximum admissible input size).
1064
+ stride (`int`, *optional*, defaults to 0):
1065
+ If set to a positive number, the overflowing tokens returned will contain some tokens from the main
1066
+ sequence returned. The value of this argument defines the number of additional tokens.
1067
+ Returns:
1068
+ `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of
1069
+ overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair
1070
+ of sequences (or a batch of pairs) is provided.
1071
+ """
1072
+ if num_tokens_to_remove <= 0:
1073
+ return ids, pair_ids, []
1074
+
1075
+ if not isinstance(truncation_strategy, TruncationStrategy):
1076
+ truncation_strategy = TruncationStrategy(truncation_strategy)
1077
+
1078
+ overflowing_tokens = []
1079
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
1080
+ truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
1081
+ ):
1082
+ ids = np.array(ids)
1083
+
1084
+ # PHA: I think we only truncate with longest first
1085
+ if ids.shape[1] > num_tokens_to_remove:
1086
+ window_len = min(ids.shape[1], stride + num_tokens_to_remove)
1087
+ if self.truncation_side == "left":
1088
+ overflowing_tokens = ids[:, :window_len]
1089
+ ids = ids[:, num_tokens_to_remove:]
1090
+ elif self.truncation_side == "right":
1091
+ overflowing_tokens = ids[-window_len:]
1092
+ ids = ids[:, :-num_tokens_to_remove]
1093
+ else:
1094
+ raise ValueError(
1095
+ f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'."
1096
+ )
1097
+
1098
+ ids = ids.tolist()
1099
+
1100
+ else:
1101
+ error_msg = (
1102
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
1103
+ f"but the first sequence has a length {len(ids)}. "
1104
+ )
1105
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST:
1106
+ error_msg = (
1107
+ error_msg + "Please select another truncation strategy than "
1108
+ f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
1109
+ )
1110
+ logger.error(error_msg)
1111
+ elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
1112
+ logger.warning(
1113
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
1114
+ f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
1115
+ "truncation strategy. So the returned list will always be empty even if some "
1116
+ "tokens have been removed."
1117
+ )
1118
+ ids = np.array(ids)
1119
+ pair_ids = np.array(pair_ids)
1120
+
1121
+ for _ in range(num_tokens_to_remove):
1122
+ if pair_ids is None or ids.shape[1] > pair_ids.shape[1]:
1123
+ if self.truncation_side == "right":
1124
+ ids = ids[:, :-1]
1125
+ elif self.truncation_side == "left":
1126
+ ids = ids[:, 1:]
1127
+ else:
1128
+ raise ValueError(
1129
+ "invalid truncation strategy:" + str(self.truncation_side)
1130
+ )
1131
+ else:
1132
+ if self.truncation_side == "right":
1133
+ pair_ids = pair_ids[:, :-1]
1134
+ elif self.truncation_side == "left":
1135
+ pair_ids = pair_ids[:, 1:]
1136
+ else:
1137
+ raise ValueError(
1138
+ "invalid truncation strategy:" + str(self.truncation_side)
1139
+ )
1140
+
1141
+ ids = ids.tolist()
1142
+ pair_ids = pair_ids.tolist()
1143
+
1144
+ elif (
1145
+ truncation_strategy == TruncationStrategy.ONLY_SECOND
1146
+ and pair_ids is not None
1147
+ ):
1148
+ raise NotImplementedError(
1149
+ "PHA: I think we only truncate with longest first"
1150
+ )
1151
+ if len(pair_ids) > num_tokens_to_remove:
1152
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
1153
+ if self.truncation_side == "right":
1154
+ overflowing_tokens = pair_ids[-window_len:]
1155
+ pair_ids = pair_ids[:-num_tokens_to_remove]
1156
+ elif self.truncation_side == "left":
1157
+ overflowing_tokens = pair_ids[:window_len]
1158
+ pair_ids = pair_ids[num_tokens_to_remove:]
1159
+ else:
1160
+ raise ValueError(
1161
+ "invalid truncation strategy:" + str(self.truncation_side)
1162
+ )
1163
+ else:
1164
+ logger.error(
1165
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
1166
+ f"but the second sequence has a length {len(pair_ids)}. "
1167
+ f"Please select another truncation strategy than {truncation_strategy}, "
1168
+ "for instance 'longest_first' or 'only_first'."
1169
+ )
1170
+
1171
+ return (ids, pair_ids, overflowing_tokens)
1172
+
1173
+ def _token_to_n_order(self, token: str) -> int:
1174
+ """Get N-gram order for a token"""
1175
+ for n_gram, word2idx in self._ngram2word2idx.items():
1176
+ if token in word2idx:
1177
+ return n_gram
1178
+
1179
+ return 0
1180
+
1181
+ def create_weight_tensor(self) -> torch.Tensor:
1182
+ unked_freqs = self._frequencies.most_common()
1183
+
1184
+ t = torch.ones(len(self))
1185
+
1186
+ for token, freq in unked_freqs:
1187
+ t[self._ngram2word2idx[self._token_to_n_order(token)][token]] = freq
1188
+
1189
+ # Ensure the only whitespace character is weighted
1190
+ t[self._ngram2word2idx[1][" "]] = 1.0
1191
+
1192
+ normed_weights = torch.tensor([(1 - (x / (max(t) + 1))).item() for x in t])
1193
+
1194
+ marker_tokens = [self.get_idx("<unk>", n) for n in range(1, self.ngram + 1)]
1195
+ marker_tokens.extend(
1196
+ [self.get_idx("<start>", n) for n in range(1, self.ngram + 1)]
1197
+ )
1198
+ # Instead of explicit ignore indexes, we use the weight vector and set target idxs to 0
1199
+ for marker in marker_tokens:
1200
+ normed_weights[marker] = 0
1201
+
1202
+ return normed_weights
1203
+
1204
+
1205
+ class TestTokenizer(unittest.TestCase):
1206
+ def test_one(self):
1207
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/1-gram-babylm.json"
1208
+
1209
+ t = NGMETokenizer(vocab_file)
1210
+ self.assertEqual(t.get_idx("<unk>", 1), 1)
1211
+
1212
+ result = t("hello world")
1213
+ self.assertEqual(result.input_ids, [16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12])
1214
+
1215
+ result = t("<unk>")
1216
+ self.assertEqual(result.input_ids, [1, 13, 5, 24, 1])
1217
+
1218
+ result = t(["hello world", "<unk>"])
1219
+ self.assertEqual(
1220
+ result.input_ids,
1221
+ [[16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12], [1, 13, 5, 24, 1]],
1222
+ )
1223
+
1224
+ def test_three(self):
1225
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1226
+
1227
+ t = NGMETokenizer(vocab_file)
1228
+
1229
+ result = t("hello world")
1230
+ self.assertEqual(result.input_ids, [16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12])
1231
+
1232
+ result = t("hello", return_ngram_sequences=True)
1233
+
1234
+ result = t(["hello world"], return_ngram_sequences=True)
1235
+ two_gram_expected = [[16, 208, 229, 230, 231, 1, 1, 312, 257, 499, 306]]
1236
+
1237
+ self.assertEqual(result["gram_2_sequence"], two_gram_expected)
1238
+ self.assertEqual(t._ngram2idx2word[1][16], "h")
1239
+ self.assertEqual(t._ngram2idx2word[2][208], "he")
1240
+ self.assertEqual(t._ngram2idx2word[2][229], "el")
1241
+
1242
+ def test_unks(self):
1243
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/2-gram-wiki-en.json"
1244
+ t = NGMETokenizer(vocab_file)
1245
+ result = t("OciVDjöShG", return_ngram_sequences=True, return_tensors="pt")
1246
+
1247
+ def test_decode(self):
1248
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1249
+ t = NGMETokenizer(vocab_file)
1250
+ decoded = t.decode(208)
1251
+ assert decoded == "he"
1252
+
1253
+ def test_padding(self):
1254
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1255
+ t = NGMETokenizer(vocab_file)
1256
+ result = t(
1257
+ "hello world",
1258
+ return_tensors="pt",
1259
+ padding="max_length",
1260
+ max_length=20,
1261
+ return_ngram_sequences=True,
1262
+ )
1263
+
1264
+ self.assertEqual(result.input_ids.shape, (1, 20))
1265
+ self.assertEqual(result.gram_2_sequence.shape, (1, 20))
1266
+ self.assertEqual(result.gram_3_sequence.shape, (1, 20))
1267
+
1268
+ def test_truncation(self):
1269
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1270
+ t = NGMETokenizer(vocab_file)
1271
+
1272
+ result = t(
1273
+ "hello world",
1274
+ return_tensors="pt",
1275
+ truncation=True,
1276
+ max_length=5,
1277
+ return_ngram_sequences=True,
1278
+ )
1279
+ self.assertEqual(result.input_ids.shape, (1, 5))
1280
+ self.assertEqual(result.gram_2_sequence.shape, (1, 5))
1281
+
1282
+ def test_padding_and_truncation(self):
1283
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1284
+ t = NGMETokenizer(vocab_file)
1285
+
1286
+ result = t(
1287
+ ["four", "something longer"],
1288
+ return_tensors="pt",
1289
+ padding="max_length",
1290
+ truncation=True,
1291
+ max_length=5,
1292
+ return_ngram_sequences=True,
1293
+ )
1294
+ self.assertEqual(result.input_ids.shape, (2, 5))
1295
+ self.assertEqual(result.gram_2_sequence.shape, (2, 5))
1296
+
1297
+
1298
+ if __name__ == "__main__":
1299
+ unittest.main()