Upload baichuan-incBaichuan-13B-Chat--modeling_baichuan.py

#3
by CavioKay - opened
baichuan-incBaichuan-13B-Chat--modeling_baichuan.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
2
+
3
+ import math
4
+ from threading import Thread
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from torch.nn import CrossEntropyLoss
10
+ from transformers import PreTrainedModel
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
+ from transformers.utils import logging
14
+ from transformers.generation.utils import GenerationConfig
15
+
16
+ from .configuration_baichuan import BaichuanConfig
17
+ from .generation_utils import build_chat_input, TextIterStreamer
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ def _get_interleave(n):
23
+ def _get_interleave_power_of_2(n):
24
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
25
+ ratio = start
26
+ return [start * ratio ** i for i in range(n)]
27
+
28
+ if math.log2(n).is_integer():
29
+ return _get_interleave_power_of_2(n)
30
+ else:
31
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
32
+ return _get_interleave_power_of_2(closest_power_of_2) + \
33
+ _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
34
+
35
+ def _fill_with_neg_inf(t):
36
+ """FP16-compatible function that fills a tensor with -inf."""
37
+ return t.float().fill_(float("-inf")).type_as(t)
38
+
39
+ def _gen_alibi_mask(n_head, max_pos):
40
+ """used in inference only"""
41
+ slopes = torch.Tensor(_get_interleave(n_head))
42
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
43
+ n_head, -1, -1)
44
+ alibi = alibi.view(n_head, 1, max_pos)
45
+ alibi_mask = torch.triu(
46
+ _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
47
+ )
48
+ alibi_mask = alibi_mask.unsqueeze(0) + alibi
49
+ return alibi_mask
50
+
51
+ def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
52
+ """used in training only"""
53
+ dim = tensor.size(1)
54
+ _future_mask = torch.triu(
55
+ _fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1
56
+ )
57
+ _future_mask = _future_mask.unsqueeze(0) + alibi
58
+ _future_mask = _future_mask.to(tensor)
59
+ return _future_mask[:tensor.shape[0] * attn_heads, :maxpos, :maxpos]
60
+
61
+
62
+ class RMSNorm(torch.nn.Module):
63
+ def __init__(self, hidden_size, epsilon=1e-6):
64
+ super().__init__()
65
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size))
66
+ self.epsilon = epsilon
67
+
68
+ def forward(self, hidden_states):
69
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
70
+ hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
71
+
72
+ # convert into half-precision
73
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
74
+ hidden_states = hidden_states.to(self.weight.dtype)
75
+
76
+ return self.weight * hidden_states
77
+
78
+
79
+ class MLP(torch.nn.Module):
80
+ def __init__(
81
+ self,
82
+ hidden_size: int,
83
+ intermediate_size: int,
84
+ hidden_act: str,
85
+ ):
86
+ super().__init__()
87
+ self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
88
+ self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
89
+ self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
90
+ self.act_fn = ACT2FN[hidden_act]
91
+
92
+ def forward(self, x):
93
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
94
+
95
+
96
+ class BaichuanAttention(torch.nn.Module):
97
+ def __init__(self, config: BaichuanConfig):
98
+ super().__init__()
99
+ self.config = config
100
+ self.hidden_size = config.hidden_size
101
+ self.num_heads = config.num_attention_heads
102
+ self.head_dim = self.hidden_size // self.num_heads
103
+ self.max_position_embeddings = config.model_max_length
104
+
105
+ if (self.head_dim * self.num_heads) != self.hidden_size:
106
+ raise ValueError(
107
+ f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
108
+ )
109
+ self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
110
+ self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
111
+
112
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
113
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
114
+
115
+ def forward(
116
+ self,
117
+ hidden_states: torch.Tensor,
118
+ attention_mask: Optional[torch.Tensor] = None,
119
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
120
+ output_attentions: bool = False,
121
+ use_cache: bool = False,
122
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
123
+
124
+ bsz, q_len, _ = hidden_states.size()
125
+
126
+ proj = self.W_pack(hidden_states)
127
+ proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
128
+ query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
129
+ key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
130
+ value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
131
+
132
+ kv_seq_len = key_states.shape[-2]
133
+ if past_key_value is not None:
134
+ kv_seq_len += past_key_value[0].shape[-2]
135
+
136
+ if past_key_value is not None:
137
+ # reuse k, v, self_attention
138
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
139
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
140
+
141
+ past_key_value = (key_states, value_states) if use_cache else None
142
+
143
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
144
+
145
+ if attention_mask is not None:
146
+ if q_len == 1: # inference with cache
147
+ if len(attention_mask.size()) == 4:
148
+ attention_mask = attention_mask[:, :, -1:, :]
149
+ else:
150
+ attention_mask = attention_mask[:, -1:, :]
151
+ attn_weights = attn_weights + attention_mask
152
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
153
+
154
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
155
+
156
+ attn_output = torch.matmul(attn_weights, value_states)
157
+
158
+ attn_output = attn_output.transpose(1, 2)
159
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
160
+ attn_output = self.o_proj(attn_output)
161
+
162
+ if not output_attentions:
163
+ attn_weights = None
164
+
165
+ return attn_output, attn_weights, past_key_value
166
+
167
+
168
+ class BaichuanLayer(torch.nn.Module):
169
+ def __init__(self, config: BaichuanConfig):
170
+ super().__init__()
171
+ self.hidden_size = config.hidden_size
172
+ self.self_attn = BaichuanAttention(config=config)
173
+ self.mlp = MLP(
174
+ hidden_size=self.hidden_size,
175
+ intermediate_size=config.intermediate_size,
176
+ hidden_act=config.hidden_act,
177
+ )
178
+ self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
179
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
180
+
181
+ def forward(
182
+ self,
183
+ hidden_states: torch.Tensor,
184
+ attention_mask: Optional[torch.Tensor] = None,
185
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
186
+ output_attentions: Optional[bool] = False,
187
+ use_cache: Optional[bool] = False,
188
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
189
+
190
+ residual = hidden_states
191
+
192
+ hidden_states = self.input_layernorm(hidden_states)
193
+
194
+ # Self Attention
195
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
196
+ hidden_states=hidden_states,
197
+ attention_mask=attention_mask,
198
+ past_key_value=past_key_value,
199
+ output_attentions=output_attentions,
200
+ use_cache=use_cache,
201
+ )
202
+ hidden_states = residual + hidden_states
203
+
204
+ # Fully Connected
205
+ residual = hidden_states
206
+ hidden_states = self.post_attention_layernorm(hidden_states)
207
+ hidden_states = self.mlp(hidden_states)
208
+ hidden_states = residual + hidden_states
209
+
210
+ outputs = (hidden_states,)
211
+
212
+ if use_cache:
213
+ outputs += (present_key_value,)
214
+
215
+ return outputs
216
+
217
+
218
+ class BaichuanPreTrainedModel(PreTrainedModel):
219
+ config_class = BaichuanConfig
220
+ base_model_prefix = "model"
221
+ supports_gradient_checkpointing = True
222
+ _no_split_modules = ["BaichuanLayer"]
223
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
224
+
225
+ def _init_weights(self, module):
226
+ std = self.config.initializer_range
227
+ if isinstance(module, torch.nn.Linear):
228
+ module.weight.data.normal_(mean=0.0, std=std)
229
+ if module.bias is not None:
230
+ module.bias.data.zero_()
231
+ elif isinstance(module, torch.nn.Embedding):
232
+ module.weight.data.normal_(mean=0.0, std=std)
233
+ if module.padding_idx is not None:
234
+ module.weight.data[module.padding_idx].zero_()
235
+
236
+ def _set_gradient_checkpointing(self, module, value=False):
237
+ if isinstance(module, BaichuanModel):
238
+ module.gradient_checkpointing = value
239
+
240
+
241
+ class BaichuanModel(BaichuanPreTrainedModel):
242
+ def __init__(self, config: BaichuanConfig):
243
+ super().__init__(config)
244
+ self.padding_idx = config.pad_token_id
245
+ self.vocab_size = config.vocab_size
246
+ self.n_head = config.num_attention_heads
247
+ self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
248
+ self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
249
+ self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
250
+
251
+ self.gradient_checkpointing = config.gradient_checkpointing
252
+ self.post_init()
253
+ self.max_cache_pos = config.model_max_length
254
+ self.first_run = True
255
+ self.alibi_mask = None
256
+
257
+ def get_input_embeddings(self):
258
+ return self.embed_tokens
259
+
260
+ def set_input_embeddings(self, value):
261
+ self.embed_tokens = value
262
+
263
+ def get_alibi_mask(self, tensor, seq_length_with_past):
264
+ if self.training:
265
+ slopes = torch.Tensor(_get_interleave(self.n_head))
266
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand(
267
+ self.n_head,
268
+ -1, -1)
269
+ alibi = alibi.view(self.n_head, 1, seq_length_with_past)
270
+ mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.n_head)
271
+ else:
272
+ if self.first_run:
273
+ self.first_run = False
274
+ self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
275
+ if seq_length_with_past > self.max_cache_pos:
276
+ self.max_cache_pos = seq_length_with_past
277
+ self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
278
+ mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
279
+ return mask
280
+
281
+ def forward(
282
+ self,
283
+ input_ids: torch.LongTensor = None,
284
+ attention_mask: Optional[torch.Tensor] = None,
285
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
286
+ inputs_embeds: Optional[torch.FloatTensor] = None,
287
+ use_cache: Optional[bool] = False,
288
+ output_attentions: Optional[bool] = False,
289
+ output_hidden_states: Optional[bool] = False,
290
+ return_dict: Optional[bool] = True,
291
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
292
+
293
+ if input_ids is not None and inputs_embeds is not None:
294
+ raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
295
+ elif input_ids is not None:
296
+ batch_size, seq_length = input_ids.shape
297
+ elif inputs_embeds is not None:
298
+ batch_size, seq_length, _ = inputs_embeds.shape
299
+ else:
300
+ raise ValueError("You need to provide input_ids or inputs_embeds")
301
+
302
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
303
+
304
+ seq_length_with_past = seq_length
305
+
306
+ if past_key_values is not None:
307
+ past_key_values_length = past_key_values[0][0].shape[2]
308
+ seq_length_with_past = seq_length_with_past + past_key_values_length
309
+
310
+ if inputs_embeds is None:
311
+ inputs_embeds = self.embed_tokens(input_ids)
312
+
313
+ if self.training:
314
+ if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
315
+ self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
316
+ alibi_mask = self.alibi_mask
317
+ else:
318
+ alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
319
+
320
+ if attention_mask is not None:
321
+ if len(attention_mask.shape) == 2:
322
+ expanded_mask = attention_mask.to(alibi_mask.dtype)
323
+ expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
324
+ ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
325
+ else:
326
+ expanded_mask = attention_mask
327
+ bsz = inputs_embeds.size(0)
328
+ src_len, tgt_len = alibi_mask.size()[-2:]
329
+ expanded_mask = expanded_mask.unsqueeze(1).expand(bsz, 1, src_len, tgt_len).to(alibi_mask.dtype)
330
+ inverted_mask = 1.0 - expanded_mask
331
+ inverted_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min)
332
+ attention_mask = inverted_mask + alibi_mask.unsqueeze(0)
333
+ else:
334
+ attention_mask = alibi_mask
335
+
336
+ hidden_states = inputs_embeds
337
+
338
+ if self.gradient_checkpointing and self.training:
339
+ if use_cache:
340
+ logger.warning_once(
341
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
342
+ )
343
+ use_cache = False
344
+
345
+ # decoder layers
346
+ all_hidden_states = () if output_hidden_states else None
347
+ all_self_attns = () if output_attentions else None
348
+ next_decoder_cache = () if use_cache else None
349
+
350
+ for idx, decoder_layer in enumerate(self.layers):
351
+ if output_hidden_states:
352
+ all_hidden_states += (hidden_states,)
353
+
354
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
355
+
356
+ if self.gradient_checkpointing and self.training:
357
+
358
+ def create_custom_forward(module):
359
+ def custom_forward(*inputs):
360
+ # None for past_key_value
361
+ return module(*inputs, output_attentions, None)
362
+
363
+ return custom_forward
364
+
365
+ layer_outputs = torch.utils.checkpoint.checkpoint(
366
+ create_custom_forward(decoder_layer),
367
+ hidden_states,
368
+ attention_mask,
369
+ None,
370
+ )
371
+ else:
372
+ layer_outputs = decoder_layer(
373
+ hidden_states,
374
+ attention_mask=attention_mask,
375
+ past_key_value=past_key_value,
376
+ output_attentions=output_attentions,
377
+ use_cache=use_cache,
378
+ )
379
+
380
+ hidden_states = layer_outputs[0]
381
+
382
+ if use_cache:
383
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
384
+
385
+ if output_attentions:
386
+ all_self_attns += (layer_outputs[1],)
387
+
388
+ hidden_states = self.norm(hidden_states)
389
+
390
+ # add hidden states from the last decoder layer
391
+ if output_hidden_states:
392
+ all_hidden_states += (hidden_states,)
393
+
394
+ next_cache = next_decoder_cache if use_cache else None
395
+ if not return_dict:
396
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
397
+ return BaseModelOutputWithPast(
398
+ last_hidden_state=hidden_states,
399
+ past_key_values=next_cache,
400
+ hidden_states=all_hidden_states,
401
+ attentions=all_self_attns,
402
+ )
403
+
404
+
405
+ class BaichuanForCausalLM(BaichuanPreTrainedModel):
406
+ def __init__(self, config):
407
+ super().__init__(config)
408
+ self.model = BaichuanModel(config)
409
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
410
+
411
+ # Initialize weights and apply final processing
412
+ self.post_init()
413
+
414
+ def get_input_embeddings(self):
415
+ return self.model.embed_tokens
416
+
417
+ def set_input_embeddings(self, value):
418
+ self.model.embed_tokens = value
419
+
420
+ def get_output_embeddings(self):
421
+ return self.lm_head
422
+
423
+ def set_output_embeddings(self, new_embeddings):
424
+ self.lm_head = new_embeddings
425
+
426
+ def set_decoder(self, decoder):
427
+ self.model = decoder
428
+
429
+ def get_decoder(self):
430
+ return self.model
431
+
432
+ def forward(
433
+ self,
434
+ input_ids: torch.LongTensor = None,
435
+ attention_mask: Optional[torch.Tensor] = None,
436
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
437
+ inputs_embeds: Optional[torch.FloatTensor] = None,
438
+ labels: Optional[torch.LongTensor] = None,
439
+ use_cache: Optional[bool] = None,
440
+ output_attentions: Optional[bool] = False,
441
+ output_hidden_states: Optional[bool] = False,
442
+ return_dict: Optional[bool] = True,
443
+ **kwargs
444
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
445
+
446
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
447
+
448
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
449
+ outputs = self.model(
450
+ input_ids=input_ids,
451
+ attention_mask=attention_mask,
452
+ past_key_values=past_key_values,
453
+ inputs_embeds=inputs_embeds,
454
+ use_cache=use_cache,
455
+ output_attentions=output_attentions,
456
+ output_hidden_states=output_hidden_states,
457
+ return_dict=return_dict,
458
+ )
459
+
460
+ hidden_states = outputs[0]
461
+ logits = self.lm_head(hidden_states)
462
+
463
+ loss = None
464
+ if labels is not None:
465
+ # Shift so that tokens < n predict n
466
+ shift_logits = logits[..., :-1, :].contiguous()
467
+ shift_labels = labels[..., 1:].contiguous()
468
+ # Flatten the tokens
469
+ loss_fct = CrossEntropyLoss()
470
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
471
+ shift_labels = shift_labels.view(-1)
472
+ # Enable model parallelism
473
+ shift_labels = shift_labels.to(shift_logits.device)
474
+ loss = loss_fct(shift_logits, shift_labels)
475
+
476
+ if not return_dict:
477
+ output = (logits,) + outputs[1:]
478
+ return (loss,) + output if loss is not None else output
479
+
480
+ return CausalLMOutputWithPast(
481
+ loss=loss,
482
+ logits=logits,
483
+ past_key_values=outputs.past_key_values,
484
+ hidden_states=outputs.hidden_states,
485
+ attentions=outputs.attentions,
486
+ )
487
+
488
+ def prepare_inputs_for_generation(
489
+ self,
490
+ input_ids: torch.LongTensor,
491
+ past_key_values: Optional[torch.Tensor] = None,
492
+ attention_mask: Optional[torch.Tensor] = None,
493
+ inputs_embeds: Optional[torch.Tensor] = None,
494
+ **kwargs
495
+ ):
496
+ if past_key_values:
497
+ input_ids = input_ids[:, -1:]
498
+
499
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
500
+ if inputs_embeds is not None and past_key_values is None:
501
+ model_inputs = {"inputs_embeds": inputs_embeds}
502
+ else:
503
+ model_inputs = {"input_ids": input_ids}
504
+
505
+ model_inputs.update(
506
+ {
507
+ "past_key_values": past_key_values,
508
+ "use_cache": kwargs.get("use_cache"),
509
+ "attention_mask": attention_mask
510
+ }
511
+ )
512
+ return model_inputs
513
+
514
+ @staticmethod
515
+ def _reorder_cache(past_key_values, beam_idx):
516
+ return tuple(
517
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
518
+ for layer_past in past_key_values
519
+ )
520
+
521
+ def quantize(self, bits: int):
522
+ try:
523
+ from .quantizer import QLinear
524
+ except ImportError:
525
+ raise ImportError(
526
+ f"Needs QLinear to run quantize."
527
+ )
528
+
529
+ for layer in self.model.layers:
530
+ layer.self_attn.W_pack = QLinear(
531
+ bits=bits,
532
+ weight=layer.self_attn.W_pack.weight,
533
+ bias = None,
534
+ )
535
+ layer.self_attn.o_proj = QLinear(
536
+ bits=bits,
537
+ weight=layer.self_attn.o_proj.weight,
538
+ bias = None,
539
+ )
540
+ layer.mlp.gate_proj = QLinear(
541
+ bits=bits,
542
+ weight=layer.mlp.gate_proj.weight,
543
+ bias = None,
544
+ )
545
+ layer.mlp.down_proj = QLinear(
546
+ bits=bits,
547
+ weight=layer.mlp.down_proj.weight,
548
+ bias = None,
549
+ )
550
+ layer.mlp.up_proj = QLinear(
551
+ bits=bits,
552
+ weight=layer.mlp.up_proj.weight,
553
+ bias = None,
554
+ )
555
+ return self
556
+
557
+ @torch.no_grad()
558
+ def chat(self, tokenizer, messages: List[dict], stream=False,
559
+ generation_config: Optional[GenerationConfig]=None):
560
+ generation_config = generation_config or self.generation_config
561
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
562
+ if stream:
563
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
564
+ Thread(target=self.generate, kwargs=dict(
565
+ inputs=input_ids, streamer=streamer,
566
+ generation_config=generation_config,
567
+ )).start()
568
+ return streamer
569
+ else:
570
+ outputs = self.generate(input_ids, generation_config=generation_config)
571
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
572
+ return response