Ubuntu commited on
Commit
6c84452
1 Parent(s): e55133d
Files changed (2) hide show
  1. config.json +1 -7
  2. modeling_glm.py +0 -1304
config.json CHANGED
@@ -2,13 +2,7 @@
2
  "architectures": [
3
  "GlmForCausalLM"
4
  ],
5
- "auto_map": {
6
- "AutoModel": "modeling_glm.GlmModel",
7
- "AutoModelForCausalLM": "modeling_glm.GlmForCausalLM",
8
- "AutoModelForSeq2SeqLM": "modeling_glm.GlmForTokenClassification",
9
- "AutoModelForSequenceClassification": "modeling_glm.GlmForSequenceClassification"
10
- },
11
- "rotary_percent": 1.0,
12
  "attention_bias": false,
13
  "attention_dropout": 0.0,
14
  "eos_token_id": [
 
2
  "architectures": [
3
  "GlmForCausalLM"
4
  ],
5
+ "partial_rotary_factor": 1.0,
 
 
 
 
 
 
6
  "attention_bias": false,
7
  "attention_dropout": 0.0,
8
  "eos_token_id": [
modeling_glm.py DELETED
@@ -1,1304 +0,0 @@
1
- import math
2
- from typing import List, Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from transformers.activations import ACT2FN
8
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
9
- from transformers.generation import GenerationMixin
10
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
11
- from transformers.modeling_flash_attention_utils import (
12
- FlashAttentionKwargs,
13
- _flash_attention_forward,
14
- )
15
- from transformers.modeling_outputs import (
16
- BaseModelOutputWithPast,
17
- CausalLMOutputWithPast,
18
- SequenceClassifierOutputWithPast,
19
- TokenClassifierOutput,
20
- )
21
- from transformers.modeling_utils import PreTrainedModel
22
- from transformers.processing_utils import Unpack
23
- from transformers.utils import (
24
- add_code_sample_docstrings,
25
- add_start_docstrings,
26
- add_start_docstrings_to_model_forward,
27
- is_flash_attn_greater_or_equal_2_10,
28
- logging,
29
- replace_return_docstrings,
30
- )
31
- from transformers.models.glm.configuration_glm import GlmConfig
32
-
33
- logger = logging.get_logger(__name__)
34
-
35
- _CHECKPOINT_FOR_DOC = "THUDM/glm-edge-4b-chat"
36
- _CONFIG_FOR_DOC = "GlmConfig"
37
-
38
-
39
- class GlmRMSNorm(nn.Module):
40
- def __init__(self, hidden_size, eps=1e-6):
41
- """
42
- GlmRMSNorm is equivalent to T5LayerNorm
43
- """
44
- super().__init__()
45
- self.weight = nn.Parameter(torch.ones(hidden_size))
46
- self.variance_epsilon = eps
47
-
48
- def forward(self, hidden_states):
49
- input_dtype = hidden_states.dtype
50
- hidden_states = hidden_states.to(torch.float32)
51
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
52
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
53
- return self.weight * hidden_states.to(input_dtype)
54
-
55
- def extra_repr(self):
56
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
57
-
58
-
59
- class GlmRotaryEmbedding(nn.Module):
60
- def __init__(self, dim, max_position_embeddings=2048, base=10000, rotary_percent=0.5, device=None):
61
- super().__init__()
62
- self.rotary_percent = rotary_percent
63
- self.dim = dim * rotary_percent
64
- self.max_position_embeddings = max_position_embeddings
65
- self.base = base
66
-
67
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
68
- self.register_buffer("inv_freq", inv_freq)
69
-
70
- def forward(self, x, position_ids=None):
71
- batch_size, seq_len, head_dim = x.shape
72
- device = x.device
73
- dtype = x.dtype
74
-
75
- seq_idx = torch.arange(0, self.max_position_embeddings, device=device).float()
76
- idx_theta = torch.outer(seq_idx, self.inv_freq)
77
-
78
- if position_ids is not None:
79
- idx_theta = idx_theta[position_ids[0]]
80
- else:
81
- idx_theta = idx_theta[:seq_len]
82
- if self.rotary_percent == 0.5:
83
- idx_theta = torch.cat([idx_theta, idx_theta], dim=-1) # for glm-4-9b
84
-
85
- device_type = device.type
86
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
87
- with torch.autocast(device_type=device_type, enabled=False):
88
- cos = torch.cos(idx_theta).to(dtype=dtype)
89
- sin = torch.sin(idx_theta).to(dtype=dtype)
90
-
91
- cos = cos[None, :, :].expand(batch_size, seq_len, -1)
92
- sin = sin[None, :, :].expand(batch_size, seq_len, -1)
93
-
94
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
95
-
96
-
97
- class GlmMLP(nn.Module):
98
- def __init__(self, config):
99
- super().__init__()
100
-
101
- self.config = config
102
- self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
103
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
104
-
105
- self.activation_fn = ACT2FN[config.hidden_act]
106
-
107
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
108
- up_states = self.gate_up_proj(hidden_states)
109
-
110
- gate, up_states = up_states.chunk(2, dim=-1)
111
- up_states = up_states * self.activation_fn(gate)
112
-
113
- return self.down_proj(up_states)
114
-
115
-
116
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
117
- """
118
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
119
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
120
- """
121
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
122
- if n_rep == 1:
123
- return hidden_states
124
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
125
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
126
-
127
-
128
- def rotate_half(x):
129
- """Rotates half the hidden dims of the input."""
130
- x1 = x[..., 0::2]
131
- x2 = x[..., 1::2]
132
- return torch.stack((-x2, x1), dim=-1).flatten(-2)
133
-
134
-
135
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, rotary_percent=0.5):
136
- """
137
- Applies Rotary Position Embedding to the query and key tensors.
138
- rotary_percent is for glm-4-9b(0.5) or glm-edge(1.0)
139
- """
140
- cos = cos.unsqueeze(unsqueeze_dim)
141
- sin = sin.unsqueeze(unsqueeze_dim)
142
-
143
- # Interleave them instead of usual shape
144
- cos = cos[..., : int(cos.shape[-1] * rotary_percent)].repeat_interleave(2, dim=-1)
145
- sin = sin[..., : int(sin.shape[-1] * rotary_percent)].repeat_interleave(2, dim=-1)
146
-
147
- # Keep rotary_percent(half or not) for later concatenation
148
- rotary_dim = int(q.shape[-1] * rotary_percent)
149
- q, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
150
- k, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
151
-
152
- # Apply rotary embeddings on the first half or full tensor
153
- q_embed = (q * cos) + (rotate_half(q) * sin)
154
- k_embed = (k * cos) + (rotate_half(k) * sin)
155
-
156
- # Concatenate back to full shape
157
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
158
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
159
- return q_embed, k_embed
160
-
161
-
162
- class GlmAttention(nn.Module):
163
- """Multi-headed attention from 'Attention Is All You Need' paper"""
164
-
165
- def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
166
- super().__init__()
167
- self.config = config
168
- self.layer_idx = layer_idx
169
- if layer_idx is None:
170
- logger.warning_once(
171
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
172
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
173
- "when creating this class."
174
- )
175
-
176
- self.attention_dropout = config.attention_dropout
177
- self.hidden_size = config.hidden_size
178
- self.num_heads = config.num_attention_heads
179
- self.head_dim = self.hidden_size // self.num_heads
180
- self.num_key_value_heads = config.num_key_value_heads
181
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
182
- self.is_causal = True
183
- self.scaling = 1 / math.sqrt(self.head_dim)
184
- self.rotary_percent = config.rotary_percent if hasattr(config, "rotary_percent") else 0.5
185
-
186
- if (self.head_dim * self.num_heads) != self.hidden_size:
187
- raise ValueError(
188
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
189
- f" and `num_heads`: {self.num_heads})."
190
- )
191
-
192
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
193
- self.k_proj = nn.Linear(
194
- self.hidden_size,
195
- self.num_key_value_heads * self.head_dim,
196
- bias=config.attention_bias,
197
- )
198
- self.v_proj = nn.Linear(
199
- self.hidden_size,
200
- self.num_key_value_heads * self.head_dim,
201
- bias=config.attention_bias,
202
- )
203
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
204
-
205
- def forward(
206
- self,
207
- hidden_states: torch.Tensor,
208
- attention_mask: Optional[torch.Tensor] = None,
209
- position_ids: Optional[torch.LongTensor] = None,
210
- past_key_value: Optional[Cache] = None,
211
- output_attentions: bool = False,
212
- use_cache: bool = False,
213
- cache_position: Optional[torch.LongTensor] = None,
214
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
215
- **kwargs,
216
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
217
- bsz, q_len, _ = hidden_states.size()
218
-
219
- query_states = self.q_proj(hidden_states)
220
- key_states = self.k_proj(hidden_states)
221
- value_states = self.v_proj(hidden_states)
222
-
223
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
224
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
225
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
226
-
227
- cos, sin = position_embeddings
228
-
229
- query_states, key_states = apply_rotary_pos_emb(
230
- query_states,
231
- key_states,
232
- cos,
233
- sin,
234
- rotary_percent=self.rotary_percent,
235
- )
236
- if past_key_value is not None:
237
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
238
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
239
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
240
-
241
- key_states = repeat_kv(key_states, self.num_key_value_groups)
242
- value_states = repeat_kv(value_states, self.num_key_value_groups)
243
-
244
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
245
-
246
- if attention_mask is not None: # no matter the length, we just slice it
247
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
248
- attn_weights = attn_weights + causal_mask
249
-
250
- # upcast attention to fp32
251
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
252
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
253
- attn_output = torch.matmul(attn_weights, value_states)
254
-
255
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
256
- raise ValueError(
257
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
258
- f" {attn_output.size()}"
259
- )
260
-
261
- attn_output = attn_output.transpose(1, 2).contiguous()
262
-
263
- attn_output = attn_output.view(bsz, q_len, -1)
264
- attn_output = self.o_proj(attn_output)
265
-
266
- if not output_attentions:
267
- attn_weights = None
268
-
269
- return attn_output, attn_weights, past_key_value
270
-
271
-
272
- class GlmFlashAttention2(GlmAttention):
273
- """
274
- Glm flash attention module. This module inherits from `GlmAttention` as the weights of the module stays
275
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
276
- flash attention and deal with padding tokens in case the input contains any of them.
277
- """
278
-
279
- def __init__(self, *args, **kwargs):
280
- super().__init__(*args, **kwargs)
281
-
282
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
283
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
284
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
285
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
286
-
287
- def forward(
288
- self,
289
- hidden_states: torch.Tensor,
290
- attention_mask: Optional[torch.LongTensor] = None,
291
- position_ids: Optional[torch.LongTensor] = None,
292
- past_key_value: Optional[Cache] = None,
293
- output_attentions: bool = False,
294
- use_cache: bool = False,
295
- cache_position: Optional[torch.LongTensor] = None,
296
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
297
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
298
- output_attentions = False
299
-
300
- bsz, q_len, _ = hidden_states.size()
301
-
302
- query_states = self.q_proj(hidden_states)
303
- key_states = self.k_proj(hidden_states)
304
- value_states = self.v_proj(hidden_states)
305
-
306
- # Flash attention requires the input to have the shape
307
- # batch_size x seq_length x head_dim x hidden_dim
308
- # therefore we just need to keep the original shape
309
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
310
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
311
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
312
-
313
- cos, sin = position_embeddings
314
- query_states, key_states = apply_rotary_pos_emb(
315
- query_states,
316
- key_states,
317
- cos,
318
- sin,
319
- rotary_percent=self.rotary_percent,
320
- )
321
-
322
- if past_key_value is not None:
323
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
324
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
325
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
326
-
327
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
328
- # to be able to avoid many of these transpose/reshape/view.
329
- query_states = query_states.transpose(1, 2)
330
- key_states = key_states.transpose(1, 2)
331
- value_states = value_states.transpose(1, 2)
332
-
333
- dropout_rate = self.attention_dropout if self.training else 0.0
334
-
335
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
336
- # therefore the input hidden states gets silently casted in float32. Hence, we need
337
- # cast them back in the correct dtype just to be sure everything works as expected.
338
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
339
- # in fp32. (GlmRMSNorm handles it correctly)
340
-
341
- input_dtype = query_states.dtype
342
- if input_dtype == torch.float32:
343
- if torch.is_autocast_enabled():
344
- target_dtype = torch.get_autocast_gpu_dtype()
345
- # Handle the case where the model is quantized
346
- elif hasattr(self.config, "_pre_quantization_dtype"):
347
- target_dtype = self.config._pre_quantization_dtype
348
- else:
349
- target_dtype = self.q_proj.weight.dtype
350
-
351
- logger.warning_once(
352
- f"The input hidden states seems to be silently casted in float32, this might be related to"
353
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
354
- f" {target_dtype}."
355
- )
356
-
357
- query_states = query_states.to(target_dtype)
358
- key_states = key_states.to(target_dtype)
359
- value_states = value_states.to(target_dtype)
360
-
361
- attn_output = _flash_attention_forward(
362
- query_states,
363
- key_states,
364
- value_states,
365
- attention_mask,
366
- q_len,
367
- position_ids=position_ids,
368
- dropout=dropout_rate,
369
- softmax_scale=self.scaling,
370
- sliding_window=getattr(self, "sliding_window", None),
371
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
372
- is_causal=self.is_causal,
373
- )
374
-
375
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
376
- attn_output = self.o_proj(attn_output)
377
-
378
- if not output_attentions:
379
- attn_weights = None
380
-
381
- return attn_output, attn_weights, past_key_value
382
-
383
-
384
- class GlmSdpaAttention(GlmAttention):
385
- """
386
- Glm attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
387
- `GlmAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
388
- SDPA API.
389
- """
390
-
391
- # Adapted from GlmAttention.forward
392
- def forward(
393
- self,
394
- hidden_states: torch.Tensor,
395
- attention_mask: Optional[torch.Tensor] = None,
396
- position_ids: Optional[torch.LongTensor] = None,
397
- past_key_value: Optional[Cache] = None,
398
- output_attentions: bool = False,
399
- use_cache: bool = False,
400
- cache_position: Optional[torch.LongTensor] = None,
401
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
402
- **kwargs,
403
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
404
- if output_attentions:
405
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
406
- logger.warning_once(
407
- "GlmModel is using GlmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
408
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
409
- )
410
- return super().forward(
411
- hidden_states=hidden_states,
412
- attention_mask=attention_mask,
413
- position_ids=position_ids,
414
- past_key_value=past_key_value,
415
- output_attentions=output_attentions,
416
- use_cache=use_cache,
417
- cache_position=cache_position,
418
- position_embeddings=position_embeddings,
419
- )
420
-
421
- bsz, q_len, _ = hidden_states.size()
422
-
423
- query_states = self.q_proj(hidden_states)
424
- key_states = self.k_proj(hidden_states)
425
- value_states = self.v_proj(hidden_states)
426
-
427
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
428
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
429
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
430
-
431
- cos, sin = position_embeddings
432
- query_states, key_states = apply_rotary_pos_emb(
433
- query_states,
434
- key_states,
435
- cos,
436
- sin,
437
- rotary_percent=self.rotary_percent,
438
- )
439
-
440
- if past_key_value is not None:
441
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
442
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
443
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
444
-
445
- key_states = repeat_kv(key_states, self.num_key_value_groups)
446
- value_states = repeat_kv(value_states, self.num_key_value_groups)
447
-
448
- causal_mask = attention_mask
449
- if attention_mask is not None:
450
- causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
451
-
452
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
453
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
454
- if query_states.device.type == "cuda" and causal_mask is not None:
455
- query_states = query_states.contiguous()
456
- key_states = key_states.contiguous()
457
- value_states = value_states.contiguous()
458
-
459
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
460
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
461
- is_causal = True if causal_mask is None and q_len > 1 else False
462
-
463
- attn_output = torch.nn.functional.scaled_dot_product_attention(
464
- query_states,
465
- key_states,
466
- value_states,
467
- attn_mask=causal_mask,
468
- dropout_p=self.attention_dropout if self.training else 0.0,
469
- is_causal=is_causal,
470
- scale=self.scaling,
471
- )
472
-
473
- attn_output = attn_output.transpose(1, 2).contiguous()
474
- attn_output = attn_output.view(bsz, q_len, -1)
475
-
476
- attn_output = self.o_proj(attn_output)
477
-
478
- return attn_output, None, past_key_value
479
-
480
-
481
- GLM_ATTENTION_CLASSES = {
482
- "eager": GlmAttention,
483
- "flash_attention_2": GlmFlashAttention2,
484
- "sdpa": GlmSdpaAttention,
485
- }
486
-
487
-
488
- class GlmDecoderLayer(nn.Module):
489
- def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
490
- super().__init__()
491
- self.hidden_size = config.hidden_size
492
-
493
- self.self_attn = GLM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
494
-
495
- self.mlp = GlmMLP(config)
496
- self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
497
- self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
498
-
499
- def forward(
500
- self,
501
- hidden_states: torch.Tensor,
502
- attention_mask: Optional[torch.Tensor] = None,
503
- position_ids: Optional[torch.LongTensor] = None,
504
- past_key_value: Optional[Cache] = None,
505
- output_attentions: Optional[bool] = False,
506
- use_cache: Optional[bool] = False,
507
- cache_position: Optional[torch.LongTensor] = None,
508
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
509
- **kwargs,
510
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
511
- """
512
- Args:
513
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
514
- attention_mask (`torch.FloatTensor`, *optional*):
515
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
516
- query_sequence_length, key_sequence_length)` if default attention is used.
517
- output_attentions (`bool`, *optional*):
518
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
519
- returned tensors for more detail.
520
- use_cache (`bool`, *optional*):
521
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
522
- (see `past_key_values`).
523
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
524
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
525
- Indices depicting the position of the input sequence tokens in the sequence
526
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
527
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
528
- with `head_dim` being the embedding dimension of each attention head.
529
- kwargs (`dict`, *optional*):
530
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
531
- into the model
532
- """
533
- residual = hidden_states
534
-
535
- hidden_states = self.input_layernorm(hidden_states)
536
-
537
- # Self Attention
538
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
539
- hidden_states=hidden_states,
540
- attention_mask=attention_mask,
541
- position_ids=position_ids,
542
- past_key_value=past_key_value,
543
- output_attentions=output_attentions,
544
- use_cache=use_cache,
545
- cache_position=cache_position,
546
- position_embeddings=position_embeddings,
547
- **kwargs,
548
- )
549
- hidden_states = residual + hidden_states
550
-
551
- # Fully Connected
552
- residual = hidden_states
553
- hidden_states = self.post_attention_layernorm(hidden_states)
554
- hidden_states = self.mlp(hidden_states)
555
- hidden_states = residual + hidden_states
556
-
557
- outputs = (hidden_states,)
558
-
559
- if output_attentions:
560
- outputs += (self_attn_weights,)
561
-
562
- if use_cache:
563
- outputs += (present_key_value,)
564
-
565
- return outputs
566
-
567
-
568
- GLM_START_DOCSTRING = r"""
569
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
570
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
571
- etc.)
572
-
573
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
574
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
575
- and behavior.
576
-
577
- Parameters:
578
- config ([`GlmConfig`]):
579
- Model configuration class with all the parameters of the model. Initializing with a config file does not
580
- load the weights associated with the model, only the configuration. Check out the
581
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
582
- """
583
-
584
-
585
- @add_start_docstrings(
586
- "The bare Glm Model outputting raw hidden-states without any specific head on top.",
587
- GLM_START_DOCSTRING,
588
- )
589
- class GlmPreTrainedModel(PreTrainedModel):
590
- config_class = GlmConfig
591
- base_model_prefix = "model"
592
- supports_gradient_checkpointing = True
593
- _no_split_modules = ["GlmDecoderLayer"]
594
- _skip_keys_device_placement = ["past_key_values"]
595
- _supports_flash_attn_2 = True
596
- _supports_sdpa = True
597
- _supports_cache_class = True
598
- _supports_quantized_cache = True
599
- _supports_static_cache = True
600
-
601
- def _init_weights(self, module):
602
- std = self.config.initializer_range
603
- if isinstance(module, nn.Linear):
604
- module.weight.data.normal_(mean=0.0, std=std)
605
- if module.bias is not None:
606
- module.bias.data.zero_()
607
- elif isinstance(module, nn.Embedding):
608
- module.weight.data.normal_(mean=0.0, std=std)
609
- if module.padding_idx is not None:
610
- module.weight.data[module.padding_idx].zero_()
611
-
612
-
613
- GLM_INPUTS_DOCSTRING = r"""
614
- Args:
615
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
616
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
617
- it.
618
-
619
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
620
- [`PreTrainedTokenizer.__call__`] for details.
621
-
622
- [What are input IDs?](../glossary#input-ids)
623
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
624
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
625
-
626
- - 1 for tokens that are **not masked**,
627
- - 0 for tokens that are **masked**.
628
-
629
- [What are attention masks?](../glossary#attention-mask)
630
-
631
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
632
- [`PreTrainedTokenizer.__call__`] for details.
633
-
634
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
635
- `past_key_values`).
636
-
637
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
638
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
639
- information on the default strategy.
640
-
641
- - 1 indicates the head is **not masked**,
642
- - 0 indicates the head is **masked**.
643
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
644
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
645
- config.n_positions - 1]`.
646
-
647
- [What are position IDs?](../glossary#position-ids)
648
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
649
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
650
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
651
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
652
-
653
- Two formats are allowed:
654
- - a [`~cache_utils.Cache`] instance, see our
655
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
656
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
657
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
658
- cache format.
659
-
660
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
661
- legacy cache format will be returned.
662
-
663
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
664
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
665
- of shape `(batch_size, sequence_length)`.
666
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
667
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
668
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
669
- model's internal embedding lookup matrix.
670
- use_cache (`bool`, *optional*):
671
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
672
- `past_key_values`).
673
- output_attentions (`bool`, *optional*):
674
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
675
- tensors for more detail.
676
- output_hidden_states (`bool`, *optional*):
677
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
678
- more detail.
679
- return_dict (`bool`, *optional*):
680
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
681
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
682
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
683
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
684
- the complete sequence length.
685
- """
686
-
687
-
688
- @add_start_docstrings(
689
- "The bare Glm Model outputting raw hidden-states without any specific head on top.",
690
- GLM_START_DOCSTRING,
691
- )
692
- class GlmModel(GlmPreTrainedModel):
693
- """
694
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GlmDecoderLayer`]
695
-
696
- Args:
697
- config: GlmConfig
698
- """
699
-
700
- def __init__(self, config: GlmConfig):
701
- super().__init__(config)
702
- self.padding_idx = config.pad_token_id
703
- self.vocab_size = config.vocab_size
704
- self.rotary_percent = config.rotary_percent if hasattr(config, "rotary_percent") else 0.5
705
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
706
- self.layers = nn.ModuleList(
707
- [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
708
- )
709
- self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
710
- self.rotary_emb = GlmRotaryEmbedding(
711
- dim=config.head_dim,
712
- max_position_embeddings=config.max_position_embeddings,
713
- base=config.rope_theta,
714
- rotary_percent=self.rotary_percent,
715
- )
716
- self.gradient_checkpointing = False
717
-
718
- # Initialize weights and apply final processing
719
- self.post_init()
720
-
721
- def get_input_embeddings(self):
722
- return self.embed_tokens
723
-
724
- def set_input_embeddings(self, value):
725
- self.embed_tokens = value
726
-
727
- @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
728
- def forward(
729
- self,
730
- input_ids: torch.LongTensor = None,
731
- attention_mask: Optional[torch.Tensor] = None,
732
- position_ids: Optional[torch.LongTensor] = None,
733
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
734
- inputs_embeds: Optional[torch.FloatTensor] = None,
735
- use_cache: Optional[bool] = None,
736
- output_attentions: Optional[bool] = None,
737
- output_hidden_states: Optional[bool] = None,
738
- return_dict: Optional[bool] = None,
739
- cache_position: Optional[torch.LongTensor] = None,
740
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
741
- ) -> Union[Tuple, BaseModelOutputWithPast]:
742
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
743
- output_hidden_states = (
744
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
745
- )
746
-
747
- use_cache = use_cache if use_cache is not None else self.config.use_cache
748
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
749
-
750
- if (input_ids is None) ^ (inputs_embeds is not None):
751
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
752
-
753
- if self.gradient_checkpointing and self.training and use_cache:
754
- logger.warning_once(
755
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
756
- )
757
- use_cache = False
758
-
759
- if inputs_embeds is None:
760
- inputs_embeds = self.embed_tokens(input_ids)
761
-
762
- # kept for BC (non `Cache` `past_key_values` inputs)
763
- return_legacy_cache = False
764
- if use_cache and not isinstance(past_key_values, Cache):
765
- return_legacy_cache = True
766
- if past_key_values is None:
767
- past_key_values = DynamicCache()
768
- else:
769
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
770
- logger.warning_once(
771
- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
772
- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
773
- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
774
- )
775
-
776
- if cache_position is None:
777
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
778
- cache_position = torch.arange(
779
- past_seen_tokens,
780
- past_seen_tokens + inputs_embeds.shape[1],
781
- device=inputs_embeds.device,
782
- )
783
- if position_ids is None:
784
- position_ids = cache_position.unsqueeze(0)
785
-
786
- causal_mask = self._update_causal_mask(
787
- attention_mask,
788
- inputs_embeds,
789
- cache_position,
790
- past_key_values,
791
- output_attentions,
792
- )
793
- hidden_states = inputs_embeds
794
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
795
-
796
- # decoder layers
797
- all_hidden_states = () if output_hidden_states else None
798
- all_self_attns = () if output_attentions else None
799
- next_decoder_cache = None
800
-
801
- for decoder_layer in self.layers:
802
- if output_hidden_states:
803
- all_hidden_states += (hidden_states,)
804
-
805
- if self.gradient_checkpointing and self.training:
806
- layer_outputs = self._gradient_checkpointing_func(
807
- decoder_layer.__call__,
808
- hidden_states,
809
- causal_mask,
810
- position_ids,
811
- past_key_values,
812
- output_attentions,
813
- use_cache,
814
- cache_position,
815
- position_embeddings,
816
- )
817
- else:
818
- layer_outputs = decoder_layer(
819
- hidden_states,
820
- attention_mask=causal_mask,
821
- position_ids=position_ids,
822
- past_key_value=past_key_values,
823
- output_attentions=output_attentions,
824
- use_cache=use_cache,
825
- cache_position=cache_position,
826
- position_embeddings=position_embeddings,
827
- **flash_attn_kwargs,
828
- )
829
-
830
- hidden_states = layer_outputs[0]
831
-
832
- if use_cache:
833
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
834
-
835
- if output_attentions:
836
- all_self_attns += (layer_outputs[1],)
837
-
838
- hidden_states = self.norm(hidden_states)
839
-
840
- # add hidden states from the last decoder layer
841
- if output_hidden_states:
842
- all_hidden_states += (hidden_states,)
843
-
844
- next_cache = next_decoder_cache if use_cache else None
845
- if return_legacy_cache:
846
- next_cache = next_cache.to_legacy_cache()
847
-
848
- if not return_dict:
849
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
850
- return BaseModelOutputWithPast(
851
- last_hidden_state=hidden_states,
852
- past_key_values=next_cache,
853
- hidden_states=all_hidden_states,
854
- attentions=all_self_attns,
855
- )
856
-
857
- def _update_causal_mask(
858
- self,
859
- attention_mask: torch.Tensor,
860
- input_tensor: torch.Tensor,
861
- cache_position: torch.Tensor,
862
- past_key_values: Cache,
863
- output_attentions: bool,
864
- ):
865
- if self.config._attn_implementation == "flash_attention_2":
866
- if attention_mask is not None and 0.0 in attention_mask:
867
- return attention_mask
868
- return None
869
-
870
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
871
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
872
- # to infer the attention mask.
873
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
874
- using_static_cache = isinstance(past_key_values, StaticCache)
875
-
876
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
877
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
878
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
879
- attention_mask,
880
- inputs_embeds=input_tensor,
881
- past_key_values_length=past_seen_tokens,
882
- is_training=self.training,
883
- ):
884
- return None
885
-
886
- dtype, device = input_tensor.dtype, input_tensor.device
887
- sequence_length = input_tensor.shape[1]
888
- if using_static_cache:
889
- target_length = past_key_values.get_max_cache_shape()
890
- else:
891
- target_length = (
892
- attention_mask.shape[-1]
893
- if isinstance(attention_mask, torch.Tensor)
894
- else past_seen_tokens + sequence_length + 1
895
- )
896
-
897
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
898
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
899
- attention_mask,
900
- sequence_length=sequence_length,
901
- target_length=target_length,
902
- dtype=dtype,
903
- device=device,
904
- cache_position=cache_position,
905
- batch_size=input_tensor.shape[0],
906
- )
907
-
908
- if (
909
- self.config._attn_implementation == "sdpa"
910
- and attention_mask is not None
911
- and attention_mask.device.type == "cuda"
912
- and not output_attentions
913
- ):
914
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
915
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
916
- # Details: https://github.com/pytorch/pytorch/issues/110213
917
- min_dtype = torch.finfo(dtype).min
918
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
919
-
920
- return causal_mask
921
-
922
- @staticmethod
923
- def _prepare_4d_causal_attention_mask_with_cache_position(
924
- attention_mask: torch.Tensor,
925
- sequence_length: int,
926
- target_length: int,
927
- dtype: torch.dtype,
928
- device: torch.device,
929
- cache_position: torch.Tensor,
930
- batch_size: int,
931
- **kwargs,
932
- ):
933
- """
934
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
935
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
936
-
937
- Args:
938
- attention_mask (`torch.Tensor`):
939
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
940
- `(batch_size, 1, query_length, key_value_length)`.
941
- sequence_length (`int`):
942
- The sequence length being processed.
943
- target_length (`int`):
944
- The target length: when generating with static cache, the mask should be as long as the static cache,
945
- to account for the 0 padding, the part of the cache that is not filled yet.
946
- dtype (`torch.dtype`):
947
- The dtype to use for the 4D attention mask.
948
- device (`torch.device`):
949
- The device to plcae the 4D attention mask on.
950
- cache_position (`torch.Tensor`):
951
- Indices depicting the position of the input sequence tokens in the sequence.
952
- batch_size (`torch.Tensor`):
953
- Batch size.
954
- """
955
- if attention_mask is not None and attention_mask.dim() == 4:
956
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
957
- causal_mask = attention_mask
958
- else:
959
- min_dtype = torch.finfo(dtype).min
960
- causal_mask = torch.full(
961
- (sequence_length, target_length),
962
- fill_value=min_dtype,
963
- dtype=dtype,
964
- device=device,
965
- )
966
- if sequence_length != 1:
967
- causal_mask = torch.triu(causal_mask, diagonal=1)
968
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
969
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
970
- if attention_mask is not None:
971
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
972
- mask_length = attention_mask.shape[-1]
973
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
974
- padding_mask = padding_mask == 0
975
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
976
- padding_mask, min_dtype
977
- )
978
-
979
- return causal_mask
980
-
981
-
982
- class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
983
- _tied_weights_keys = ["lm_head.weight"]
984
-
985
- def __init__(self, config: GlmConfig):
986
- super().__init__(config)
987
- self.model = GlmModel(config)
988
- self.vocab_size = config.vocab_size
989
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
990
-
991
- # Initialize weights and apply final processing
992
- self.post_init()
993
-
994
- def get_input_embeddings(self):
995
- return self.model.embed_tokens
996
-
997
- def set_input_embeddings(self, value):
998
- self.model.embed_tokens = value
999
-
1000
- def get_output_embeddings(self):
1001
- return self.lm_head
1002
-
1003
- def set_output_embeddings(self, new_embeddings):
1004
- self.lm_head = new_embeddings
1005
-
1006
- def set_decoder(self, decoder):
1007
- self.model = decoder
1008
-
1009
- def get_decoder(self):
1010
- return self.model
1011
-
1012
- @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
1013
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1014
- def forward(
1015
- self,
1016
- input_ids: torch.LongTensor = None,
1017
- attention_mask: Optional[torch.Tensor] = None,
1018
- position_ids: Optional[torch.LongTensor] = None,
1019
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1020
- inputs_embeds: Optional[torch.FloatTensor] = None,
1021
- labels: Optional[torch.LongTensor] = None,
1022
- use_cache: Optional[bool] = None,
1023
- output_attentions: Optional[bool] = None,
1024
- output_hidden_states: Optional[bool] = None,
1025
- return_dict: Optional[bool] = None,
1026
- cache_position: Optional[torch.LongTensor] = None,
1027
- num_logits_to_keep: int = 0,
1028
- **loss_kwargs,
1029
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1030
- r"""
1031
- Args:
1032
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1033
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1034
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1035
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1036
-
1037
- num_logits_to_keep (`int`, *optional*):
1038
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1039
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1040
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1041
-
1042
- Returns:
1043
-
1044
- Example:
1045
-
1046
- ```python
1047
- >>> from transformers import AutoTokenizer, GlmForCausalLM
1048
-
1049
- >>> model = GlmForCausalLM.from_pretrained("google/glm-7b")
1050
- >>> tokenizer = AutoTokenizer.from_pretrained("google/glm-7b")
1051
-
1052
- >>> prompt = "What is your favorite condiment?"
1053
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1054
-
1055
- >>> # Generate
1056
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1057
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1058
- "What is your favorite condiment?"
1059
- ```"""
1060
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1061
- output_hidden_states = (
1062
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1063
- )
1064
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1065
-
1066
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1067
- outputs = self.model(
1068
- input_ids=input_ids,
1069
- attention_mask=attention_mask,
1070
- position_ids=position_ids,
1071
- past_key_values=past_key_values,
1072
- inputs_embeds=inputs_embeds,
1073
- use_cache=use_cache,
1074
- output_attentions=output_attentions,
1075
- output_hidden_states=output_hidden_states,
1076
- return_dict=return_dict,
1077
- cache_position=cache_position,
1078
- )
1079
-
1080
- hidden_states = outputs[0]
1081
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1082
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1083
-
1084
- loss = None
1085
- if labels is not None:
1086
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1087
-
1088
- if not return_dict:
1089
- output = (logits,) + outputs[1:]
1090
- return (loss,) + output if loss is not None else output
1091
-
1092
- return CausalLMOutputWithPast(
1093
- loss=loss,
1094
- logits=logits,
1095
- past_key_values=outputs.past_key_values,
1096
- hidden_states=outputs.hidden_states,
1097
- attentions=outputs.attentions,
1098
- )
1099
-
1100
-
1101
- @add_start_docstrings(
1102
- """
1103
- The Glm Model transformer with a sequence classification head on top (linear layer).
1104
-
1105
- [`GlmForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1106
- (e.g. GPT-2) do.
1107
-
1108
- Since it does classification on the last token, it requires to know the position of the last token. If a
1109
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1110
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1111
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1112
- each row of the batch).
1113
- """,
1114
- GLM_START_DOCSTRING,
1115
- )
1116
- class GlmForSequenceClassification(GlmPreTrainedModel):
1117
- def __init__(self, config: GlmConfig):
1118
- super().__init__(config)
1119
- self.num_labels = config.num_labels
1120
- self.model = GlmModel(config)
1121
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1122
-
1123
- # Initialize weights and apply final processing
1124
- self.post_init()
1125
-
1126
- def get_input_embeddings(self):
1127
- return self.model.embed_tokens
1128
-
1129
- def set_input_embeddings(self, value):
1130
- self.model.embed_tokens = value
1131
-
1132
- @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
1133
- def forward(
1134
- self,
1135
- input_ids: Optional[torch.LongTensor] = None,
1136
- attention_mask: Optional[torch.Tensor] = None,
1137
- position_ids: Optional[torch.LongTensor] = None,
1138
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1139
- inputs_embeds: Optional[torch.FloatTensor] = None,
1140
- labels: Optional[torch.LongTensor] = None,
1141
- use_cache: Optional[bool] = None,
1142
- output_attentions: Optional[bool] = None,
1143
- output_hidden_states: Optional[bool] = None,
1144
- return_dict: Optional[bool] = None,
1145
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1146
- r"""
1147
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1148
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1149
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1150
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1151
- """
1152
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1153
-
1154
- transformer_outputs = self.model(
1155
- input_ids,
1156
- attention_mask=attention_mask,
1157
- position_ids=position_ids,
1158
- past_key_values=past_key_values,
1159
- inputs_embeds=inputs_embeds,
1160
- use_cache=use_cache,
1161
- output_attentions=output_attentions,
1162
- output_hidden_states=output_hidden_states,
1163
- return_dict=return_dict,
1164
- )
1165
- hidden_states = transformer_outputs[0]
1166
- logits = self.score(hidden_states)
1167
-
1168
- if input_ids is not None:
1169
- batch_size = input_ids.shape[0]
1170
- else:
1171
- batch_size = inputs_embeds.shape[0]
1172
-
1173
- if self.config.pad_token_id is None and batch_size != 1:
1174
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1175
- if self.config.pad_token_id is None:
1176
- sequence_lengths = -1
1177
- else:
1178
- if input_ids is not None:
1179
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1180
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1181
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1182
- sequence_lengths = sequence_lengths.to(logits.device)
1183
- else:
1184
- sequence_lengths = -1
1185
-
1186
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1187
-
1188
- loss = None
1189
- if labels is not None:
1190
- loss = self.loss_function(
1191
- logits=logits,
1192
- labels=labels,
1193
- pooled_logits=pooled_logits,
1194
- config=self.config,
1195
- )
1196
-
1197
- if not return_dict:
1198
- output = (pooled_logits,) + transformer_outputs[1:]
1199
- return ((loss,) + output) if loss is not None else output
1200
-
1201
- return SequenceClassifierOutputWithPast(
1202
- loss=loss,
1203
- logits=pooled_logits,
1204
- past_key_values=transformer_outputs.past_key_values,
1205
- hidden_states=transformer_outputs.hidden_states,
1206
- attentions=transformer_outputs.attentions,
1207
- )
1208
-
1209
-
1210
- @add_start_docstrings(
1211
- """
1212
- The Glm Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1213
- output) e.g. for Named-Entity-Recognition (NER) tasks.
1214
- """,
1215
- GLM_START_DOCSTRING,
1216
- )
1217
- class GlmForTokenClassification(GlmPreTrainedModel):
1218
- def __init__(self, config: GlmConfig):
1219
- super().__init__(config)
1220
- self.num_labels = config.num_labels
1221
- self.model = GlmModel(config)
1222
- if getattr(config, "classifier_dropout", None) is not None:
1223
- classifier_dropout = config.classifier_dropout
1224
- elif getattr(config, "hidden_dropout", None) is not None:
1225
- classifier_dropout = config.hidden_dropout
1226
- else:
1227
- classifier_dropout = 0.1
1228
- self.dropout = nn.Dropout(classifier_dropout)
1229
- self.score = nn.Linear(config.hidden_size, config.num_labels)
1230
-
1231
- # Initialize weights and apply final processing
1232
- self.post_init()
1233
-
1234
- def get_input_embeddings(self):
1235
- return self.model.embed_tokens
1236
-
1237
- def set_input_embeddings(self, value):
1238
- self.model.embed_tokens = value
1239
-
1240
- @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
1241
- @add_code_sample_docstrings(
1242
- checkpoint=_CHECKPOINT_FOR_DOC,
1243
- output_type=TokenClassifierOutput,
1244
- config_class=_CONFIG_FOR_DOC,
1245
- )
1246
- def forward(
1247
- self,
1248
- input_ids: Optional[torch.LongTensor] = None,
1249
- attention_mask: Optional[torch.Tensor] = None,
1250
- position_ids: Optional[torch.LongTensor] = None,
1251
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1252
- inputs_embeds: Optional[torch.FloatTensor] = None,
1253
- labels: Optional[torch.LongTensor] = None,
1254
- use_cache: Optional[bool] = None,
1255
- output_attentions: Optional[bool] = None,
1256
- output_hidden_states: Optional[bool] = None,
1257
- return_dict: Optional[bool] = None,
1258
- ) -> Union[Tuple, TokenClassifierOutput]:
1259
- r"""
1260
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1261
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1262
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1263
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1264
- """
1265
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1266
-
1267
- outputs = self.model(
1268
- input_ids,
1269
- attention_mask=attention_mask,
1270
- position_ids=position_ids,
1271
- past_key_values=past_key_values,
1272
- inputs_embeds=inputs_embeds,
1273
- use_cache=use_cache,
1274
- output_attentions=output_attentions,
1275
- output_hidden_states=output_hidden_states,
1276
- return_dict=return_dict,
1277
- )
1278
- sequence_output = outputs[0]
1279
- sequence_output = self.dropout(sequence_output)
1280
- logits = self.score(sequence_output)
1281
-
1282
- loss = None
1283
- if labels is not None:
1284
- loss = self.loss_function(logits, labels, self.config)
1285
-
1286
- if not return_dict:
1287
- output = (logits,) + outputs[2:]
1288
- return ((loss,) + output) if loss is not None else output
1289
-
1290
- return TokenClassifierOutput(
1291
- loss=loss,
1292
- logits=logits,
1293
- hidden_states=outputs.hidden_states,
1294
- attentions=outputs.attentions,
1295
- )
1296
-
1297
-
1298
- __all__ = [
1299
- "GlmPreTrainedModel",
1300
- "GlmModel",
1301
- "GlmForCausalLM",
1302
- "GlmForSequenceClassification",
1303
- "GlmForTokenClassification",
1304
- ]