PatrickHaller commited on
Commit
e0609f8
1 Parent(s): 1cfc099

Upload NGMEForCausalLM

Browse files
Files changed (8) hide show
  1. config.json +34 -0
  2. configuration_ngme.py +177 -0
  3. generation_config.json +7 -0
  4. modeling_ngme.py +1203 -0
  5. ngme.py +178 -0
  6. pytorch_model.bin +3 -0
  7. sampling.py +205 -0
  8. tokenization_ngme.py +1303 -0
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/glusterfs/dfs-gfs-dist/hallepat/ngme",
3
+ "architectures": [
4
+ "NGMEForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_ngme.NGMEConfig",
8
+ "AutoModelForCausalLM": "modeling_ngme.NGMEForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 0,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 768,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "max_position_embeddings": 2048,
17
+ "model_type": "ngme",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "num_key_value_heads": 12,
21
+ "pad_token_id": 0,
22
+ "pretraining_tp": 1,
23
+ "rms_norm_eps": 1e-06,
24
+ "rope_scaling": null,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.26.1",
28
+ "unk_idx": 1,
29
+ "unk_token_id": 1,
30
+ "use_cache": true,
31
+ "use_flash_attn": false,
32
+ "use_small_embedding": false,
33
+ "vocab_size": 15237
34
+ }
configuration_ngme.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class NGMEConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`NGMEModel`]. It is used to instantiate an LLaMA
31
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
+ defaults will yield a similar configuration to that of the LLaMA-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`NGMEModel`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ pretraining_tp (`int`, *optional*, defaults to `1`):
59
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
60
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
61
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
62
+ issue](https://github.com/pytorch/pytorch/issues/76232).
63
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
64
+ The non-linear activation function (function or string) in the decoder.
65
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
66
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
67
+ just in case (e.g., 512 or 1024 or 2048).
68
+ initializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
71
+ The epsilon used by the rms normalization layers.
72
+ use_cache (`bool`, *optional*, defaults to `True`):
73
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
74
+ relevant if `config.is_decoder=True`.
75
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
76
+ Whether to tie weight embeddings
77
+ rope_scaling (`Dict`, *optional*):
78
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
79
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
80
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
81
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
82
+ these scaling strategies behave:
83
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
84
+ experimental feature, subject to breaking API changes in future versions.
85
+
86
+ Example:
87
+
88
+ ```python
89
+ >>> from transformers import NGMEModel, LlamaConfig
90
+
91
+ >>> # Initializing a LLaMA llama-7b style configuration
92
+ >>> configuration = NGMEConfig()
93
+
94
+ >>> # Initializing a model from the llama-7b style configuration
95
+ >>> model = NGMEModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+ model_type = "ngme"
101
+ keys_to_ignore_at_inference = ["past_key_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size=32000,
106
+ hidden_size=4096,
107
+ intermediate_size=11008,
108
+ num_hidden_layers=32,
109
+ num_attention_heads=32,
110
+ num_key_value_heads=None,
111
+ hidden_act="silu",
112
+ max_position_embeddings=2048,
113
+ initializer_range=0.02,
114
+ rms_norm_eps=1e-6,
115
+ use_cache=True,
116
+ pad_token_id=None,
117
+ bos_token_id=1,
118
+ eos_token_id=2,
119
+ pretraining_tp=1,
120
+ tie_word_embeddings=False,
121
+ rope_scaling=None,
122
+ use_flash_attn=False,
123
+ use_small_embedding=False,
124
+ unk_idx=-1,
125
+ **kwargs,
126
+ ):
127
+ self.vocab_size = vocab_size
128
+ self.unk_idx = unk_idx
129
+ self.max_position_embeddings = max_position_embeddings
130
+ self.hidden_size = hidden_size
131
+ self.intermediate_size = intermediate_size
132
+ self.num_hidden_layers = num_hidden_layers
133
+ self.num_attention_heads = num_attention_heads
134
+
135
+ # for backward compatibility
136
+ if num_key_value_heads is None:
137
+ num_key_value_heads = num_attention_heads
138
+
139
+ self.num_key_value_heads = num_key_value_heads
140
+ self.hidden_act = hidden_act
141
+ self.initializer_range = initializer_range
142
+ self.rms_norm_eps = rms_norm_eps
143
+ self.pretraining_tp = pretraining_tp
144
+ self.use_cache = use_cache
145
+ self.rope_scaling = rope_scaling
146
+ self._rope_scaling_validation()
147
+ self.use_flash_attn = use_flash_attn
148
+ self.use_small_embedding = use_small_embedding
149
+
150
+ super().__init__(
151
+ pad_token_id=pad_token_id,
152
+ bos_token_id=bos_token_id,
153
+ eos_token_id=eos_token_id,
154
+ tie_word_embeddings=tie_word_embeddings,
155
+ **kwargs,
156
+ )
157
+
158
+ def _rope_scaling_validation(self):
159
+ """
160
+ Validate the `rope_scaling` configuration.
161
+ """
162
+ if self.rope_scaling is None:
163
+ return
164
+
165
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
166
+ raise ValueError(
167
+ "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
168
+ f"got {self.rope_scaling}"
169
+ )
170
+ rope_scaling_type = self.rope_scaling.get("type", None)
171
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
172
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
173
+ raise ValueError(
174
+ f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
175
+ )
176
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
177
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 0,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.26.1"
7
+ }
modeling_ngme.py ADDED
@@ -0,0 +1,1203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPast,
13
+ CausalLMOutputWithPast,
14
+ SequenceClassifierOutputWithPast,
15
+ )
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+ from transformers.generation.utils import SampleOutput
19
+ from transformers.generation.logits_process import LogitsProcessorList
20
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
21
+
22
+ from .tokenization_ngme import NGMETokenizer
23
+ from .sampling import sample as sample_ngme
24
+ from .configuration_ngme import NGMEConfig
25
+ from .ngme import (
26
+ soft_n_hot,
27
+ NGramsEmbedding,
28
+ collect_n_gram_sequences,
29
+ shift_with_pad,
30
+ )
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
36
+ def _make_causal_mask(
37
+ input_ids_shape: torch.Size,
38
+ dtype: torch.dtype,
39
+ device: torch.device,
40
+ past_key_values_length: int = 0,
41
+ ):
42
+ """
43
+ Make causal mask used for bi-directional self-attention.
44
+ """
45
+ bsz, tgt_len = input_ids_shape
46
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
47
+ mask_cond = torch.arange(mask.size(-1), device=device)
48
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
49
+ mask = mask.to(dtype)
50
+
51
+ if past_key_values_length > 0:
52
+ mask = torch.cat(
53
+ [
54
+ torch.zeros(
55
+ tgt_len, past_key_values_length, dtype=dtype, device=device
56
+ ),
57
+ mask,
58
+ ],
59
+ dim=-1,
60
+ )
61
+ return mask[None, None, :, :].expand(
62
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
63
+ )
64
+
65
+
66
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
67
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
68
+ """
69
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
70
+ """
71
+ bsz, src_len = mask.size()
72
+ tgt_len = tgt_len if tgt_len is not None else src_len
73
+
74
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
75
+
76
+ inverted_mask = 1.0 - expanded_mask
77
+
78
+ return inverted_mask.masked_fill(
79
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
80
+ )
81
+
82
+
83
+ class NGMERMSNorm(nn.Module):
84
+ def __init__(self, hidden_size, eps=1e-6):
85
+ """
86
+ NGMERMSNorm is equivalent to T5LayerNorm
87
+ """
88
+ super().__init__()
89
+ self.weight = nn.Parameter(torch.ones(hidden_size))
90
+ self.variance_epsilon = eps
91
+
92
+ def forward(self, hidden_states):
93
+ input_dtype = hidden_states.dtype
94
+ hidden_states = hidden_states.to(torch.float32)
95
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
96
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
97
+ return self.weight * hidden_states.to(input_dtype)
98
+
99
+
100
+ class NGMERotaryEmbedding(torch.nn.Module):
101
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
102
+ super().__init__()
103
+
104
+ self.dim = dim
105
+ self.max_position_embeddings = max_position_embeddings
106
+ self.base = base
107
+ inv_freq = 1.0 / (
108
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
109
+ )
110
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
111
+
112
+ # Build here to make `torch.jit.trace` work.
113
+ self._set_cos_sin_cache(
114
+ seq_len=max_position_embeddings,
115
+ device=self.inv_freq.device,
116
+ dtype=torch.get_default_dtype(),
117
+ )
118
+
119
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
120
+ self.max_seq_len_cached = seq_len
121
+ t = torch.arange(
122
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
123
+ )
124
+
125
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
126
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
127
+ emb = torch.cat((freqs, freqs), dim=-1)
128
+ self.register_buffer(
129
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
130
+ )
131
+ self.register_buffer(
132
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
133
+ )
134
+
135
+ def forward(self, x, seq_len=None):
136
+ # x: [bs, num_attention_heads, seq_len, head_size]
137
+ if seq_len > self.max_seq_len_cached:
138
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
139
+
140
+ return (
141
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
142
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
143
+ )
144
+
145
+
146
+ class NGMELinearScalingRotaryEmbedding(NGMERotaryEmbedding):
147
+ """NGMERotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
148
+
149
+ def __init__(
150
+ self,
151
+ dim,
152
+ max_position_embeddings=2048,
153
+ base=10000,
154
+ device=None,
155
+ scaling_factor=1.0,
156
+ ):
157
+ self.scaling_factor = scaling_factor
158
+ super().__init__(dim, max_position_embeddings, base, device)
159
+
160
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
161
+ self.max_seq_len_cached = seq_len
162
+ t = torch.arange(
163
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
164
+ )
165
+ t = t / self.scaling_factor
166
+
167
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
168
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
169
+ emb = torch.cat((freqs, freqs), dim=-1)
170
+ self.register_buffer(
171
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
172
+ )
173
+ self.register_buffer(
174
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
175
+ )
176
+
177
+
178
+ class NGMEDynamicNTKScalingRotaryEmbedding(NGMERotaryEmbedding):
179
+ """NGMERotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
180
+
181
+ def __init__(
182
+ self,
183
+ dim,
184
+ max_position_embeddings=2048,
185
+ base=10000,
186
+ device=None,
187
+ scaling_factor=1.0,
188
+ ):
189
+ self.scaling_factor = scaling_factor
190
+ super().__init__(dim, max_position_embeddings, base, device)
191
+
192
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
193
+ self.max_seq_len_cached = seq_len
194
+
195
+ if seq_len > self.max_position_embeddings:
196
+ base = self.base * (
197
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
198
+ - (self.scaling_factor - 1)
199
+ ) ** (self.dim / (self.dim - 2))
200
+ inv_freq = 1.0 / (
201
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
202
+ )
203
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
204
+
205
+ t = torch.arange(
206
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
207
+ )
208
+
209
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
210
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
211
+ emb = torch.cat((freqs, freqs), dim=-1)
212
+ self.register_buffer(
213
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
214
+ )
215
+ self.register_buffer(
216
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
217
+ )
218
+
219
+
220
+ def rotate_half(x):
221
+ """Rotates half the hidden dims of the input."""
222
+ x1 = x[..., : x.shape[-1] // 2]
223
+ x2 = x[..., x.shape[-1] // 2 :]
224
+ return torch.cat((-x2, x1), dim=-1)
225
+
226
+
227
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
228
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
229
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
230
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
231
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
232
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
233
+ q_embed = (q * cos) + (rotate_half(q) * sin)
234
+ k_embed = (k * cos) + (rotate_half(k) * sin)
235
+ return q_embed, k_embed
236
+
237
+
238
+ class NGMEMLP(nn.Module):
239
+ def __init__(self, config):
240
+ super().__init__()
241
+ self.config = config
242
+ self.hidden_size = config.hidden_size
243
+ self.intermediate_size = config.intermediate_size
244
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
245
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
246
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
247
+ self.act_fn = ACT2FN[config.hidden_act]
248
+
249
+ def forward(self, x):
250
+ if self.config.pretraining_tp > 1:
251
+ slice = self.intermediate_size // self.config.pretraining_tp
252
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
253
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
254
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
255
+
256
+ gate_proj = torch.cat(
257
+ [
258
+ F.linear(x, gate_proj_slices[i])
259
+ for i in range(self.config.pretraining_tp)
260
+ ],
261
+ dim=-1,
262
+ )
263
+ up_proj = torch.cat(
264
+ [
265
+ F.linear(x, up_proj_slices[i])
266
+ for i in range(self.config.pretraining_tp)
267
+ ],
268
+ dim=-1,
269
+ )
270
+
271
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
272
+ down_proj = [
273
+ F.linear(intermediate_states[i], down_proj_slices[i])
274
+ for i in range(self.config.pretraining_tp)
275
+ ]
276
+ down_proj = sum(down_proj)
277
+ else:
278
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
279
+
280
+ return down_proj
281
+
282
+
283
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
284
+ """
285
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
286
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
287
+ """
288
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
289
+ if n_rep == 1:
290
+ return hidden_states
291
+ hidden_states = hidden_states[:, :, None, :, :].expand(
292
+ batch, num_key_value_heads, n_rep, slen, head_dim
293
+ )
294
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
295
+
296
+
297
+ class NGMEAttention(nn.Module):
298
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
299
+
300
+ def __init__(self, config: NGMEConfig):
301
+ super().__init__()
302
+ self.config = config
303
+ self.hidden_size = config.hidden_size
304
+ self.num_heads = config.num_attention_heads
305
+ self.head_dim = self.hidden_size // self.num_heads
306
+ self.num_key_value_heads = config.num_key_value_heads
307
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
308
+ self.max_position_embeddings = config.max_position_embeddings
309
+
310
+ if (self.head_dim * self.num_heads) != self.hidden_size:
311
+ raise ValueError(
312
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
313
+ f" and `num_heads`: {self.num_heads})."
314
+ )
315
+ self.q_proj = nn.Linear(
316
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
317
+ )
318
+ self.k_proj = nn.Linear(
319
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
320
+ )
321
+ self.v_proj = nn.Linear(
322
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
323
+ )
324
+ self.o_proj = nn.Linear(
325
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
326
+ )
327
+ self._init_rope()
328
+
329
+ def _init_rope(self):
330
+ if self.config.rope_scaling is None:
331
+ self.rotary_emb = NGMERotaryEmbedding(
332
+ self.head_dim, max_position_embeddings=self.max_position_embeddings
333
+ )
334
+ else:
335
+ scaling_type = self.config.rope_scaling["type"]
336
+ scaling_factor = self.config.rope_scaling["factor"]
337
+ if scaling_type == "linear":
338
+ self.rotary_emb = NGMELinearScalingRotaryEmbedding(
339
+ self.head_dim,
340
+ max_position_embeddings=self.max_position_embeddings,
341
+ scaling_factor=scaling_factor,
342
+ )
343
+ elif scaling_type == "dynamic":
344
+ self.rotary_emb = NGMEDynamicNTKScalingRotaryEmbedding(
345
+ self.head_dim,
346
+ max_position_embeddings=self.max_position_embeddings,
347
+ scaling_factor=scaling_factor,
348
+ )
349
+ else:
350
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
351
+
352
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
353
+ return (
354
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
355
+ .transpose(1, 2)
356
+ .contiguous()
357
+ )
358
+
359
+ def forward(
360
+ self,
361
+ hidden_states: torch.Tensor,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ position_ids: Optional[torch.LongTensor] = None,
364
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
365
+ output_attentions: bool = False,
366
+ use_cache: bool = False,
367
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
368
+ bsz, q_len, _ = hidden_states.size()
369
+
370
+ if self.config.pretraining_tp > 1:
371
+ key_value_slicing = (
372
+ self.num_key_value_heads * self.head_dim
373
+ ) // self.config.pretraining_tp
374
+ query_slices = self.q_proj.weight.split(
375
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
376
+ )
377
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
378
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
379
+
380
+ query_states = [
381
+ F.linear(hidden_states, query_slices[i])
382
+ for i in range(self.config.pretraining_tp)
383
+ ]
384
+ query_states = torch.cat(query_states, dim=-1)
385
+
386
+ key_states = [
387
+ F.linear(hidden_states, key_slices[i])
388
+ for i in range(self.config.pretraining_tp)
389
+ ]
390
+ key_states = torch.cat(key_states, dim=-1)
391
+
392
+ value_states = [
393
+ F.linear(hidden_states, value_slices[i])
394
+ for i in range(self.config.pretraining_tp)
395
+ ]
396
+ value_states = torch.cat(value_states, dim=-1)
397
+
398
+ else:
399
+ query_states = self.q_proj(hidden_states)
400
+ key_states = self.k_proj(hidden_states)
401
+ value_states = self.v_proj(hidden_states)
402
+
403
+ query_states = query_states.view(
404
+ bsz, q_len, self.num_heads, self.head_dim
405
+ ).transpose(1, 2)
406
+ key_states = key_states.view(
407
+ bsz, q_len, self.num_key_value_heads, self.head_dim
408
+ ).transpose(1, 2)
409
+ value_states = value_states.view(
410
+ bsz, q_len, self.num_key_value_heads, self.head_dim
411
+ ).transpose(1, 2)
412
+
413
+ kv_seq_len = key_states.shape[-2]
414
+ if past_key_value is not None:
415
+ kv_seq_len += past_key_value[0].shape[-2]
416
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
417
+ query_states, key_states = apply_rotary_pos_emb(
418
+ query_states, key_states, cos, sin, position_ids
419
+ )
420
+
421
+ if past_key_value is not None:
422
+ # reuse k, v, self_attention
423
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
424
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
425
+
426
+ past_key_value = (key_states, value_states) if use_cache else None
427
+
428
+ # repeat k/v heads if n_kv_heads < n_heads
429
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
430
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
431
+
432
+ attn_weights = torch.matmul(
433
+ query_states, key_states.transpose(2, 3)
434
+ ) / math.sqrt(self.head_dim)
435
+
436
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
437
+ raise ValueError(
438
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
439
+ f" {attn_weights.size()}"
440
+ )
441
+
442
+ if attention_mask is not None:
443
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
444
+ raise ValueError(
445
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
446
+ )
447
+ attn_weights = attn_weights + attention_mask
448
+
449
+ # upcast attention to fp32
450
+ attn_weights = nn.functional.softmax(
451
+ attn_weights, dim=-1, dtype=torch.float32
452
+ ).to(query_states.dtype)
453
+ attn_output = torch.matmul(attn_weights, value_states)
454
+
455
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
456
+ raise ValueError(
457
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
458
+ f" {attn_output.size()}"
459
+ )
460
+
461
+ attn_output = attn_output.transpose(1, 2).contiguous()
462
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
463
+
464
+ if self.config.pretraining_tp > 1:
465
+ attn_output = attn_output.split(
466
+ self.hidden_size // self.config.pretraining_tp, dim=2
467
+ )
468
+ o_proj_slices = self.o_proj.weight.split(
469
+ self.hidden_size // self.config.pretraining_tp, dim=1
470
+ )
471
+ attn_output = sum(
472
+ [
473
+ F.linear(attn_output[i], o_proj_slices[i])
474
+ for i in range(self.config.pretraining_tp)
475
+ ]
476
+ )
477
+ else:
478
+ attn_output = self.o_proj(attn_output)
479
+
480
+ if not output_attentions:
481
+ attn_weights = None
482
+
483
+ return attn_output, attn_weights, past_key_value
484
+
485
+
486
+ class NGMEDecoderLayer(nn.Module):
487
+ def __init__(self, config: NGMEConfig):
488
+ super().__init__()
489
+ self.hidden_size = config.hidden_size
490
+ self.self_attn = NGMEAttention(config=config)
491
+ self.mlp = NGMEMLP(config)
492
+ self.input_layernorm = NGMERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
493
+ self.post_attention_layernorm = NGMERMSNorm(
494
+ config.hidden_size, eps=config.rms_norm_eps
495
+ )
496
+
497
+ def forward(
498
+ self,
499
+ hidden_states: torch.Tensor,
500
+ attention_mask: Optional[torch.Tensor] = None,
501
+ position_ids: Optional[torch.LongTensor] = None,
502
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
503
+ output_attentions: Optional[bool] = False,
504
+ use_cache: Optional[bool] = False,
505
+ ) -> Tuple[
506
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
507
+ ]:
508
+ """
509
+ Args:
510
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
511
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
512
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
513
+ output_attentions (`bool`, *optional*):
514
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
515
+ returned tensors for more detail.
516
+ use_cache (`bool`, *optional*):
517
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
518
+ (see `past_key_values`).
519
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
520
+ """
521
+
522
+ residual = hidden_states
523
+
524
+ hidden_states = self.input_layernorm(hidden_states)
525
+
526
+ # Self Attention
527
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
528
+ hidden_states=hidden_states,
529
+ attention_mask=attention_mask,
530
+ position_ids=position_ids,
531
+ past_key_value=past_key_value,
532
+ output_attentions=output_attentions,
533
+ use_cache=use_cache,
534
+ )
535
+ hidden_states = residual + hidden_states
536
+
537
+ # Fully Connected
538
+ residual = hidden_states
539
+ hidden_states = self.post_attention_layernorm(hidden_states)
540
+ hidden_states = self.mlp(hidden_states)
541
+ hidden_states = residual + hidden_states
542
+
543
+ outputs = (hidden_states,)
544
+
545
+ if output_attentions:
546
+ outputs += (self_attn_weights,)
547
+
548
+ if use_cache:
549
+ outputs += (present_key_value,)
550
+
551
+ return outputs
552
+
553
+
554
+ class NGMEPreTrainedModel(PreTrainedModel):
555
+ config_class = NGMEConfig
556
+ base_model_prefix = "model"
557
+ supports_gradient_checkpointing = True
558
+ _no_split_modules = ["NGMEDecoderLayer"]
559
+ _skip_keys_device_placement = "past_key_values"
560
+
561
+ def _init_weights(self, module):
562
+ std = self.config.initializer_range
563
+ if isinstance(module, nn.Linear):
564
+ module.weight.data.normal_(mean=0.0, std=std)
565
+ if module.bias is not None:
566
+ module.bias.data.zero_()
567
+ elif isinstance(module, NGramsEmbedding):
568
+ if self.config.use_small_embedding:
569
+ nn.init.uniform_(module.weight, a=-1e-4, b=1e-4)
570
+ else:
571
+ module.weight.data.normal_(mean=0.0, std=std)
572
+ if module.padding_idx is not None:
573
+ module.weight.data[module.padding_idx].zero_()
574
+
575
+ def _set_gradient_checkpointing(self, module, value=False):
576
+ if isinstance(module, NGMEModel):
577
+ module.gradient_checkpointing = value
578
+
579
+
580
+ class NGMEModel(NGMEPreTrainedModel):
581
+ """
582
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`NGMEDecoderLayer`]
583
+
584
+ Args:
585
+ config: NGMEConfig
586
+ """
587
+
588
+ def __init__(self, config: NGMEConfig):
589
+ super().__init__(config)
590
+ self.padding_idx = config.pad_token_id
591
+ self.vocab_size = config.vocab_size
592
+
593
+ self.embed_tokens = NGramsEmbedding(
594
+ config.vocab_size,
595
+ config.hidden_size,
596
+ self.padding_idx,
597
+ unk_idx=config.unk_idx,
598
+ )
599
+
600
+ if self.config.use_small_embedding:
601
+ self.embed_layer_norm = nn.LayerNorm(config.hidden_size)
602
+
603
+ self.layers = nn.ModuleList(
604
+ [NGMEDecoderLayer(config) for _ in range(config.num_hidden_layers)]
605
+ )
606
+ self.norm = NGMERMSNorm(config.hidden_size, eps=config.rms_norm_eps)
607
+
608
+ self.gradient_checkpointing = False
609
+ # Initialize weights and apply final processing
610
+ self.post_init()
611
+
612
+ def get_input_embeddings(self):
613
+ return self.embed_tokens
614
+
615
+ def set_input_embeddings(self, value):
616
+ self.embed_tokens = value
617
+
618
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
619
+ def _prepare_decoder_attention_mask(
620
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
621
+ ):
622
+ # create causal mask
623
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
624
+ combined_attention_mask = None
625
+ if input_shape[-1] > 1:
626
+ combined_attention_mask = _make_causal_mask(
627
+ input_shape,
628
+ inputs_embeds.dtype,
629
+ device=inputs_embeds.device,
630
+ past_key_values_length=past_key_values_length,
631
+ )
632
+
633
+ if attention_mask is not None:
634
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
635
+ expanded_attn_mask = _expand_mask(
636
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
637
+ ).to(inputs_embeds.device)
638
+ combined_attention_mask = (
639
+ expanded_attn_mask
640
+ if combined_attention_mask is None
641
+ else expanded_attn_mask + combined_attention_mask
642
+ )
643
+
644
+ return combined_attention_mask
645
+
646
+ def forward(
647
+ self,
648
+ input_ids: torch.LongTensor = None,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ position_ids: Optional[torch.LongTensor] = None,
651
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
652
+ inputs_embeds: Optional[torch.FloatTensor] = None,
653
+ use_cache: Optional[bool] = None,
654
+ output_attentions: Optional[bool] = None,
655
+ output_hidden_states: Optional[bool] = None,
656
+ return_dict: Optional[bool] = None,
657
+ ngram_sequences: List[torch.Tensor] = [],
658
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
659
+ output_attentions = (
660
+ output_attentions
661
+ if output_attentions is not None
662
+ else self.config.output_attentions
663
+ )
664
+ output_hidden_states = (
665
+ output_hidden_states
666
+ if output_hidden_states is not None
667
+ else self.config.output_hidden_states
668
+ )
669
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
670
+
671
+ return_dict = (
672
+ return_dict if return_dict is not None else self.config.use_return_dict
673
+ )
674
+
675
+ # retrieve input_ids and inputs_embeds
676
+ if input_ids is not None and inputs_embeds is not None:
677
+ raise ValueError(
678
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
679
+ )
680
+ elif input_ids is not None:
681
+ batch_size, seq_length = input_ids.shape
682
+ elif inputs_embeds is not None:
683
+ batch_size, seq_length, _ = inputs_embeds.shape
684
+ else:
685
+ raise ValueError(
686
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
687
+ )
688
+
689
+ seq_length_with_past = seq_length
690
+ past_key_values_length = 0
691
+
692
+ if past_key_values is not None:
693
+ past_key_values_length = past_key_values[0][0].shape[2]
694
+ seq_length_with_past = seq_length_with_past + past_key_values_length
695
+
696
+ if position_ids is None:
697
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
698
+ position_ids = torch.arange(
699
+ past_key_values_length,
700
+ seq_length + past_key_values_length,
701
+ dtype=torch.long,
702
+ device=device,
703
+ )
704
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
705
+ else:
706
+ position_ids = position_ids.view(-1, seq_length).long()
707
+
708
+ if inputs_embeds is None:
709
+ inputs_embeds = self.embed_tokens(input_ids, ngram_sequences)
710
+
711
+ if self.config.use_small_embedding:
712
+ inputs_embeds = self.embed_layer_norm(inputs_embeds)
713
+
714
+ # embed positions
715
+ if attention_mask is None:
716
+ attention_mask = torch.ones(
717
+ (batch_size, seq_length_with_past),
718
+ dtype=torch.bool,
719
+ device=inputs_embeds.device,
720
+ )
721
+ attention_mask = self._prepare_decoder_attention_mask(
722
+ attention_mask,
723
+ (batch_size, seq_length),
724
+ inputs_embeds,
725
+ past_key_values_length,
726
+ )
727
+
728
+ hidden_states = inputs_embeds
729
+
730
+ if self.gradient_checkpointing and self.training:
731
+ if use_cache:
732
+ logger.warning_once(
733
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
734
+ )
735
+ use_cache = False
736
+
737
+ # decoder layers
738
+ all_hidden_states = () if output_hidden_states else None
739
+ all_self_attns = () if output_attentions else None
740
+ next_decoder_cache = () if use_cache else None
741
+
742
+ for idx, decoder_layer in enumerate(self.layers):
743
+ if output_hidden_states:
744
+ all_hidden_states += (hidden_states,)
745
+
746
+ past_key_value = (
747
+ past_key_values[idx] if past_key_values is not None else None
748
+ )
749
+
750
+ if self.gradient_checkpointing and self.training:
751
+
752
+ def create_custom_forward(module):
753
+ def custom_forward(*inputs):
754
+ # None for past_key_value
755
+ return module(*inputs, output_attentions, None)
756
+
757
+ return custom_forward
758
+
759
+ layer_outputs = torch.utils.checkpoint.checkpoint(
760
+ create_custom_forward(decoder_layer),
761
+ hidden_states,
762
+ attention_mask,
763
+ position_ids,
764
+ None,
765
+ )
766
+ else:
767
+ layer_outputs = decoder_layer(
768
+ hidden_states,
769
+ attention_mask=attention_mask,
770
+ position_ids=position_ids,
771
+ past_key_value=past_key_value,
772
+ output_attentions=output_attentions,
773
+ use_cache=use_cache,
774
+ )
775
+
776
+ hidden_states = layer_outputs[0]
777
+
778
+ if use_cache:
779
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
780
+
781
+ if output_attentions:
782
+ all_self_attns += (layer_outputs[1],)
783
+
784
+ hidden_states = self.norm(hidden_states)
785
+
786
+ # add hidden states from the last decoder layer
787
+ if output_hidden_states:
788
+ all_hidden_states += (hidden_states,)
789
+
790
+ next_cache = next_decoder_cache if use_cache else None
791
+ if not return_dict:
792
+ return tuple(
793
+ v
794
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
795
+ if v is not None
796
+ )
797
+ return BaseModelOutputWithPast(
798
+ last_hidden_state=hidden_states,
799
+ past_key_values=next_cache,
800
+ hidden_states=all_hidden_states,
801
+ attentions=all_self_attns,
802
+ )
803
+
804
+
805
+ class NGMEForCausalLM(NGMEPreTrainedModel):
806
+ _tied_weights_keys = ["lm_head.weight"]
807
+
808
+ def __init__(self, config):
809
+ super().__init__(config)
810
+ self.model = NGMEModel(config)
811
+ self.vocab_size = config.vocab_size
812
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
813
+
814
+ self.tokenizer: Optional[NGMETokenizer] = None
815
+
816
+ # Create weight tensor for the vocab weights and ignore unk_token weights
817
+ weigth_tensor = torch.ones(config.vocab_size)
818
+ weigth_tensor[config.unk_idx] = 0
819
+ self.loss_fct = nn.CrossEntropyLoss(weight=weigth_tensor)
820
+
821
+ # Initialize weights and apply final processing
822
+ self.post_init()
823
+
824
+ def get_input_embeddings(self):
825
+ return self.model.embed_tokens
826
+
827
+ def set_input_embeddings(self, value):
828
+ self.model.embed_tokens = value
829
+
830
+ def get_output_embeddings(self):
831
+ return self.lm_head
832
+
833
+ def set_output_embeddings(self, new_embeddings):
834
+ self.lm_head = new_embeddings
835
+
836
+ def set_decoder(self, decoder):
837
+ self.model = decoder
838
+
839
+ def get_decoder(self):
840
+ return self.model
841
+
842
+ def _collect_ngram_labels(
843
+ self,
844
+ unigram_labels: torch.LongTensor,
845
+ label_gram_2_sequence: Optional[torch.LongTensor] = None,
846
+ label_gram_3_sequence: Optional[torch.LongTensor] = None,
847
+ # label_gram_4_sequence: Optional[torch.LongTensor] = None,
848
+ label_target_gram_2_sequence: Optional[torch.LongTensor] = None,
849
+ label_target_gram_3_sequence: Optional[torch.LongTensor] = None,
850
+ # label_target_gram_4_sequence: Optional[torch.LongTensor] = None,
851
+ ):
852
+ ngram_labels = [unigram_labels[..., 1:].contiguous()]
853
+
854
+ if label_gram_2_sequence is not None:
855
+ if label_target_gram_2_sequence is not None:
856
+ two_gram_labels = label_target_gram_2_sequence[..., 1:].contiguous()
857
+ else:
858
+ two_gram_labels = shift_with_pad(
859
+ label_gram_2_sequence, 2, unigram_labels
860
+ )
861
+ ngram_labels.append(two_gram_labels)
862
+
863
+ if label_gram_3_sequence is not None:
864
+ if label_target_gram_3_sequence is not None:
865
+ three_gram_labels = label_target_gram_3_sequence[..., 1:].contiguous()
866
+ else:
867
+ three_gram_labels = shift_with_pad(
868
+ label_gram_3_sequence, 3, unigram_labels
869
+ )
870
+ ngram_labels.append(three_gram_labels)
871
+
872
+ # if label_gram_4_sequence is not None:
873
+ # if label_target_gram_4_sequence is not None:
874
+ # four_gram_labels = label_target_gram_4_sequence[..., 1:].contiguous()
875
+ # else:
876
+ # four_gram_labels = shift_with_pad(
877
+ # label_gram_4_sequence, 4, unigram_labels
878
+ # )
879
+ # ngram_labels.append(four_gram_labels)
880
+
881
+ return ngram_labels
882
+
883
+ def forward(
884
+ self,
885
+ input_ids: torch.LongTensor = None,
886
+ attention_mask: Optional[torch.Tensor] = None,
887
+ position_ids: Optional[torch.LongTensor] = None,
888
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
889
+ inputs_embeds: Optional[torch.FloatTensor] = None,
890
+ labels: Optional[torch.LongTensor] = None,
891
+ use_cache: Optional[bool] = None,
892
+ output_attentions: Optional[bool] = None,
893
+ output_hidden_states: Optional[bool] = None,
894
+ return_dict: Optional[bool] = None,
895
+ label_gram_2_sequence: Optional[torch.LongTensor] = None,
896
+ label_gram_3_sequence: Optional[torch.LongTensor] = None,
897
+ label_gram_4_sequence: Optional[torch.LongTensor] = None,
898
+ label_target_gram_2_sequence: Optional[torch.LongTensor] = None,
899
+ label_target_gram_3_sequence: Optional[torch.LongTensor] = None,
900
+ label_target_gram_4_sequence: Optional[torch.LongTensor] = None,
901
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
902
+ output_attentions = (
903
+ output_attentions
904
+ if output_attentions is not None
905
+ else self.config.output_attentions
906
+ )
907
+ output_hidden_states = (
908
+ output_hidden_states
909
+ if output_hidden_states is not None
910
+ else self.config.output_hidden_states
911
+ )
912
+ return_dict = (
913
+ return_dict if return_dict is not None else self.config.use_return_dict
914
+ )
915
+
916
+ ngram_sequences = collect_n_gram_sequences(
917
+ gram_2_sequence=label_gram_2_sequence,
918
+ gram_3_sequence=label_gram_3_sequence,
919
+ gram_4_sequence=label_gram_4_sequence,
920
+ )
921
+
922
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
923
+ outputs = self.model(
924
+ input_ids=input_ids,
925
+ attention_mask=attention_mask,
926
+ position_ids=position_ids,
927
+ past_key_values=past_key_values,
928
+ inputs_embeds=inputs_embeds,
929
+ use_cache=use_cache,
930
+ output_attentions=output_attentions,
931
+ output_hidden_states=output_hidden_states,
932
+ return_dict=return_dict,
933
+ ngram_sequences=ngram_sequences,
934
+ )
935
+
936
+ hidden_states = outputs[0]
937
+ if self.config.pretraining_tp > 1:
938
+ lm_head_slices = self.lm_head.weight.split(
939
+ self.vocab_size // self.config.pretraining_tp, dim=0
940
+ )
941
+ logits = [
942
+ F.linear(hidden_states, lm_head_slices[i])
943
+ for i in range(self.config.pretraining_tp)
944
+ ]
945
+ logits = torch.cat(logits, dim=-1)
946
+ else:
947
+ logits = self.lm_head(hidden_states)
948
+
949
+ logits = logits.float()
950
+
951
+ loss = None
952
+ if labels is not None:
953
+ # Shift so that tokens < n predict n
954
+ shift_logits = logits[..., :-1, :].contiguous()
955
+
956
+ ngram_labels = self._collect_ngram_labels(
957
+ unigram_labels=labels,
958
+ label_gram_2_sequence=label_gram_2_sequence,
959
+ label_gram_3_sequence=label_gram_3_sequence,
960
+ # label_gram_4_sequence=label_gram_4_sequence,
961
+ label_target_gram_2_sequence=label_target_gram_2_sequence,
962
+ label_target_gram_3_sequence=label_target_gram_3_sequence,
963
+ # label_target_gram_4_sequence=label_target_gram_4_sequence,
964
+ )
965
+
966
+ shift_labels = torch.stack(ngram_labels, dim=0)
967
+ shift_labels = soft_n_hot(
968
+ shift_labels, self.config.vocab_size, strategy="exp"
969
+ )
970
+
971
+ # Flatten the tokens
972
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
973
+ shift_labels = shift_labels.view(-1, shift_labels.size(-1))
974
+ # Enable model parallelism
975
+ shift_labels = shift_labels.to(shift_logits.device)
976
+ loss = self.loss_fct(shift_logits, shift_labels)
977
+
978
+ if not return_dict:
979
+ output = (logits,) + outputs[1:]
980
+ return (loss,) + output if loss is not None else output
981
+
982
+ return CausalLMOutputWithPast(
983
+ loss=loss,
984
+ logits=logits,
985
+ past_key_values=outputs.past_key_values,
986
+ hidden_states=outputs.hidden_states,
987
+ attentions=outputs.attentions,
988
+ )
989
+
990
+ def prepare_inputs_for_generation(
991
+ self,
992
+ input_ids,
993
+ past_key_values=None,
994
+ attention_mask=None,
995
+ inputs_embeds=None,
996
+ **kwargs,
997
+ ):
998
+ if past_key_values:
999
+ input_ids = input_ids[:, -1:]
1000
+
1001
+ position_ids = kwargs.get("position_ids", None)
1002
+ if attention_mask is not None and position_ids is None:
1003
+ # create position_ids on the fly for batch generation
1004
+ position_ids = attention_mask.long().cumsum(-1) - 1
1005
+ position_ids.masked_fill_(attention_mask == 0, 1)
1006
+ if past_key_values:
1007
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1008
+
1009
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1010
+ if inputs_embeds is not None and past_key_values is None:
1011
+ model_inputs = {"inputs_embeds": inputs_embeds}
1012
+ else:
1013
+ model_inputs = {"input_ids": input_ids}
1014
+
1015
+ model_inputs.update(
1016
+ {
1017
+ "position_ids": position_ids,
1018
+ "past_key_values": past_key_values,
1019
+ "use_cache": kwargs.get("use_cache"),
1020
+ "attention_mask": attention_mask,
1021
+ }
1022
+ )
1023
+ return model_inputs
1024
+
1025
+ @staticmethod
1026
+ def _reorder_cache(past_key_values, beam_idx):
1027
+ reordered_past = ()
1028
+ for layer_past in past_key_values:
1029
+ reordered_past += (
1030
+ tuple(
1031
+ past_state.index_select(0, beam_idx.to(past_state.device))
1032
+ for past_state in layer_past
1033
+ ),
1034
+ )
1035
+ return reordered_past
1036
+
1037
+ def sample(
1038
+ self,
1039
+ input_ids: torch.LongTensor,
1040
+ logits_processor: Optional[LogitsProcessorList] = None,
1041
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1042
+ logits_warper: Optional[LogitsProcessorList] = None,
1043
+ max_length: Optional[int] = None,
1044
+ pad_token_id: Optional[int] = None,
1045
+ eos_token_id: Optional[Union[int, List[int]]] = None,
1046
+ output_attentions: Optional[bool] = None,
1047
+ output_hidden_states: Optional[bool] = None,
1048
+ output_scores: Optional[bool] = None,
1049
+ return_dict_in_generate: Optional[bool] = None,
1050
+ synced_gpus: bool = False,
1051
+ streamer=None,
1052
+ **model_kwargs,
1053
+ ) -> Union[SampleOutput, torch.LongTensor]:
1054
+ if not hasattr(self, "tokenizer"):
1055
+ raise ValueError(
1056
+ "You are trying to sample from a model that does not have a tokenizer."
1057
+ "Add a tokenizer as an attribute of your model (either manually or automatically)."
1058
+ )
1059
+
1060
+ return sample_ngme(
1061
+ self,
1062
+ input_ids,
1063
+ logits_processor=logits_processor,
1064
+ stopping_criteria=stopping_criteria,
1065
+ logits_warper=logits_warper,
1066
+ max_length=max_length,
1067
+ pad_token_id=pad_token_id,
1068
+ eos_token_id=eos_token_id,
1069
+ output_attentions=output_attentions,
1070
+ output_hidden_states=output_hidden_states,
1071
+ output_scores=output_scores,
1072
+ return_dict_in_generate=return_dict_in_generate,
1073
+ synced_gpus=synced_gpus,
1074
+ streamer=streamer,
1075
+ **model_kwargs,
1076
+ )
1077
+
1078
+
1079
+ class NGMEForSequenceClassification(NGMEPreTrainedModel):
1080
+ def __init__(self, config):
1081
+ super().__init__(config)
1082
+ self.num_labels = config.num_labels
1083
+ self.model = NGMEModel(config)
1084
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1085
+
1086
+ # Initialize weights and apply final processing
1087
+ self.post_init()
1088
+
1089
+ def get_input_embeddings(self):
1090
+ return self.model.embed_tokens
1091
+
1092
+ def set_input_embeddings(self, value):
1093
+ self.model.embed_tokens = value
1094
+
1095
+ def forward(
1096
+ self,
1097
+ input_ids: torch.LongTensor = None,
1098
+ attention_mask: Optional[torch.Tensor] = None,
1099
+ position_ids: Optional[torch.LongTensor] = None,
1100
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1101
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1102
+ labels: Optional[torch.LongTensor] = None,
1103
+ use_cache: Optional[bool] = None,
1104
+ output_attentions: Optional[bool] = None,
1105
+ output_hidden_states: Optional[bool] = None,
1106
+ return_dict: Optional[bool] = None,
1107
+ label_gram_2_sequence: Optional[torch.LongTensor] = None,
1108
+ label_gram_3_sequence: Optional[torch.LongTensor] = None,
1109
+ # label_gram_4_sequence: Optional[torch.LongTensor] = None,
1110
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1111
+ r"""
1112
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1113
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1114
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1115
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1116
+ """
1117
+ return_dict = (
1118
+ return_dict if return_dict is not None else self.config.use_return_dict
1119
+ )
1120
+
1121
+ ngram_sequences = collect_n_gram_sequences(
1122
+ gram_2_sequence=label_gram_2_sequence,
1123
+ gram_3_sequence=label_gram_3_sequence,
1124
+ # gram_4_sequence=label_gram_4_sequence,
1125
+ )
1126
+
1127
+ transformer_outputs = self.model(
1128
+ input_ids,
1129
+ attention_mask=attention_mask,
1130
+ position_ids=position_ids,
1131
+ past_key_values=past_key_values,
1132
+ inputs_embeds=inputs_embeds,
1133
+ use_cache=use_cache,
1134
+ output_attentions=output_attentions,
1135
+ output_hidden_states=output_hidden_states,
1136
+ return_dict=return_dict,
1137
+ ngram_sequences=ngram_sequences,
1138
+ )
1139
+ hidden_states = transformer_outputs[0]
1140
+ logits = self.score(hidden_states)
1141
+
1142
+ if input_ids is not None:
1143
+ batch_size = input_ids.shape[0]
1144
+ else:
1145
+ batch_size = inputs_embeds.shape[0]
1146
+
1147
+ if self.config.pad_token_id is None and batch_size != 1:
1148
+ raise ValueError(
1149
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1150
+ )
1151
+ if self.config.pad_token_id is None:
1152
+ sequence_lengths = -1
1153
+ else:
1154
+ if input_ids is not None:
1155
+ sequence_lengths = (
1156
+ torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1157
+ ).to(logits.device)
1158
+ else:
1159
+ sequence_lengths = -1
1160
+
1161
+ pooled_logits = logits[
1162
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1163
+ ]
1164
+
1165
+ loss = None
1166
+
1167
+ if labels is not None:
1168
+ labels = labels.to(logits.device)
1169
+ if self.config.problem_type is None:
1170
+ if self.num_labels == 1:
1171
+ self.config.problem_type = "regression"
1172
+ elif self.num_labels > 1 and (
1173
+ labels.dtype == torch.long or labels.dtype == torch.int
1174
+ ):
1175
+ self.config.problem_type = "single_label_classification"
1176
+ else:
1177
+ self.config.problem_type = "multi_label_classification"
1178
+
1179
+ if self.config.problem_type == "regression":
1180
+ loss_fct = MSELoss()
1181
+ if self.num_labels == 1:
1182
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1183
+ else:
1184
+ loss = loss_fct(pooled_logits, labels)
1185
+ elif self.config.problem_type == "single_label_classification":
1186
+ loss_fct = CrossEntropyLoss()
1187
+ loss = loss_fct(
1188
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1189
+ )
1190
+ elif self.config.problem_type == "multi_label_classification":
1191
+ loss_fct = BCEWithLogitsLoss()
1192
+ loss = loss_fct(pooled_logits, labels)
1193
+ if not return_dict:
1194
+ output = (pooled_logits,) + transformer_outputs[1:]
1195
+ return ((loss,) + output) if loss is not None else output
1196
+
1197
+ return SequenceClassifierOutputWithPast(
1198
+ loss=loss,
1199
+ logits=pooled_logits,
1200
+ past_key_values=transformer_outputs.past_key_values,
1201
+ hidden_states=transformer_outputs.hidden_states,
1202
+ attentions=transformer_outputs.attentions,
1203
+ )
ngme.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, List
3
+ from functools import lru_cache
4
+ from itertools import chain, tee
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ n_dists = {
10
+ 0: [1],
11
+ 1: [0.4, 0.6],
12
+ 2: [0.2, 0.3, 0.5],
13
+ 3: [0.1, 0.2, 0.3, 0.4],
14
+ 4: [0.1, 0.15, 0.2, 0.25, 0.3],
15
+ }
16
+
17
+ strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2}
18
+
19
+ def pad_sequence(
20
+ sequence,
21
+ n,
22
+ pad_left=False,
23
+ pad_right=False,
24
+ left_pad_symbol=None,
25
+ right_pad_symbol=None,
26
+ ):
27
+ """Copied from NLTK"""
28
+ sequence = iter(sequence)
29
+ if pad_left:
30
+ sequence = chain((left_pad_symbol,) * (n - 1), sequence)
31
+ if pad_right:
32
+ sequence = chain(sequence, (right_pad_symbol,) * (n - 1))
33
+ return sequence
34
+
35
+ def ngrams(sequence, n, **kwargs):
36
+ """Copied from NLTK"""
37
+ sequence = pad_sequence(sequence, n, **kwargs)
38
+
39
+ # Creates the sliding window, of n no. of items.
40
+ # `iterables` is a tuple of iterables where each iterable is a window of n items.
41
+ iterables = tee(sequence, n)
42
+
43
+ for i, sub_iterable in enumerate(iterables): # For each window,
44
+ for _ in range(i): # iterate through every order of ngrams
45
+ next(sub_iterable, None) # generate the ngrams within the window.
46
+ return zip(*iterables) # Unpack and flattens the iterables.
47
+
48
+
49
+ @lru_cache(maxsize=5)
50
+ def soft_dist(n):
51
+ return [1 / n] * n
52
+
53
+
54
+ @lru_cache(maxsize=5)
55
+ def n_dist(n: int, strategy: str) -> list[float]:
56
+ """dist of ngram weight is logarithmic"""
57
+ ns = list(range(1, n + 1))
58
+ xs = list(map(strats[strategy], ns))
59
+ result = list(map(lambda x: x / sum(xs), xs))
60
+ return result
61
+
62
+ def soft_n_hot(
63
+ input,
64
+ num_classes: int,
65
+ strategy: Optional[str],
66
+ ):
67
+
68
+ shape = list(input.size())[1:]
69
+
70
+ shape.append(num_classes)
71
+
72
+ ret = torch.zeros(shape).to(input.device)
73
+
74
+ if strategy:
75
+ soft_labels = n_dist(input.size(0), strategy)
76
+ else:
77
+ soft_labels = [1] * input.size(0)
78
+
79
+ for i, t in enumerate(input):
80
+ ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i])
81
+
82
+ return ret
83
+
84
+
85
+ def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None):
86
+
87
+ shape = list(t.size())
88
+
89
+ if ngram_sequences is not None:
90
+ shape.append(num_clases)
91
+ ret = torch.zeros(shape).to(t.device)
92
+ ret.scatter_(-1, t.unsqueeze(-1), 1)
93
+ for seq in ngram_sequences:
94
+ if unk_idx is not None:
95
+ mask = torch.eq(seq, unk_idx)
96
+ seq[mask] = t[mask]
97
+ ret.scatter_(-1, seq.unsqueeze(-1), 1)
98
+ return ret
99
+
100
+ elif len(shape) == 2:
101
+ return F.one_hot(t, num_classes=num_clases).float()
102
+ else:
103
+ shape = shape[1:]
104
+ shape.append(num_clases)
105
+ ret = torch.zeros(shape).to(t.device)
106
+ # Expect that first dimension is for all n-grams
107
+ for seq in t:
108
+ ret.scatter_(-1, seq.unsqueeze(-1), 1)
109
+
110
+ return ret
111
+
112
+
113
+ class NGramsEmbedding(torch.nn.Embedding):
114
+ """N-Hot encoder"""
115
+
116
+ def __init__(
117
+ self,
118
+ num_embeddings: int,
119
+ embedding_dim: int,
120
+ padding_idx: Optional[int] = None,
121
+ max_norm: Optional[float] = None,
122
+ norm_type: float = 2,
123
+ scale_grad_by_freq: bool = False,
124
+ sparse: bool = False,
125
+ _weight: Optional[torch.Tensor] = None,
126
+ device=None,
127
+ dtype=None,
128
+ unk_idx: Optional[int] = None
129
+ ) -> None:
130
+ super().__init__(
131
+ num_embeddings,
132
+ embedding_dim,
133
+ padding_idx=padding_idx,
134
+ max_norm=max_norm,
135
+ norm_type=norm_type,
136
+ scale_grad_by_freq=scale_grad_by_freq,
137
+ sparse=sparse,
138
+ _weight=_weight,
139
+ device=device,
140
+ dtype=dtype,
141
+ )
142
+
143
+ self.num_classes = num_embeddings
144
+ self.unk_idx = unk_idx
145
+
146
+ def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None):
147
+ return self._forward(
148
+ n_hot(input, self.num_classes, ngram_sequences, self.unk_idx)
149
+ )
150
+
151
+ def _forward(self, n_hot: torch.Tensor) -> torch.Tensor:
152
+ return F.linear(n_hot, self.weight.t())
153
+
154
+
155
+ def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]:
156
+ sequences = []
157
+ for n in range(2, len(kwargs)+2):
158
+ s = kwargs[f"gram_{n}_sequence"]
159
+ if s is not None:
160
+ sequences.append(s)
161
+ else:
162
+ break
163
+
164
+ return sequences
165
+
166
+ def shift_with_pad(target_tensor, n, from_tensor):
167
+ shifted = target_tensor[:, n:]
168
+
169
+ seq_size = target_tensor.size(1) - 1
170
+
171
+ missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device)
172
+
173
+ # Pad with missing idxs from unigram tensor
174
+ shifted = torch.concat(
175
+ (shifted, from_tensor.index_select(1, missing_idxs)), dim=1
176
+ )
177
+
178
+ return shifted
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8da3e30e422932a51e7c566af705ca91e71d6381a43a58683edc0a71e0516502
3
+ size 546774413
sampling.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch import nn
7
+ from transformers import BatchEncoding
8
+ from transformers.generation.logits_process import (
9
+ LogitsProcessorList,
10
+ )
11
+ from transformers.generation.stopping_criteria import (
12
+ StoppingCriteriaList,
13
+ validate_stopping_criteria,
14
+ )
15
+
16
+ from transformers.generation.utils import SampleOutput, SampleEncoderDecoderOutput, SampleDecoderOnlyOutput
17
+
18
+ def sample(
19
+ self,
20
+ input_ids: torch.LongTensor,
21
+ logits_processor: Optional[LogitsProcessorList] = None,
22
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
23
+ logits_warper: Optional[LogitsProcessorList] = None,
24
+ max_length: Optional[int] = None,
25
+ pad_token_id: Optional[int] = None,
26
+ eos_token_id: Optional[Union[int, List[int]]] = None,
27
+ output_attentions: Optional[bool] = None,
28
+ output_hidden_states: Optional[bool] = None,
29
+ output_scores: Optional[bool] = None,
30
+ return_dict_in_generate: Optional[bool] = None,
31
+ synced_gpus: Optional[bool] = False,
32
+ **model_kwargs,
33
+ ) -> Union[SampleOutput, torch.LongTensor]:
34
+
35
+ if type(input_ids) in [dict, BatchEncoding]:
36
+ input_ids, ngram_sequences = input_ids["input_ids"], input_ids
37
+ del ngram_sequences["input_ids"]
38
+ del ngram_sequences["attention_mask"]
39
+ else:
40
+ ngram_sequences = {}
41
+
42
+ # init values
43
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
44
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
45
+ if max_length is not None:
46
+ warnings.warn(
47
+ "`max_length` is deprecated in this function, use"
48
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
49
+ UserWarning,
50
+ )
51
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
52
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
53
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
54
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
55
+ if isinstance(eos_token_id, int):
56
+ eos_token_id = [eos_token_id]
57
+
58
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
59
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
60
+ output_attentions = (
61
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
62
+ )
63
+ output_hidden_states = (
64
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
65
+ )
66
+ return_dict_in_generate = (
67
+ return_dict_in_generate
68
+ if return_dict_in_generate is not None
69
+ else self.generation_config.return_dict_in_generate
70
+ )
71
+
72
+ # init attention / hidden states / scores tuples
73
+ scores = () if (return_dict_in_generate and output_scores) else None
74
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
75
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
76
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
77
+
78
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
79
+ if return_dict_in_generate and self.config.is_encoder_decoder:
80
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
81
+ encoder_hidden_states = (
82
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
83
+ )
84
+
85
+ # keep track of which sequences are already finished
86
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
87
+
88
+ this_peer_finished = False # used by synced_gpus only
89
+ # auto-regressive generation
90
+ while True:
91
+ if synced_gpus:
92
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
93
+ # The following logic allows an early break if all peers finished generating their sequence
94
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
95
+ # send 0.0 if we finished, 1.0 otherwise
96
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
97
+ # did all peers finish? the reduced sum will be 0.0 then
98
+ if this_peer_finished_flag.item() == 0.0:
99
+ break
100
+
101
+ # prepare model inputs
102
+ model_inputs = {"input_ids": input_ids}
103
+
104
+ # forward pass to get next token
105
+ outputs = self(
106
+ **model_inputs,
107
+ return_dict=True,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ **ngram_sequences
111
+ )
112
+
113
+ if synced_gpus and this_peer_finished:
114
+ continue # don't waste resources running the code we don't need
115
+
116
+ next_token_logits = outputs.logits[:, -1, :]
117
+
118
+ # pre-process distribution
119
+ next_token_scores = logits_processor(input_ids, next_token_logits)
120
+ next_token_scores = logits_warper(input_ids, next_token_scores)
121
+
122
+ # Store scores, attentions and hidden_states when required
123
+ if return_dict_in_generate:
124
+ if output_scores:
125
+ scores += (next_token_scores,)
126
+ if output_attentions:
127
+ decoder_attentions += (
128
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
129
+ )
130
+ if self.config.is_encoder_decoder:
131
+ cross_attentions += (outputs.cross_attentions,)
132
+
133
+ if output_hidden_states:
134
+ decoder_hidden_states += (
135
+ (outputs.decoder_hidden_states,)
136
+ if self.config.is_encoder_decoder
137
+ else (outputs.hidden_states,)
138
+ )
139
+
140
+ # sample
141
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
142
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
143
+
144
+ # finished sentences should have their next token be a padding token
145
+ if eos_token_id is not None:
146
+ if pad_token_id is None:
147
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
148
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
149
+
150
+ # update generated ids, model inputs, and length for next step
151
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
152
+ decoded = self.tokenizer.batch_decode(input_ids)[0]
153
+ encoded = self.tokenizer(
154
+ decoded, return_tensors="pt", return_ngram_sequences=True
155
+ )
156
+ input_ids = encoded.input_ids.to(self.device)
157
+
158
+ ngram_sequences = {}
159
+
160
+ if "label_gram_2_sequence" in encoded:
161
+ ngram_sequences["label_gram_2_sequence"] = encoded["label_gram_2_sequence"].to(self.device)
162
+
163
+ if "label_gram_3_sequence" in encoded:
164
+ ngram_sequences["label_gram_3_sequence"] = encoded["label_gram_3_sequence"].to(self.device)
165
+
166
+ if "label_gram_4_sequence" in encoded:
167
+ ngram_sequences["label_gram_4_sequence"] = encoded["label_gram_4_sequence"].to(self.device)
168
+
169
+ model_kwargs = self._update_model_kwargs_for_generation(
170
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
171
+ )
172
+
173
+ # if eos_token was found in one sentence, set sentence to finished
174
+ if eos_token_id_tensor is not None:
175
+ unfinished_sequences = unfinished_sequences.mul(
176
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
177
+ )
178
+
179
+ # stop when each sentence is finished, or if we exceed the maximum length
180
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
181
+ if not synced_gpus:
182
+ break
183
+ else:
184
+ this_peer_finished = True
185
+
186
+ if return_dict_in_generate:
187
+ if self.config.is_encoder_decoder:
188
+ return SampleEncoderDecoderOutput(
189
+ sequences=input_ids,
190
+ scores=scores,
191
+ encoder_attentions=encoder_attentions,
192
+ encoder_hidden_states=encoder_hidden_states,
193
+ decoder_attentions=decoder_attentions,
194
+ cross_attentions=cross_attentions,
195
+ decoder_hidden_states=decoder_hidden_states,
196
+ )
197
+ else:
198
+ return SampleDecoderOnlyOutput(
199
+ sequences=input_ids,
200
+ scores=scores,
201
+ attentions=decoder_attentions,
202
+ hidden_states=decoder_hidden_states,
203
+ )
204
+ else:
205
+ return input_ids
tokenization_ngme.py ADDED
@@ -0,0 +1,1303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import unittest
3
+ import os
4
+ from collections import Counter
5
+ from typing import Dict, List, Optional, Sized, Tuple, Union, Any
6
+
7
+ import torch
8
+ import numpy as np
9
+ from tokenizers import AddedToken
10
+ from transformers import PreTrainedTokenizer
11
+ from transformers.tokenization_utils_base import (
12
+ BatchEncoding,
13
+ EncodedInput,
14
+ TruncationStrategy,
15
+ )
16
+ from transformers.utils import logging
17
+ from transformers.utils.generic import PaddingStrategy, TensorType, to_py_obj
18
+
19
+ from .ngme import ngrams as ngram_tokenizer
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ def load_vocab(vocab_file):
25
+ """Loads a vocabulary file into a dictionary."""
26
+ with open(vocab_file, "r", encoding="utf-8") as f:
27
+ vocab = json.load(f)
28
+ return vocab
29
+
30
+
31
+ def all_same(items):
32
+ return all(x == items[0] for x in items)
33
+
34
+
35
+ class NGMETokenizer(PreTrainedTokenizer):
36
+ model_input_names = ["input_ids", "attention_mask"]
37
+ vocab_file = "vocab.json"
38
+ vocab_files_names = {"vocab_file": vocab_file}
39
+
40
+ def __init__(
41
+ self,
42
+ vocab_file,
43
+ ngram: Optional[int] = None,
44
+ eos_token="\n",
45
+ pad_token="\n",
46
+ unk_token="<unk>",
47
+ eod_token="<eod>",
48
+ **kwargs,
49
+ ):
50
+ super().__init__(
51
+ eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, **kwargs
52
+ )
53
+
54
+ eos_token = (
55
+ AddedToken(
56
+ eos_token,
57
+ lstrip=False,
58
+ rstrip=False,
59
+ )
60
+ if isinstance(eos_token, str)
61
+ else eos_token
62
+ )
63
+ pad_token = (
64
+ AddedToken(
65
+ pad_token,
66
+ lstrip=False,
67
+ rstrip=False,
68
+ )
69
+ if isinstance(pad_token, str)
70
+ else pad_token
71
+ )
72
+ unk_token = (
73
+ AddedToken(
74
+ unk_token,
75
+ lstrip=False,
76
+ rstrip=False,
77
+ )
78
+ if isinstance(unk_token, str)
79
+ else unk_token
80
+ )
81
+
82
+ self._ngram2word2idx = {}
83
+ self._ngram2idx2word = {}
84
+ self._current_max_idx = 0
85
+ self._frequencies: Counter = Counter()
86
+ self.ngram = ngram
87
+
88
+ self._load_from_file(vocab_file)
89
+
90
+ for n in range(2, self.ngram + 1):
91
+ self.model_input_names.append(f"ngram_{n}_sequence")
92
+
93
+ # TODO: COuld also be whitespace if n+1gram dont contain it
94
+ self._special_token = "Ġ"
95
+ assert self._special_token not in self._ngram2word2idx[1]
96
+
97
+ def __call__(self, *args, **kwargs) -> BatchEncoding:
98
+ if "return_ngram_sequences" in kwargs:
99
+ return_ngram_sequences = kwargs["return_ngram_sequences"]
100
+ del kwargs["return_ngram_sequences"]
101
+ else:
102
+ return_ngram_sequences = False
103
+
104
+ # We could check the args and kwargs beforehand and apply extra ngram sequences based on it, but
105
+ # we let HF handle all logic and reverse take the char sequence from the ids
106
+ batch_encoding = super().__call__(*args, **kwargs)
107
+
108
+ if return_ngram_sequences:
109
+ ngram_sequences = self.create_ngram_sequences(args[0])
110
+ # NOTE: This is pretty hard coded, lets just throw an error if the user wants to use it differently
111
+
112
+ if "padding" in kwargs:
113
+ if kwargs["padding"] == "max_length":
114
+ padded_sequences = {}
115
+ for n_key, sequence in ngram_sequences.items():
116
+ padded_sequences[n_key] = self.pad_sequence_right(
117
+ sequence,
118
+ len(batch_encoding["input_ids"][0]),
119
+ self.pad_token_id,
120
+ )
121
+
122
+ ngram_sequences = padded_sequences
123
+ elif kwargs["padding"] == "longest":
124
+ padded_sequences = {}
125
+ for n_key, sequence in ngram_sequences.items():
126
+ padded_sequences[n_key] = self.pad_sequence_right(
127
+ sequence,
128
+ max([len(seq) for seq in sequence]),
129
+ self.pad_token_id,
130
+ )
131
+ ngram_sequences = padded_sequences
132
+
133
+ else:
134
+ raise ValueError(
135
+ f"Padding {kwargs['padding']} not supported for ngram sequences"
136
+ )
137
+
138
+ if "truncation" in kwargs and kwargs["truncation"]:
139
+ truncated_sequences = {}
140
+ for n_key, sequence in ngram_sequences.items():
141
+ truncated_sequences[n_key] = self.truncate_sequence_right(
142
+ sequence, len(batch_encoding["input_ids"][0])
143
+ )
144
+ ngram_sequences = truncated_sequences
145
+
146
+ batch_encoding.update(ngram_sequences)
147
+
148
+ if "return_tensors" in kwargs:
149
+ batch_encoding.convert_to_tensors(kwargs["return_tensors"])
150
+
151
+ return batch_encoding
152
+
153
+ def pad_sequence_right(
154
+ self, batched_sequence: List[List[int]], padding_length: int, padding_value: int
155
+ ) -> List[List[int]]:
156
+ padded_sequence = []
157
+ for sequence in batched_sequence:
158
+ padded_sequence.append(
159
+ sequence + [padding_value] * (padding_length - len(sequence))
160
+ )
161
+ return padded_sequence
162
+
163
+ def truncate_sequence_right(
164
+ self, batched_sequence: List[List[int]], max_length: int
165
+ ) -> List[List[int]]:
166
+ truncated_sequence = []
167
+ for sequence in batched_sequence:
168
+ truncated_sequence.append(sequence[:max_length])
169
+ return truncated_sequence
170
+
171
+ def create_ngram_sequences(self, char_sequences: List[str]) -> Dict[str, Any]:
172
+ ngram_sequences_output = {}
173
+
174
+ if isinstance(char_sequences, str):
175
+ char_sequences = [char_sequences]
176
+
177
+ for n in range(2, self.ngram + 1):
178
+ ngram_sequences = []
179
+ for char_sequence in char_sequences:
180
+ ngrams = ["".join(ngram) for ngram in ngram_tokenizer(char_sequence, n)]
181
+ # Fill in the front with existign unigrams, for same length and
182
+ # because the timestep t should not look ahead
183
+ ngrams = list(char_sequence[: n - 1]) + ngrams
184
+ encoded_ngrams = self.encode(ngrams) if len(ngrams) > 0 else []
185
+ ngram_sequences.append(encoded_ngrams)
186
+
187
+ ngram_sequences_output[f"label_gram_{n}_sequence"] = ngram_sequences
188
+
189
+ return ngram_sequences_output
190
+
191
+ def _seq_size(self, encoded) -> Union[int, List[int]]:
192
+ if isinstance(encoded, torch.Tensor):
193
+ encoded = encoded.tolist()
194
+
195
+ if isinstance(encoded[0], list):
196
+ return [len(enc) for enc in encoded]
197
+
198
+ return len(encoded)
199
+
200
+ def _load_from_file(self, filename: str):
201
+ """Loads a dictionary from a file."""
202
+ vocab_file = load_vocab(filename)
203
+ if not self.ngram:
204
+ self.ngram = vocab_file["ngram"]
205
+
206
+ if "\n" not in vocab_file["vocab"]:
207
+ self._add_ngram("\n", 1)
208
+
209
+ for token in vocab_file["vocab"]:
210
+ if token["ngram"] <= self.ngram:
211
+ self._add_ngram(token["token"], token["ngram"])
212
+ self._frequencies.update({token["token"]: token["frequency"]})
213
+
214
+ def _add_ngram(self, word, ngram: int) -> int:
215
+ """Add a new n-gram token to the dictionary."""
216
+ self._frequencies.update({word: 1})
217
+
218
+ if ngram not in self._ngram2idx2word:
219
+ self._ngram2idx2word[ngram] = {self._current_max_idx: word}
220
+ self._ngram2word2idx[ngram] = {word: self._current_max_idx}
221
+ self._current_max_idx += 1
222
+ else:
223
+ if word not in self._ngram2word2idx[ngram]:
224
+ self._ngram2idx2word[ngram][self._current_max_idx] = word
225
+ self._ngram2word2idx[ngram][word] = self._current_max_idx
226
+ self._current_max_idx += 1
227
+
228
+ return self._ngram2word2idx[ngram][word]
229
+
230
+ def _is_contiguous(self):
231
+ vocab_size = len(self)
232
+ return list(range(vocab_size)) == [idx for idx, token in self._get_all_tokens()]
233
+
234
+ def _get_all_tokens(self):
235
+ """Returns all tokens in the dictionary."""
236
+ for ngram in range(1, self.ngram + 1):
237
+ for idx, token in self._ngram2idx2word[ngram].items():
238
+ yield idx, token
239
+
240
+ def save_vocabulary(
241
+ self, save_directory: str, filename_prefix: Optional[str] = None
242
+ ) -> Tuple[str]:
243
+ filename = os.path.join(
244
+ save_directory,
245
+ (filename_prefix + "-" if filename_prefix else ""),
246
+ self.vocab_file,
247
+ )
248
+
249
+ index = 0
250
+ vocab = {"ngram": self.ngram, "vocab": []}
251
+
252
+ for ngram in range(1, self.ngram + 1):
253
+ for idx, token in self._ngram2idx2word[ngram].items():
254
+ if index != idx:
255
+ index = idx
256
+
257
+ try:
258
+ frequency = self._frequencies[token]
259
+ except KeyError:
260
+ frequency = -1
261
+
262
+ index += 1
263
+ vocab["vocab"].append(
264
+ {
265
+ "token": token,
266
+ "index": idx,
267
+ "frequency": frequency,
268
+ "ngram": ngram,
269
+ }
270
+ )
271
+
272
+ with open(filename, "w", encoding="utf-8") as writer:
273
+ json.dump(vocab, writer, indent=4, ensure_ascii=False)
274
+
275
+ return (filename,)
276
+
277
+ @property
278
+ def vocab_size(self) -> int:
279
+ return self._current_max_idx
280
+
281
+ def _tokenize(self, text: str) -> List[str]:
282
+ return list(text)
283
+
284
+ def get_idx(self, token: str, ngram: Optional[int] = None) -> int:
285
+ if ngram:
286
+ if token in self._ngram2word2idx[ngram]:
287
+ return self._ngram2word2idx[ngram][token]
288
+ else:
289
+ return self._ngram2word2idx[1]["<unk>"]
290
+
291
+ for ngram in range(1, self.ngram + 1):
292
+ if token in self._ngram2word2idx[ngram]:
293
+ return self._ngram2word2idx[ngram][token]
294
+
295
+ return self._ngram2word2idx[1]["<unk>"]
296
+
297
+ def _convert_ngram_tokens_to_ids(self, ngram_tokens: List[str]) -> List[int]:
298
+ return [self.get_idx(token) for token in ngram_tokens]
299
+
300
+ def convert_tokens_to_ids(self, tokens: List[str]):
301
+ if not tokens:
302
+ return []
303
+
304
+ if isinstance(tokens, str):
305
+ return self.get_idx(tokens)
306
+
307
+ return self._convert_ngram_tokens_to_ids(tokens)
308
+
309
+ def _convert_id_to_token(self, index: int) -> str:
310
+ return self.get_item_for_index(index)
311
+
312
+ def get_item_for_index(self, idx) -> str:
313
+ """Return the token for a given index."""
314
+ for idxs in self._ngram2idx2word.values():
315
+ if idx in idxs:
316
+ return idxs[idx]
317
+
318
+ return self.unk_token
319
+
320
+ def convert_tokens_to_string(self, tokens):
321
+ return "".join(tokens)
322
+
323
+ def create_weight_tensor(self) -> torch.Tensor:
324
+ unked_freqs = self._frequencies.most_common()
325
+
326
+ t = torch.ones(len(self))
327
+
328
+ for token, freq in unked_freqs:
329
+ t[self._ngram2word2idx[self._token_to_n_order(token)][token]] = freq
330
+
331
+ # Ensure the only whitespace character is weighted
332
+ t[self._ngram2word2idx[1][" "]] = 1.0
333
+
334
+ max_t = max(t)
335
+
336
+ normed_weights = torch.tensor([(1 - (x / (max_t + 1))).item() for x in t])
337
+
338
+ marker_tokens = [self.get_idx("<unk>", n) for n in range(1, self.ngram + 1)]
339
+ marker_tokens.extend(
340
+ [self.get_idx("<start>", n) for n in range(1, self.ngram + 1)]
341
+ )
342
+ # Instead of explicit ignore indexes, we use the weight vector and set target idxs to 0
343
+ for marker in marker_tokens:
344
+ normed_weights[marker] = 0
345
+
346
+ return normed_weights
347
+
348
+ def _token_to_n_order(self, token: str) -> int:
349
+ """Get N-gram order for a token"""
350
+ for n_gram, word2idx in self._ngram2word2idx.items():
351
+ if token in word2idx:
352
+ return n_gram
353
+
354
+ return 0
355
+
356
+
357
+ class GPTNGMETokenizer(PreTrainedTokenizer):
358
+ model_input_names = ["input_ids", "attention_mask"]
359
+ vocab_file = "vocab.json"
360
+ vocab_files_names = {"vocab_file": vocab_file}
361
+
362
+ def __init__(
363
+ self, vocab_file, eos_token="\n", pad_token="\n", unk_token="<unk>", **kwargs
364
+ ):
365
+ eos_token = (
366
+ AddedToken(
367
+ eos_token,
368
+ lstrip=False,
369
+ rstrip=False,
370
+ )
371
+ if isinstance(eos_token, str)
372
+ else eos_token
373
+ )
374
+ pad_token = (
375
+ AddedToken(
376
+ pad_token,
377
+ lstrip=False,
378
+ rstrip=False,
379
+ )
380
+ if isinstance(pad_token, str)
381
+ else pad_token
382
+ )
383
+ unk_token = (
384
+ AddedToken(
385
+ unk_token,
386
+ lstrip=False,
387
+ rstrip=False,
388
+ )
389
+ if isinstance(unk_token, str)
390
+ else unk_token
391
+ )
392
+
393
+ super().__init__(
394
+ eos_token=eos_token, pad_token=pad_token, unk_token=unk_token, **kwargs
395
+ )
396
+
397
+ self._ngram2word2idx = {}
398
+ self._ngram2idx2word = {}
399
+ self._current_max_idx = 0
400
+ self._frequencies: Counter = Counter()
401
+
402
+ self._load_from_file(vocab_file)
403
+
404
+ def _load_from_file(self, filename: str):
405
+ """Loads a dictionary from a file."""
406
+ vocab_file = load_vocab(filename)
407
+ self.ngram = vocab_file["ngram"]
408
+
409
+ if "\n" not in vocab_file["vocab"]:
410
+ self._add_ngram("\n", 1)
411
+
412
+ for token in vocab_file["vocab"]:
413
+ self._add_ngram(token["token"], token["ngram"])
414
+ self._frequencies.update({token["token"]: token["frequency"]})
415
+
416
+ def _add_ngram(self, word, ngram: int) -> int:
417
+ """Add a new n-gram token to the dictionary."""
418
+ self._frequencies.update({word: 1})
419
+
420
+ if ngram not in self._ngram2idx2word:
421
+ self._ngram2idx2word[ngram] = {self._current_max_idx: word}
422
+ self._ngram2word2idx[ngram] = {word: self._current_max_idx}
423
+ self._current_max_idx += 1
424
+ else:
425
+ if word not in self._ngram2word2idx[ngram]:
426
+ self._ngram2idx2word[ngram][self._current_max_idx] = word
427
+ self._ngram2word2idx[ngram][word] = self._current_max_idx
428
+ self._current_max_idx += 1
429
+
430
+ return self._ngram2word2idx[ngram][word]
431
+
432
+ def _is_contiguous(self):
433
+ vocab_size = len(self)
434
+ return list(range(vocab_size)) == [idx for idx, token in self._get_all_tokens()]
435
+
436
+ def _get_all_tokens(self):
437
+ """Returns all tokens in the dictionary."""
438
+ for ngram in range(1, self.ngram + 1):
439
+ for idx, token in self._ngram2idx2word[ngram].items():
440
+ yield idx, token
441
+
442
+ def save_vocabulary(
443
+ self, save_directory: str, filename_prefix: Optional[str] = None
444
+ ) -> Tuple[str]:
445
+ filename = os.path.join(
446
+ save_directory,
447
+ (filename_prefix + "-" if filename_prefix else ""),
448
+ self.vocab_file,
449
+ )
450
+
451
+ index = 0
452
+ vocab = {"ngram": self.ngram, "vocab": []}
453
+
454
+ for ngram in range(1, self.ngram + 1):
455
+ for idx, token in self._ngram2idx2word[ngram].items():
456
+ if index != idx:
457
+ index = idx
458
+
459
+ try:
460
+ frequency = self._frequencies[token]
461
+ except KeyError:
462
+ frequency = -1
463
+
464
+ index += 1
465
+ vocab["vocab"].append(
466
+ {
467
+ "token": token,
468
+ "index": idx,
469
+ "frequency": frequency,
470
+ "ngram": ngram,
471
+ }
472
+ )
473
+
474
+ with open(filename, "w", encoding="utf-8") as writer:
475
+ json.dump(vocab, writer, indent=4, ensure_ascii=False)
476
+
477
+ return (filename,)
478
+
479
+ @property
480
+ def vocab_size(self) -> int:
481
+ return self._current_max_idx
482
+
483
+ def retokenize(self, input_ids, *args, **kwargs):
484
+ decoded = self.convert_ids_to_tokens(input_ids)
485
+ sequence = "".join(decoded)
486
+ new_decoded = self(sequence, *args, **kwargs).input_ids
487
+ return new_decoded
488
+
489
+ def _tokenize(self, text):
490
+ ngram_sequences = []
491
+ for n in range(1, self.ngram + 1):
492
+ words = ["<start>" for _ in range(1, n)]
493
+ words.extend(list(text))
494
+
495
+ tokens = []
496
+ for i, word in enumerate(ngram_tokenizer(words, n)):
497
+ if "<start>" in word:
498
+ word = [w for w in list(word) if w != "<start>"]
499
+ tokens.append("".join(word))
500
+
501
+ ngram_sequences.append(tokens)
502
+
503
+ return ngram_sequences
504
+
505
+ def get_idx(self, token: str, ngram: Optional[int] = None) -> int:
506
+ if ngram:
507
+ if token in self._ngram2word2idx[ngram]:
508
+ return self._ngram2word2idx[ngram][token]
509
+ else:
510
+ return self._ngram2word2idx[1]["<unk>"]
511
+
512
+ for ngram in range(1, self.ngram + 1):
513
+ if token in self._ngram2word2idx[ngram]:
514
+ return self._ngram2word2idx[ngram][token]
515
+
516
+ return self._ngram2word2idx[1]["<unk>"]
517
+
518
+ def _convert_ngram_tokens_to_ids(self, ngram_tokens: List[str]) -> List[int]:
519
+ return [self.get_idx(token) for token in ngram_tokens]
520
+
521
+ def convert_tokens_to_ids(self, tokens: List[List[str]]):
522
+ if not tokens:
523
+ return []
524
+
525
+ if isinstance(tokens, str):
526
+ return self.get_idx(tokens)
527
+
528
+ return [
529
+ self._convert_ngram_tokens_to_ids(ngram_tokens) for ngram_tokens in tokens
530
+ ]
531
+
532
+ def _convert_id_to_token(self, index: int) -> str:
533
+ return self.get_item_for_index(index)
534
+
535
+ def get_item_for_index(self, idx) -> str:
536
+ """Return the token for a given index."""
537
+ for idxs in self._ngram2idx2word.values():
538
+ if idx in idxs:
539
+ return idxs[idx]
540
+
541
+ return self.unk_token
542
+
543
+ def _decode(
544
+ self, token_ids: List[List[int]], skip_special_tokens: bool = False, **kwargs
545
+ ) -> str:
546
+ return "".join(self.convert_ids_to_tokens(token_ids[0]))
547
+
548
+ def debug_decode(self, token_ids: List[List[int]]):
549
+ for n in range(1, self.ngram + 1):
550
+ print(f"{n}-gram: {self.convert_ids_to_tokens(token_ids[n-1])}")
551
+
552
+ def _pad(
553
+ self,
554
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
555
+ max_length: Optional[int] = None,
556
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
557
+ pad_to_multiple_of: Optional[int] = None,
558
+ return_attention_mask: Optional[bool] = None,
559
+ ) -> dict:
560
+ """
561
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
562
+
563
+ Args:
564
+ encoded_inputs:
565
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
566
+ max_length: maximum length of the returned list and optionally padding length (see below).
567
+ Will truncate by taking into account the special tokens.
568
+ padding_strategy: PaddingStrategy to use for padding.
569
+
570
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
571
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
572
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
573
+ The tokenizer padding sides are defined in self.padding_side:
574
+
575
+ - 'left': pads on the left of the sequences
576
+ - 'right': pads on the right of the sequences
577
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
578
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
579
+ `>= 7.5` (Volta).
580
+ return_attention_mask:
581
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
582
+ """
583
+ # encoded_inputs == one sample -> List[List[int]]
584
+
585
+ # Load from model defaults
586
+ if return_attention_mask is None:
587
+ return_attention_mask = "attention_mask" in self.model_input_names
588
+
589
+ required_input = encoded_inputs[self.model_input_names[0]]
590
+ # PHA: Check if we have a list of list of list, then we unpack
591
+ if (
592
+ len(required_input) != 0
593
+ and isinstance(required_input[0], list)
594
+ and isinstance(required_input[0][0], list)
595
+ ):
596
+ required_input = required_input[0]
597
+
598
+ if padding_strategy == PaddingStrategy.LONGEST:
599
+ max_length = len(required_input)
600
+
601
+ if (
602
+ max_length is not None
603
+ and pad_to_multiple_of is not None
604
+ and (max_length % pad_to_multiple_of != 0)
605
+ ):
606
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
607
+
608
+ needs_to_be_padded = (
609
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
610
+ and len(required_input[0]) != max_length
611
+ )
612
+
613
+ # Initialize attention mask if not present.
614
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
615
+ if len(required_input) == 0:
616
+ encoded_inputs["attention_mask"] = []
617
+ else:
618
+ encoded_inputs["attention_mask"] = [1] * len(required_input[0])
619
+
620
+ if needs_to_be_padded:
621
+ difference = max_length - len(required_input[0])
622
+
623
+ if self.padding_side == "right":
624
+ if return_attention_mask:
625
+ encoded_inputs["attention_mask"] = (
626
+ encoded_inputs["attention_mask"] + [0] * difference
627
+ )
628
+ if "token_type_ids" in encoded_inputs:
629
+ encoded_inputs["token_type_ids"] = (
630
+ encoded_inputs["token_type_ids"]
631
+ + [self.pad_token_type_id] * difference
632
+ )
633
+ if "special_tokens_mask" in encoded_inputs:
634
+ encoded_inputs["special_tokens_mask"] = (
635
+ encoded_inputs["special_tokens_mask"] + [1] * difference
636
+ )
637
+ for i in range(len(encoded_inputs[self.model_input_names[0]])):
638
+ encoded_inputs[self.model_input_names[0]][i] = (
639
+ required_input[i] + [self.pad_token_id] * difference
640
+ )
641
+ elif self.padding_side == "left":
642
+ if return_attention_mask:
643
+ encoded_inputs["attention_mask"] = [
644
+ 0
645
+ ] * difference + encoded_inputs["attention_mask"]
646
+ if "token_type_ids" in encoded_inputs:
647
+ encoded_inputs["token_type_ids"] = [
648
+ self.pad_token_type_id
649
+ ] * difference + encoded_inputs["token_type_ids"]
650
+ if "special_tokens_mask" in encoded_inputs:
651
+ encoded_inputs["special_tokens_mask"] = [
652
+ 1
653
+ ] * difference + encoded_inputs["special_tokens_mask"]
654
+
655
+ for i in range(len(encoded_inputs[self.model_input_names[0]])):
656
+ encoded_inputs[self.model_input_names[0]][i] = [
657
+ self.pad_token_id
658
+ ] * difference + required_input[i]
659
+ else:
660
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
661
+
662
+ return encoded_inputs
663
+
664
+ def pad(
665
+ self,
666
+ encoded_inputs: Union[
667
+ BatchEncoding,
668
+ List[BatchEncoding],
669
+ Dict[str, EncodedInput],
670
+ Dict[str, List[EncodedInput]],
671
+ List[Dict[str, EncodedInput]],
672
+ ],
673
+ padding: Union[bool, str, PaddingStrategy] = True,
674
+ max_length: Optional[int] = None,
675
+ pad_to_multiple_of: Optional[int] = None,
676
+ return_attention_mask: Optional[bool] = None,
677
+ return_tensors: Optional[Union[str, TensorType]] = None,
678
+ verbose: bool = True,
679
+ ) -> BatchEncoding:
680
+ """
681
+ Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
682
+ in the batch.
683
+
684
+ Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`,
685
+
686
+ `self.pad_token_id` and `self.pad_token_type_id`).
687
+
688
+ Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the
689
+ text followed by a call to the `pad` method to get a padded encoding.
690
+
691
+ <Tip>
692
+
693
+ If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
694
+ result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
695
+ PyTorch tensors, you will lose the specific device of your tensors however.
696
+
697
+ </Tip>
698
+
699
+ Args:
700
+ encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `Dict[str, List[int]]`, `Dict[str, List[List[int]]` or `List[Dict[str, List[int]]]`):
701
+ Tokenized inputs. Can represent one input ([`BatchEncoding`] or `Dict[str, List[int]]`) or a batch of
702
+ tokenized inputs (list of [`BatchEncoding`], *Dict[str, List[List[int]]]* or *List[Dict[str,
703
+ List[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
704
+ collate function.
705
+
706
+ Instead of `List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see
707
+ the note above for the return type.
708
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
709
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
710
+ index) among:
711
+
712
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
713
+ sequence if provided).
714
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
715
+ acceptable input length for the model if that argument is not provided.
716
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
717
+ lengths).
718
+ max_length (`int`, *optional*):
719
+ Maximum length of the returned list and optionally padding length (see above).
720
+ pad_to_multiple_of (`int`, *optional*):
721
+ If set will pad the sequence to a multiple of the provided value.
722
+
723
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
724
+ `>= 7.5` (Volta).
725
+ return_attention_mask (`bool`, *optional*):
726
+ Whether to return the attention mask. If left to the default, will return the attention mask according
727
+ to the specific tokenizer's default, defined by the `return_outputs` attribute.
728
+
729
+ [What are attention masks?](../glossary#attention-mask)
730
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
731
+ If set, will return tensors instead of list of python integers. Acceptable values are:
732
+
733
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
734
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
735
+ - `'np'`: Return Numpy `np.ndarray` objects.
736
+ verbose (`bool`, *optional*, defaults to `True`):
737
+ Whether or not to print more information and warnings.
738
+ """
739
+
740
+ # Problem: The pad function checks if the encoded_inputs is a list or not
741
+ # If it is a list it assumes that we have batches
742
+ # With ngme encoding the input is always a list
743
+
744
+ # If we have a list of dicts, let's convert it in a dict of lists
745
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
746
+ if isinstance(encoded_inputs, (list, tuple)) and isinstance(
747
+ encoded_inputs[0], Mapping
748
+ ):
749
+ encoded_inputs = {
750
+ key: [example[key] for example in encoded_inputs]
751
+ for key in encoded_inputs[0].keys()
752
+ }
753
+
754
+ # The model's main input name, usually `input_ids`, has be passed for padding
755
+ if self.model_input_names[0] not in encoded_inputs:
756
+ raise ValueError(
757
+ "You should supply an encoding or a list of encodings to this method "
758
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
759
+ )
760
+
761
+ required_input = encoded_inputs[self.model_input_names[0]]
762
+
763
+ if required_input is None or (
764
+ isinstance(required_input, Sized) and len(required_input) == 0
765
+ ):
766
+ if return_attention_mask:
767
+ encoded_inputs["attention_mask"] = []
768
+ return encoded_inputs
769
+
770
+ # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
771
+ # and rebuild them afterwards if no return_tensors is specified
772
+ # Note that we lose the specific device the tensor may be on for PyTorch
773
+
774
+ first_element = required_input[0]
775
+ # PHA: First element in ngme is a list of list
776
+ if isinstance(first_element, (list, tuple)):
777
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
778
+ for item in required_input:
779
+ if len(item) != 0:
780
+ first_element = item[0]
781
+ break
782
+ # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
783
+ if not isinstance(first_element, (int, list, tuple)):
784
+ if is_tf_tensor(first_element):
785
+ return_tensors = "tf" if return_tensors is None else return_tensors
786
+ elif is_torch_tensor(first_element):
787
+ return_tensors = "pt" if return_tensors is None else return_tensors
788
+ elif isinstance(first_element, np.ndarray):
789
+ return_tensors = "np" if return_tensors is None else return_tensors
790
+ else:
791
+ raise ValueError(
792
+ f"type of {first_element} unknown: {type(first_element)}. "
793
+ "Should be one of a python, numpy, pytorch or tensorflow object."
794
+ )
795
+
796
+ for key, value in encoded_inputs.items():
797
+ encoded_inputs[key] = to_py_obj(value)
798
+
799
+ # Convert padding_strategy in PaddingStrategy
800
+ padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
801
+ padding=padding, max_length=max_length, verbose=verbose
802
+ )
803
+
804
+ required_input = encoded_inputs[self.model_input_names[0]]
805
+
806
+ if required_input:
807
+ if isinstance(required_input[0], (list, tuple)):
808
+ if len(required_input[0]) > 0 and not isinstance(
809
+ required_input[0][0], (list, tuple)
810
+ ):
811
+ encoded_inputs = self._pad(
812
+ encoded_inputs,
813
+ max_length=max_length,
814
+ padding_strategy=padding_strategy,
815
+ pad_to_multiple_of=pad_to_multiple_of,
816
+ return_attention_mask=return_attention_mask,
817
+ )
818
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
819
+
820
+ batch_size = len(required_input)
821
+ assert all(
822
+ len(v) == batch_size for v in encoded_inputs.values()
823
+ ), "Some items in the output dictionary have a different batch size than others."
824
+
825
+ if padding_strategy == PaddingStrategy.LONGEST:
826
+ max_length = max(len(inputs[0]) for inputs in required_input)
827
+ padding_strategy = PaddingStrategy.MAX_LENGTH
828
+
829
+ batch_outputs = {}
830
+ for i in range(batch_size):
831
+ inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
832
+ outputs = self._pad(
833
+ inputs,
834
+ max_length=max_length,
835
+ padding_strategy=padding_strategy,
836
+ pad_to_multiple_of=pad_to_multiple_of,
837
+ return_attention_mask=return_attention_mask,
838
+ )
839
+
840
+ for key, value in outputs.items():
841
+ if key not in batch_outputs:
842
+ batch_outputs[key] = []
843
+ batch_outputs[key].append(value)
844
+
845
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
846
+
847
+ def prepare_for_model(
848
+ self,
849
+ ids: List[int],
850
+ pair_ids: Optional[List[int]] = None,
851
+ add_special_tokens: bool = True,
852
+ padding: Union[bool, str, PaddingStrategy] = False,
853
+ truncation: Union[bool, str, TruncationStrategy] = None,
854
+ max_length: Optional[int] = None,
855
+ stride: int = 0,
856
+ pad_to_multiple_of: Optional[int] = None,
857
+ return_tensors: Optional[Union[str, TensorType]] = None,
858
+ return_token_type_ids: Optional[bool] = None,
859
+ return_attention_mask: Optional[bool] = None,
860
+ return_overflowing_tokens: bool = False,
861
+ return_special_tokens_mask: bool = False,
862
+ return_offsets_mapping: bool = False,
863
+ return_length: bool = False,
864
+ verbose: bool = True,
865
+ prepend_batch_axis: bool = False,
866
+ **kwargs,
867
+ ) -> BatchEncoding:
868
+ """
869
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
870
+ adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
871
+ manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
872
+ different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
873
+ overflowing tokens. Such a combination of arguments will raise an error.
874
+ Args:
875
+ ids (`List[int]`):
876
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
877
+ `convert_tokens_to_ids` methods.
878
+ pair_ids (`List[int]`, *optional*):
879
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
880
+ and `convert_tokens_to_ids` methods.
881
+ """
882
+
883
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
884
+ (
885
+ padding_strategy,
886
+ truncation_strategy,
887
+ max_length,
888
+ kwargs,
889
+ ) = self._get_padding_truncation_strategies(
890
+ padding=padding,
891
+ truncation=truncation,
892
+ max_length=max_length,
893
+ pad_to_multiple_of=pad_to_multiple_of,
894
+ verbose=verbose,
895
+ **kwargs,
896
+ )
897
+
898
+ pair = bool(pair_ids is not None)
899
+
900
+ if len(ids) == 0:
901
+ len_ids = 0
902
+ else:
903
+ len_ids = len(ids[0])
904
+
905
+ if pair and len(pair_ids) == 0:
906
+ len_pair_ids = 0
907
+ elif pair and len(pair_ids) > 0:
908
+ len_pair_ids = len(pair_ids[0])
909
+ else:
910
+ len_pair_ids = 0
911
+
912
+ if return_token_type_ids and not add_special_tokens:
913
+ raise ValueError(
914
+ "Asking to return token_type_ids while setting add_special_tokens to False "
915
+ "results in an undefined behavior. Please set add_special_tokens to True or "
916
+ "set return_token_type_ids to None."
917
+ )
918
+
919
+ if (
920
+ return_overflowing_tokens
921
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
922
+ and pair_ids is not None
923
+ ):
924
+ raise ValueError(
925
+ "Not possible to return overflowing tokens for pair of sequences with the "
926
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
927
+ "for instance `only_second` or `only_first`."
928
+ )
929
+
930
+ # Load from model defaults
931
+ if return_token_type_ids is None:
932
+ return_token_type_ids = "token_type_ids" in self.model_input_names
933
+ if return_attention_mask is None:
934
+ return_attention_mask = "attention_mask" in self.model_input_names
935
+
936
+ encoded_inputs = {}
937
+
938
+ # Compute the total size of the returned encodings
939
+ total_len = (
940
+ len_ids
941
+ + len_pair_ids
942
+ + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
943
+ )
944
+
945
+ # Truncation: Handle max sequence length
946
+ overflowing_tokens = []
947
+ if (
948
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
949
+ and max_length
950
+ and total_len > max_length
951
+ ):
952
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(
953
+ ids,
954
+ pair_ids=pair_ids,
955
+ num_tokens_to_remove=total_len - max_length,
956
+ truncation_strategy=truncation_strategy,
957
+ stride=stride,
958
+ )
959
+
960
+ if return_overflowing_tokens:
961
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
962
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
963
+
964
+ # Add special tokens
965
+ if add_special_tokens:
966
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
967
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
968
+ else:
969
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
970
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
971
+
972
+ # Build output dictionary
973
+ encoded_inputs["input_ids"] = sequence
974
+ if return_token_type_ids:
975
+ encoded_inputs["token_type_ids"] = token_type_ids
976
+ if return_special_tokens_mask:
977
+ if add_special_tokens:
978
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(
979
+ ids, pair_ids
980
+ )
981
+ else:
982
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
983
+
984
+ # Check lengths
985
+ self._eventual_warn_about_too_long_sequence(
986
+ encoded_inputs["input_ids"], max_length, verbose
987
+ )
988
+
989
+ # Padding
990
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
991
+ encoded_inputs = self.pad(
992
+ encoded_inputs,
993
+ max_length=max_length,
994
+ padding=padding_strategy.value,
995
+ pad_to_multiple_of=pad_to_multiple_of,
996
+ return_attention_mask=return_attention_mask,
997
+ )
998
+
999
+ if return_length:
1000
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
1001
+
1002
+ batch_outputs = BatchEncoding(
1003
+ encoded_inputs,
1004
+ tensor_type=return_tensors,
1005
+ prepend_batch_axis=prepend_batch_axis,
1006
+ )
1007
+
1008
+ return batch_outputs
1009
+
1010
+ def build_inputs_with_special_tokens(
1011
+ self,
1012
+ token_ids_0: List[List[int]],
1013
+ token_ids_1: Optional[List[List[int]]] = None,
1014
+ ) -> List[List[int]]:
1015
+ """
1016
+ Concatenate nested ngram sequences.
1017
+
1018
+ Args:
1019
+ token_ids_0 (`List[List[int]]`): The first tokenized sequence.
1020
+ token_ids_1 (`List[List[int]]`, *optional*): The second tokenized sequence.
1021
+
1022
+ Returns:
1023
+ `List[List[int]]`: The model input with special tokens.
1024
+ """
1025
+ if token_ids_1 is None or len(token_ids_1) == 0:
1026
+ return token_ids_0
1027
+
1028
+ if len(token_ids_0) == 0:
1029
+ return token_ids_1
1030
+
1031
+ return np.concatenate(
1032
+ (np.array(token_ids_0), np.array(token_ids_1)), axis=1
1033
+ ).tolist()
1034
+
1035
+ def truncate_sequences(
1036
+ self,
1037
+ ids: List[int],
1038
+ pair_ids: Optional[List[int]] = None,
1039
+ num_tokens_to_remove: int = 0,
1040
+ truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
1041
+ stride: int = 0,
1042
+ ) -> Tuple[List[int], List[int], List[int]]:
1043
+ """
1044
+ Truncates a sequence pair in-place following the strategy.
1045
+ Args:
1046
+ ids (`List[int]`):
1047
+ Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
1048
+ `convert_tokens_to_ids` methods.
1049
+ pair_ids (`List[int]`, *optional*):
1050
+ Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
1051
+ and `convert_tokens_to_ids` methods.
1052
+ num_tokens_to_remove (`int`, *optional*, defaults to 0):
1053
+ Number of tokens to remove using the truncation strategy.
1054
+ truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
1055
+ The strategy to follow for truncation. Can be:
1056
+ - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
1057
+ maximum acceptable input length for the model if that argument is not provided. This will truncate
1058
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a
1059
+ batch of pairs) is provided.
1060
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
1061
+ maximum acceptable input length for the model if that argument is not provided. This will only
1062
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
1063
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
1064
+ maximum acceptable input length for the model if that argument is not provided. This will only
1065
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
1066
+ - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
1067
+ than the model maximum admissible input size).
1068
+ stride (`int`, *optional*, defaults to 0):
1069
+ If set to a positive number, the overflowing tokens returned will contain some tokens from the main
1070
+ sequence returned. The value of this argument defines the number of additional tokens.
1071
+ Returns:
1072
+ `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of
1073
+ overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair
1074
+ of sequences (or a batch of pairs) is provided.
1075
+ """
1076
+ if num_tokens_to_remove <= 0:
1077
+ return ids, pair_ids, []
1078
+
1079
+ if not isinstance(truncation_strategy, TruncationStrategy):
1080
+ truncation_strategy = TruncationStrategy(truncation_strategy)
1081
+
1082
+ overflowing_tokens = []
1083
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
1084
+ truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
1085
+ ):
1086
+ ids = np.array(ids)
1087
+
1088
+ # PHA: I think we only truncate with longest first
1089
+ if ids.shape[1] > num_tokens_to_remove:
1090
+ window_len = min(ids.shape[1], stride + num_tokens_to_remove)
1091
+ if self.truncation_side == "left":
1092
+ overflowing_tokens = ids[:, :window_len]
1093
+ ids = ids[:, num_tokens_to_remove:]
1094
+ elif self.truncation_side == "right":
1095
+ overflowing_tokens = ids[-window_len:]
1096
+ ids = ids[:, :-num_tokens_to_remove]
1097
+ else:
1098
+ raise ValueError(
1099
+ f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'."
1100
+ )
1101
+
1102
+ ids = ids.tolist()
1103
+
1104
+ else:
1105
+ error_msg = (
1106
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
1107
+ f"but the first sequence has a length {len(ids)}. "
1108
+ )
1109
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST:
1110
+ error_msg = (
1111
+ error_msg + "Please select another truncation strategy than "
1112
+ f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
1113
+ )
1114
+ logger.error(error_msg)
1115
+ elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
1116
+ logger.warning(
1117
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
1118
+ f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
1119
+ "truncation strategy. So the returned list will always be empty even if some "
1120
+ "tokens have been removed."
1121
+ )
1122
+ ids = np.array(ids)
1123
+ pair_ids = np.array(pair_ids)
1124
+
1125
+ for _ in range(num_tokens_to_remove):
1126
+ if pair_ids is None or ids.shape[1] > pair_ids.shape[1]:
1127
+ if self.truncation_side == "right":
1128
+ ids = ids[:, :-1]
1129
+ elif self.truncation_side == "left":
1130
+ ids = ids[:, 1:]
1131
+ else:
1132
+ raise ValueError(
1133
+ "invalid truncation strategy:" + str(self.truncation_side)
1134
+ )
1135
+ else:
1136
+ if self.truncation_side == "right":
1137
+ pair_ids = pair_ids[:, :-1]
1138
+ elif self.truncation_side == "left":
1139
+ pair_ids = pair_ids[:, 1:]
1140
+ else:
1141
+ raise ValueError(
1142
+ "invalid truncation strategy:" + str(self.truncation_side)
1143
+ )
1144
+
1145
+ ids = ids.tolist()
1146
+ pair_ids = pair_ids.tolist()
1147
+
1148
+ elif (
1149
+ truncation_strategy == TruncationStrategy.ONLY_SECOND
1150
+ and pair_ids is not None
1151
+ ):
1152
+ raise NotImplementedError(
1153
+ "PHA: I think we only truncate with longest first"
1154
+ )
1155
+ if len(pair_ids) > num_tokens_to_remove:
1156
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
1157
+ if self.truncation_side == "right":
1158
+ overflowing_tokens = pair_ids[-window_len:]
1159
+ pair_ids = pair_ids[:-num_tokens_to_remove]
1160
+ elif self.truncation_side == "left":
1161
+ overflowing_tokens = pair_ids[:window_len]
1162
+ pair_ids = pair_ids[num_tokens_to_remove:]
1163
+ else:
1164
+ raise ValueError(
1165
+ "invalid truncation strategy:" + str(self.truncation_side)
1166
+ )
1167
+ else:
1168
+ logger.error(
1169
+ f"We need to remove {num_tokens_to_remove} to truncate the input "
1170
+ f"but the second sequence has a length {len(pair_ids)}. "
1171
+ f"Please select another truncation strategy than {truncation_strategy}, "
1172
+ "for instance 'longest_first' or 'only_first'."
1173
+ )
1174
+
1175
+ return (ids, pair_ids, overflowing_tokens)
1176
+
1177
+ def _token_to_n_order(self, token: str) -> int:
1178
+ """Get N-gram order for a token"""
1179
+ for n_gram, word2idx in self._ngram2word2idx.items():
1180
+ if token in word2idx:
1181
+ return n_gram
1182
+
1183
+ return 0
1184
+
1185
+ def create_weight_tensor(self) -> torch.Tensor:
1186
+ unked_freqs = self._frequencies.most_common()
1187
+
1188
+ t = torch.ones(len(self))
1189
+
1190
+ for token, freq in unked_freqs:
1191
+ t[self._ngram2word2idx[self._token_to_n_order(token)][token]] = freq
1192
+
1193
+ # Ensure the only whitespace character is weighted
1194
+ t[self._ngram2word2idx[1][" "]] = 1.0
1195
+
1196
+ normed_weights = torch.tensor([(1 - (x / (max(t) + 1))).item() for x in t])
1197
+
1198
+ marker_tokens = [self.get_idx("<unk>", n) for n in range(1, self.ngram + 1)]
1199
+ marker_tokens.extend(
1200
+ [self.get_idx("<start>", n) for n in range(1, self.ngram + 1)]
1201
+ )
1202
+ # Instead of explicit ignore indexes, we use the weight vector and set target idxs to 0
1203
+ for marker in marker_tokens:
1204
+ normed_weights[marker] = 0
1205
+
1206
+ return normed_weights
1207
+
1208
+
1209
+ class TestTokenizer(unittest.TestCase):
1210
+ def test_one(self):
1211
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/1-gram-babylm.json"
1212
+
1213
+ t = NGMETokenizer(vocab_file)
1214
+ self.assertEqual(t.get_idx("<unk>", 1), 1)
1215
+
1216
+ result = t("hello world")
1217
+ self.assertEqual(result.input_ids, [16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12])
1218
+
1219
+ result = t("<unk>")
1220
+ self.assertEqual(result.input_ids, [1, 13, 5, 24, 1])
1221
+
1222
+ result = t(["hello world", "<unk>"])
1223
+ self.assertEqual(
1224
+ result.input_ids,
1225
+ [[16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12], [1, 13, 5, 24, 1]],
1226
+ )
1227
+
1228
+ def test_three(self):
1229
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1230
+
1231
+ t = NGMETokenizer(vocab_file)
1232
+
1233
+ result = t("hello world")
1234
+ self.assertEqual(result.input_ids, [16, 3, 11, 11, 8, 2, 21, 8, 9, 11, 12])
1235
+
1236
+ result = t("hello", return_ngram_sequences=True)
1237
+
1238
+ result = t(["hello world"], return_ngram_sequences=True)
1239
+ two_gram_expected = [[16, 208, 229, 230, 231, 1, 1, 312, 257, 499, 306]]
1240
+
1241
+ self.assertEqual(result["gram_2_sequence"], two_gram_expected)
1242
+ self.assertEqual(t._ngram2idx2word[1][16], "h")
1243
+ self.assertEqual(t._ngram2idx2word[2][208], "he")
1244
+ self.assertEqual(t._ngram2idx2word[2][229], "el")
1245
+
1246
+ def test_unks(self):
1247
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/2-gram-wiki-en.json"
1248
+ t = NGMETokenizer(vocab_file)
1249
+ result = t("OciVDjöShG", return_ngram_sequences=True, return_tensors="pt")
1250
+
1251
+ def test_decode(self):
1252
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1253
+ t = NGMETokenizer(vocab_file)
1254
+ decoded = t.decode(208)
1255
+ assert decoded == "he"
1256
+
1257
+ def test_padding(self):
1258
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1259
+ t = NGMETokenizer(vocab_file)
1260
+ result = t(
1261
+ "hello world",
1262
+ return_tensors="pt",
1263
+ padding="max_length",
1264
+ max_length=20,
1265
+ return_ngram_sequences=True,
1266
+ )
1267
+
1268
+ self.assertEqual(result.input_ids.shape, (1, 20))
1269
+ self.assertEqual(result.gram_2_sequence.shape, (1, 20))
1270
+ self.assertEqual(result.gram_3_sequence.shape, (1, 20))
1271
+
1272
+ def test_truncation(self):
1273
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1274
+ t = NGMETokenizer(vocab_file)
1275
+
1276
+ result = t(
1277
+ "hello world",
1278
+ return_tensors="pt",
1279
+ truncation=True,
1280
+ max_length=5,
1281
+ return_ngram_sequences=True,
1282
+ )
1283
+ self.assertEqual(result.input_ids.shape, (1, 5))
1284
+ self.assertEqual(result.gram_2_sequence.shape, (1, 5))
1285
+
1286
+ def test_padding_and_truncation(self):
1287
+ vocab_file = "/home/phmaker/Projects/ngme/vocabs/3-gram-babylm.json"
1288
+ t = NGMETokenizer(vocab_file)
1289
+
1290
+ result = t(
1291
+ ["four", "something longer"],
1292
+ return_tensors="pt",
1293
+ padding="max_length",
1294
+ truncation=True,
1295
+ max_length=5,
1296
+ return_ngram_sequences=True,
1297
+ )
1298
+ self.assertEqual(result.input_ids.shape, (2, 5))
1299
+ self.assertEqual(result.gram_2_sequence.shape, (2, 5))
1300
+
1301
+
1302
+ if __name__ == "__main__":
1303
+ unittest.main()