liuxz0801 commited on
Commit
d675deb
1 Parent(s): 73219b9

Upload 10 files

Browse files
config.json CHANGED
@@ -24,6 +24,7 @@
24
  "offset_alibi": 100,
25
  "pad_token_id": 3,
26
  "pretraining_tp": 2,
 
27
  "skip_bias_add": true,
28
  "skip_bias_add_qkv": false,
29
  "slow_but_exact": false,
@@ -35,6 +36,8 @@
35
  "flash_attn":true,
36
  "tie_word_embeddings":false,
37
  "training_seqlen":8192,
38
- "base_seqlen":8192
 
 
39
  }
40
 
 
24
  "offset_alibi": 100,
25
  "pad_token_id": 3,
26
  "pretraining_tp": 2,
27
+ "seq_length": 8192,
28
  "skip_bias_add": true,
29
  "skip_bias_add_qkv": false,
30
  "slow_but_exact": false,
 
36
  "flash_attn":true,
37
  "tie_word_embeddings":false,
38
  "training_seqlen":8192,
39
+ "logn":false,
40
+ "semi_causal":false,
41
+ "embed_layernorm":false
42
  }
43
 
generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_length": 8192,
3
+ "do_sample": false,
4
+ "use_cache": true,
5
+ "temperature": 0.3,
6
+ "top_k": 5,
7
+ "top_p": 0.85,
8
+ "repetition_penalty": 1.01,
9
+ "pad_token_id": 3,
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "user_token_id": 20,
13
+ "bot_token_id": 21
14
+ }
generation_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from collections import deque
3
+ from queue import Queue
4
+ import copy
5
+
6
+
7
+ class History:
8
+
9
+ def __init__(self, tokenizer, history):
10
+ '''
11
+ init from a list of dict
12
+ '''
13
+ # use deque to meet some special situation
14
+ self.input_history = deque()
15
+ self.tokenizer = tokenizer
16
+ if history:
17
+ self._transfer_from_list(history)
18
+
19
+ def _transfer_from_list(self, history):
20
+ for message in history:
21
+ content = message.get("content")
22
+ # the token result may not be equal to the result model gen
23
+ message.update(self.tokenizer(content))
24
+ self.input_history.append(message)
25
+
26
+ def append(self, message):
27
+ content = message.get("content")
28
+ if "input_ids" not in message or "attention_mask" not in message:
29
+ message.update(self.tokenizer(content))
30
+ self.input_history.append(message)
31
+
32
+ def append_left(self, message):
33
+ content = message.get("content")
34
+ if "input_ids" not in message or "attention_mask" not in message:
35
+ message.update(self.tokenizer(content))
36
+ self.input_history.appendleft(message)
37
+
38
+ def pop(self):
39
+ x = self.input_history.pop()
40
+ return x
41
+
42
+ def pop_left(self):
43
+ x = self.pop_left()
44
+ return x
45
+
46
+ def update(self, message):
47
+ self.input_history.pop()
48
+ self.append(message)
49
+
50
+ def __len__(self):
51
+ return self.input_history.__len__()
52
+
53
+ def __str__(self):
54
+ return self.input_history.__str__()
55
+
56
+ def __copy__(self):
57
+ new_instance = type(self)(self.tokenizer, [])
58
+ new_instance.input_history = copy.copy(self.input_history)
59
+ return new_instance
60
+
61
+ def __deepcopy__(self, memodict={}):
62
+ new_instance = type(self)(self.tokenizer, [])
63
+ new_instance.input_history = copy.deepcopy(self.input_history)
64
+ return new_instance
65
+
66
+
67
+ class TelechatIterTextStreamer:
68
+ """
69
+ With reference to the TextIterStreamers in transformers, we have rewritten this class
70
+ """
71
+
72
+ def __init__(
73
+ self, tokenizer, history: History = None, skip_prompt: bool = False, timeout: Optional[float] = None,
74
+ **decode_kwargs
75
+ ):
76
+
77
+ self.tokenizer = tokenizer
78
+ self.history = history
79
+ self.skip_prompt = skip_prompt
80
+ self.timeout = timeout
81
+ self.decode_kwargs = decode_kwargs
82
+
83
+ self.text_queue = Queue()
84
+ self.cache_time = 0
85
+ self.text_until = ""
86
+ self.token_until = []
87
+ self.stop_signal = None
88
+ self.next_tokens_are_prompt = True
89
+
90
+ self.history.append({"role": "bot", "content": self.text_until})
91
+
92
+ def put(self, value):
93
+ """
94
+ put printable text into queue
95
+ """
96
+ if len(value.shape) > 1 and value.shape[0] > 1:
97
+ raise ValueError("TextStreamer only supports batch size 1")
98
+ elif len(value.shape) > 1:
99
+ value = value[0]
100
+
101
+ if self.skip_prompt and self.next_tokens_are_prompt:
102
+ self.next_tokens_are_prompt = False
103
+ return
104
+
105
+ if value[-1] == self.tokenizer.eos_token_id:
106
+ return
107
+
108
+ # there may be some smart way to decode.
109
+ self.token_until.extend(value.tolist())
110
+ text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
111
+
112
+
113
+ if self._is_printable(text) or self.cache_time >= 6:
114
+ output_text = text[len(self.text_until):]
115
+ self.text_until = text
116
+
117
+ else:
118
+ self.cache_time+=1
119
+ return
120
+
121
+ self.on_finalized_text(output_text)
122
+
123
+ def end(self):
124
+ """Flushes any remaining cache and prints a newline to stdout."""
125
+ # Flush the cache, if it exists
126
+ text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
127
+ output_text = text[len(self.text_until):]
128
+ self.text_until = text
129
+ self.on_finalized_text(output_text, stream_end=True)
130
+ self.clear_cache()
131
+
132
+ def clear_cache(self):
133
+ self.cache_time = 0
134
+ self.token_until = []
135
+ self.text_until = ""
136
+ self.history = None
137
+ self.next_tokens_are_prompt = True
138
+
139
+ def on_finalized_text(self, text: str, stream_end: bool = False):
140
+ """Put the text tuple in the queue."""
141
+ self.history.update({"role": "bot", "content": self.text_until, "input_ids": self.token_until,
142
+ "attention_mask": [1] * len(self.token_until)})
143
+ self.text_queue.put((text, self.history), timeout=self.timeout)
144
+ if stream_end:
145
+ self.text_queue.put((self.stop_signal, self.history), timeout=self.timeout)
146
+
147
+ @staticmethod
148
+ def _is_printable(cp):
149
+ """Checks whether tokens can be decoded or not"""
150
+ if "�" in cp:
151
+ return False
152
+ return True
153
+
154
+ def __iter__(self):
155
+ return self
156
+
157
+ def __next__(self):
158
+ value_now, history_until = self.text_queue.get(timeout=self.timeout)
159
+ if value_now == self.stop_signal:
160
+ raise StopIteration()
161
+ else:
162
+ return value_now, history_until
modeling_telechat.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  # coding=utf-8
3
  # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
4
  #
@@ -34,15 +33,15 @@
34
  # limitations under the License.
35
 
36
 
37
-
38
-
39
  """PyTorch TELECHAT model."""
40
 
41
  import warnings
42
- from typing import Optional, Tuple, Union
 
43
 
44
  import torch
45
  import math
 
46
  from torch import nn
47
  import torch.utils.checkpoint
48
  from torch.nn import functional as F
@@ -53,8 +52,10 @@ from transformers.modeling_outputs import (
53
  )
54
  from transformers.modeling_utils import PreTrainedModel
55
  from transformers.utils import logging
 
56
 
57
  from .configuration_telechat import TelechatConfig
 
58
 
59
  logger = logging.get_logger(__name__)
60
 
@@ -78,63 +79,56 @@ except ImportError:
78
  flash_attn_unpadded_func = None
79
 
80
 
81
-
82
  class RotaryEmbedding(torch.nn.Module):
83
  # Extracted from: https://github.com/EleutherAI/gpt-neox
84
- def __init__(self, dim ,config, base=10000, precision=torch.half):
85
  super().__init__()
86
  self.config = config
87
  self.dim = dim
88
  self.base = base
89
- self.inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().half() / dim)).cuda()
90
  self.max_seq_len_cached = None
91
  self.cos_cached = None
92
  self.sin_cached = None
93
- self.precision = precision
94
 
95
- def get_mscale(self,scale=1):
96
  if scale <= 1:
97
  return 1.0
98
  return 0.1 * math.log(scale) + 1.0
99
 
100
  def get_ntk_alpha(self, true_seq_len):
101
- context_value = math.log(true_seq_len / self.config.base_seqlen, 2) + 1
102
- # ntk_alpha = 2 ** context_value - 1
103
  ntk_alpha = 2 ** math.ceil(context_value) - 1
104
  ntk_alpha = max(ntk_alpha, 1)
105
  return ntk_alpha
106
 
107
- def forward(self, x, seq_dim=0, seq_len=None):
108
- if seq_len is None:
109
- seq_len = x.shape[seq_dim]
110
- seq_len = max(seq_len, self.config.training_seqlen)
 
 
111
  ntk_alpha = self.get_ntk_alpha(seq_len)
112
- self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
113
- if True:
114
- base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
115
- self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim ))
116
- self.max_seq_len_cached = seq_len
117
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
118
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
119
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
120
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
121
- if self.precision == torch.bfloat16:
122
- emb = emb.float()
123
- # [sx, 1 (b * np), hn]
124
- self.cos_cached = self.mscale *emb.cos()[:, None, :].half()
125
- self.sin_cached = self.mscale *emb.sin()[:, None, :].half()
126
- if self.precision == torch.bfloat16:
127
- self.cos_cached = self.cos_cached.bfloat16()
128
- self.sin_cached = self.sin_cached.bfloat16()
129
  return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
130
 
131
 
132
-
133
  # rotary pos emb helpers:
134
  def rotate_half(x):
135
  x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
136
  return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
137
 
 
138
  def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
139
  cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
140
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
@@ -192,7 +186,6 @@ class FlashSelfAttention(torch.nn.Module):
192
  q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
193
  cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
194
  device=q.device)
195
- self.training = False
196
  if self.training:
197
  # during training q,k,v always have same seqlen
198
  assert seqlen_k == seqlen_q
@@ -218,7 +211,6 @@ class FlashSelfAttention(torch.nn.Module):
218
  return output
219
 
220
 
221
-
222
  def _make_causal_mask(
223
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
224
  ) -> torch.BoolTensor:
@@ -249,7 +241,6 @@ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
249
  return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
250
 
251
 
252
-
253
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
254
  """
255
  Dropout add function
@@ -332,7 +323,7 @@ class TelechatGelu(nn.Module):
332
 
333
 
334
  class TelechatAttention(nn.Module):
335
- def __init__(self, config: TelechatConfig ,layer_idx):
336
  super().__init__()
337
  self.kv_cache = None
338
  self.layer_idx = layer_idx
@@ -361,16 +352,13 @@ class TelechatAttention(nn.Module):
361
  self.key_value = nn.Linear(self.hidden_size, kv_projection_size * 2, bias=False)
362
  self.dense = nn.Linear(self.hidden_size, self.hidden_size)
363
  self.attention_dropout = nn.Dropout(config.attention_dropout)
364
- self.rotary_emb = RotaryEmbedding(self.head_dim ,config=config)
365
 
366
  self.core_attention_flash = FlashSelfAttention(
367
  causal=True, attention_dropout=config.attention_dropout
368
  )
369
 
370
  self.last_key_layer = None
371
- #logn_list = [math.log(i, 4096) if i > 4096 else 1 for i in range(1, 32768)]
372
- #self.logn_tensor = torch.tensor(logn_list)[None, :, None, None].half().cuda()
373
-
374
 
375
  def repeat_kv(self, hidden_states, n_rep):
376
  slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
@@ -440,27 +428,26 @@ class TelechatAttention(nn.Module):
440
  seq_len = key_layer.shape[0]
441
  offset = 0
442
 
443
- if use_cache and layer_past != None:
444
- past_key, past_value = layer_past
445
  offset = past_key.shape[0]
446
  seq_len += offset
447
 
448
- cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
449
 
450
  query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
451
  if use_cache:
452
  if layer_past != None:
453
  past_key, past_value = layer_past
454
- key_layer = torch.cat((past_key, key_layer[-1, ...].unsqueeze(0)) ,dim=0)
455
- value_layer = torch.cat((past_value ,value_layer[-1 ,...].unsqueeze(0)) ,dim = 0)
456
- layer_past = key_layer ,value_layer
457
  s, bz, head, dim = value_layer.shape
458
  s_key = key_layer.shape[0]
459
  s_query = query_layer.shape[0]
460
  query_layer = query_layer.reshape((s_query, bz, head, dim))
461
  key_layer = key_layer.reshape((s_key, bz, head, dim))
462
 
463
-
464
  if self.config.flash_attn:
465
  q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() for x in
466
  (query_layer, key_layer, value_layer)]
@@ -468,22 +455,23 @@ class TelechatAttention(nn.Module):
468
  context_layer = rearrange(context_layer, 'b s h d -> b s (h d)').contiguous()
469
  else:
470
  ##[sq, b, np, hn] -> [sq, b * np, hn]
471
- query_layer = query_layer.reshape(s_query ,bz * self.num_heads, dim)
472
  # [sk, b, np, hn] -> [sk, b * np, hn]
473
  key_layer = key_layer.reshape(s_key, bz * self.num_heads, dim)
474
- matmul_result = self.inv_norm_factor * torch.einsum('bik,bkj->bij', query_layer.transpose(0, 1), key_layer.transpose(0, 1).transpose(1, 2))
 
475
 
476
  attention_scores = matmul_result.view(bz, self.num_heads, s_query, s_key)
477
 
478
  input_dtype = attention_scores.dtype
479
- if input_dtype == torch.float16:
480
  attention_scores = attention_scores.to(torch.float)
481
  attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
482
  attention_probs = F.softmax(attn_weights, dim=-1).to(input_dtype) ##dtype = torch.float32
483
  attention_probs = self.attention_dropout(attention_probs)
484
  attention_probs_reshaped = attention_probs.view(bz * self.num_heads, s_query, s_key)
485
 
486
- value_layer = value_layer.reshape(s_key ,bz * self.num_heads, dim)
487
  context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
488
  context_layer = self._merge_heads(context_layer)
489
 
@@ -497,6 +485,7 @@ class TelechatAttention(nn.Module):
497
 
498
  return output_tensor, layer_past
499
 
 
500
  class TelechatMLP(nn.Module):
501
  def __init__(self, config: TelechatConfig):
502
  super().__init__()
@@ -513,14 +502,14 @@ class TelechatMLP(nn.Module):
513
 
514
 
515
  class TelechatBlock(nn.Module):
516
- def __init__(self, config: TelechatConfig ,layer_idx):
517
  super().__init__()
518
  hidden_size = config.hidden_size
519
 
520
  self.input_layernorm = MixedFusedRMSNorm(hidden_size, eps=config.layer_norm_epsilon)
521
  self.num_heads = config.n_head
522
  self.layer_idx = layer_idx
523
- self.self_attention = TelechatAttention(config ,layer_idx)
524
  self.post_attention_layernorm = MixedFusedRMSNorm(hidden_size, eps=config.layer_norm_epsilon)
525
 
526
  self.mlp = TelechatMLP(config)
@@ -611,12 +600,11 @@ class TelechatModel(TelechatPreTrainedModel):
611
  if self.config.embed_layernorm:
612
  self.word_embeddings_layernorm = MixedFusedRMSNorm(self.embed_dim, eps=config.layer_norm_epsilon)
613
 
614
- self.h = nn.ModuleList([TelechatBlock(config ,_) for _ in range(config.num_hidden_layers)])
615
  self.ln_f = MixedFusedRMSNorm(self.embed_dim, eps=config.layer_norm_epsilon)
616
  self.gradient_checkpointing = False
617
  self.post_init()
618
 
619
-
620
  def get_input_embeddings(self):
621
  return self.word_embeddings
622
 
@@ -661,7 +649,6 @@ class TelechatModel(TelechatPreTrainedModel):
661
  use_cache = use_cache if use_cache is not None else self.config.use_cache
662
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
663
 
664
-
665
  if input_ids is not None:
666
  batch_size, seq_length = input_ids.shape
667
  elif inputs_embeds is not None:
@@ -670,7 +657,6 @@ class TelechatModel(TelechatPreTrainedModel):
670
  if past_key_values is None:
671
  past_key_values = tuple([None] * len(self.h))
672
 
673
-
674
  if inputs_embeds is None:
675
  inputs_embeds = self.word_embeddings(input_ids)
676
  hidden_states = inputs_embeds
@@ -750,7 +736,8 @@ class TelechatModel(TelechatPreTrainedModel):
750
 
751
  class TelechatForCausalLM(TelechatPreTrainedModel):
752
  # _tied_weights_keys = ["lm_head.weight"]
753
- _keys_to_ignore_on_load_missing = [ r"lm_head.weight"]
 
754
  def __init__(self, config: TelechatConfig):
755
  super().__init__(config)
756
  self.transformer = TelechatModel(config)
@@ -838,3 +825,86 @@ class TelechatForCausalLM(TelechatPreTrainedModel):
838
  hidden_states=transformer_outputs.hidden_states,
839
  attentions=transformer_outputs.attentions,
840
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # coding=utf-8
2
  # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
3
  #
 
33
  # limitations under the License.
34
 
35
 
 
 
36
  """PyTorch TELECHAT model."""
37
 
38
  import warnings
39
+ from typing import Optional, Tuple, Union, List, Dict
40
+ from threading import Thread
41
 
42
  import torch
43
  import math
44
+ import copy
45
  from torch import nn
46
  import torch.utils.checkpoint
47
  from torch.nn import functional as F
 
52
  )
53
  from transformers.modeling_utils import PreTrainedModel
54
  from transformers.utils import logging
55
+ from transformers import GenerationConfig
56
 
57
  from .configuration_telechat import TelechatConfig
58
+ from .generation_utils import History, TelechatIterTextStreamer
59
 
60
  logger = logging.get_logger(__name__)
61
 
 
79
  flash_attn_unpadded_func = None
80
 
81
 
 
82
  class RotaryEmbedding(torch.nn.Module):
83
  # Extracted from: https://github.com/EleutherAI/gpt-neox
84
+ def __init__(self, dim, config, base=10000):
85
  super().__init__()
86
  self.config = config
87
  self.dim = dim
88
  self.base = base
 
89
  self.max_seq_len_cached = None
90
  self.cos_cached = None
91
  self.sin_cached = None
 
92
 
93
+ def get_mscale(self, scale=1):
94
  if scale <= 1:
95
  return 1.0
96
  return 0.1 * math.log(scale) + 1.0
97
 
98
  def get_ntk_alpha(self, true_seq_len):
99
+ context_value = math.log(true_seq_len / 4096, 2) + 1
 
100
  ntk_alpha = 2 ** math.ceil(context_value) - 1
101
  ntk_alpha = max(ntk_alpha, 1)
102
  return ntk_alpha
103
 
104
+ def forward(self, x, dtype, seq_dim=0):
105
+ seq_len = x.shape[seq_dim]
106
+ self.mscale = 1.0
107
+ if not self.training:
108
+ seq_len = max(seq_len, self.config.training_seqlen)
109
+ self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
110
  ntk_alpha = self.get_ntk_alpha(seq_len)
111
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
112
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
113
+ self.max_seq_len_cached = seq_len
114
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
115
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
116
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
118
+ # if self.precision == torch.bfloat16:
119
+ emb = emb.float() if dtype == torch.bfloat16 else emb
120
+ # [sx, 1 (b * np), hn]
121
+ self.cos_cached = self.mscale * emb.cos()[:, None, :].to(dtype)
122
+ self.sin_cached = self.mscale * emb.sin()[:, None, :].to(dtype)
 
 
 
 
 
123
  return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
124
 
125
 
 
126
  # rotary pos emb helpers:
127
  def rotate_half(x):
128
  x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
129
  return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
130
 
131
+
132
  def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
133
  cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
134
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
186
  q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
187
  cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
188
  device=q.device)
 
189
  if self.training:
190
  # during training q,k,v always have same seqlen
191
  assert seqlen_k == seqlen_q
 
211
  return output
212
 
213
 
 
214
  def _make_causal_mask(
215
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
216
  ) -> torch.BoolTensor:
 
241
  return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
242
 
243
 
 
244
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
245
  """
246
  Dropout add function
 
323
 
324
 
325
  class TelechatAttention(nn.Module):
326
+ def __init__(self, config: TelechatConfig, layer_idx):
327
  super().__init__()
328
  self.kv_cache = None
329
  self.layer_idx = layer_idx
 
352
  self.key_value = nn.Linear(self.hidden_size, kv_projection_size * 2, bias=False)
353
  self.dense = nn.Linear(self.hidden_size, self.hidden_size)
354
  self.attention_dropout = nn.Dropout(config.attention_dropout)
355
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config=config)
356
 
357
  self.core_attention_flash = FlashSelfAttention(
358
  causal=True, attention_dropout=config.attention_dropout
359
  )
360
 
361
  self.last_key_layer = None
 
 
 
362
 
363
  def repeat_kv(self, hidden_states, n_rep):
364
  slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
 
428
  seq_len = key_layer.shape[0]
429
  offset = 0
430
 
431
+ if use_cache and layer_past != None:
432
+ past_key, past_value = layer_past
433
  offset = past_key.shape[0]
434
  seq_len += offset
435
 
436
+ cos, sin = self.rotary_emb(value_layer, dtype=value_layer.dtype)
437
 
438
  query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
439
  if use_cache:
440
  if layer_past != None:
441
  past_key, past_value = layer_past
442
+ key_layer = torch.cat((past_key, key_layer[-1, ...].unsqueeze(0)), dim=0)
443
+ value_layer = torch.cat((past_value, value_layer[-1, ...].unsqueeze(0)), dim=0)
444
+ layer_past = key_layer, value_layer
445
  s, bz, head, dim = value_layer.shape
446
  s_key = key_layer.shape[0]
447
  s_query = query_layer.shape[0]
448
  query_layer = query_layer.reshape((s_query, bz, head, dim))
449
  key_layer = key_layer.reshape((s_key, bz, head, dim))
450
 
 
451
  if self.config.flash_attn:
452
  q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() for x in
453
  (query_layer, key_layer, value_layer)]
 
455
  context_layer = rearrange(context_layer, 'b s h d -> b s (h d)').contiguous()
456
  else:
457
  ##[sq, b, np, hn] -> [sq, b * np, hn]
458
+ query_layer = query_layer.reshape(s_query, bz * self.num_heads, dim)
459
  # [sk, b, np, hn] -> [sk, b * np, hn]
460
  key_layer = key_layer.reshape(s_key, bz * self.num_heads, dim)
461
+ matmul_result = self.inv_norm_factor * torch.einsum('bik,bkj->bij', query_layer.transpose(0, 1),
462
+ key_layer.transpose(0, 1).transpose(1, 2))
463
 
464
  attention_scores = matmul_result.view(bz, self.num_heads, s_query, s_key)
465
 
466
  input_dtype = attention_scores.dtype
467
+ if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
468
  attention_scores = attention_scores.to(torch.float)
469
  attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
470
  attention_probs = F.softmax(attn_weights, dim=-1).to(input_dtype) ##dtype = torch.float32
471
  attention_probs = self.attention_dropout(attention_probs)
472
  attention_probs_reshaped = attention_probs.view(bz * self.num_heads, s_query, s_key)
473
 
474
+ value_layer = value_layer.reshape(s_key, bz * self.num_heads, dim)
475
  context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
476
  context_layer = self._merge_heads(context_layer)
477
 
 
485
 
486
  return output_tensor, layer_past
487
 
488
+
489
  class TelechatMLP(nn.Module):
490
  def __init__(self, config: TelechatConfig):
491
  super().__init__()
 
502
 
503
 
504
  class TelechatBlock(nn.Module):
505
+ def __init__(self, config: TelechatConfig, layer_idx):
506
  super().__init__()
507
  hidden_size = config.hidden_size
508
 
509
  self.input_layernorm = MixedFusedRMSNorm(hidden_size, eps=config.layer_norm_epsilon)
510
  self.num_heads = config.n_head
511
  self.layer_idx = layer_idx
512
+ self.self_attention = TelechatAttention(config, layer_idx)
513
  self.post_attention_layernorm = MixedFusedRMSNorm(hidden_size, eps=config.layer_norm_epsilon)
514
 
515
  self.mlp = TelechatMLP(config)
 
600
  if self.config.embed_layernorm:
601
  self.word_embeddings_layernorm = MixedFusedRMSNorm(self.embed_dim, eps=config.layer_norm_epsilon)
602
 
603
+ self.h = nn.ModuleList([TelechatBlock(config, _) for _ in range(config.num_hidden_layers)])
604
  self.ln_f = MixedFusedRMSNorm(self.embed_dim, eps=config.layer_norm_epsilon)
605
  self.gradient_checkpointing = False
606
  self.post_init()
607
 
 
608
  def get_input_embeddings(self):
609
  return self.word_embeddings
610
 
 
649
  use_cache = use_cache if use_cache is not None else self.config.use_cache
650
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
651
 
 
652
  if input_ids is not None:
653
  batch_size, seq_length = input_ids.shape
654
  elif inputs_embeds is not None:
 
657
  if past_key_values is None:
658
  past_key_values = tuple([None] * len(self.h))
659
 
 
660
  if inputs_embeds is None:
661
  inputs_embeds = self.word_embeddings(input_ids)
662
  hidden_states = inputs_embeds
 
736
 
737
  class TelechatForCausalLM(TelechatPreTrainedModel):
738
  # _tied_weights_keys = ["lm_head.weight"]
739
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
740
+
741
  def __init__(self, config: TelechatConfig):
742
  super().__init__(config)
743
  self.transformer = TelechatModel(config)
 
825
  hidden_states=transformer_outputs.hidden_states,
826
  attentions=transformer_outputs.attentions,
827
  )
828
+
829
+ def chat(self, tokenizer, question: str = '', history: Union[List[Dict], History] = None, stream: bool = False,
830
+ generation_config: Optional[GenerationConfig] = None, **kwargs):
831
+ """
832
+ Args:
833
+ tokenizer: the tokenizer of telechat
834
+ question: question which the model reply in this turn
835
+ history: history which will format the input for telechat
836
+ stream: if return the full text at last or yield the text in token
837
+ generation_config: configuration for generation
838
+ **kwargs: args which will update the generation config or pass to model forward
839
+ """
840
+ generation_config = generation_config or self.generation_config
841
+ if not generation_config:
842
+ logger.error("generation_config is None")
843
+ raise ValueError("generation_config must not be None")
844
+ if not question:
845
+ logger.error("question is empty")
846
+ raise ValueError("question must not be empty")
847
+ if history is None:
848
+ history = []
849
+
850
+ # we update and check generate_config here for building inputs.
851
+
852
+ generation_config = copy.deepcopy(generation_config)
853
+ user_id = generation_config.user_token_id
854
+ bot_id = generation_config.bot_token_id
855
+ model_kwargs = generation_config.update(**kwargs)
856
+ generation_config.validate()
857
+
858
+ # transfer to History
859
+ if not isinstance(history, History):
860
+ history = History(tokenizer, history)
861
+
862
+ inputs = self.build_inputs_for_chat(tokenizer, question, history, generation_config, user_id, bot_id)
863
+ history.append({"role": "user", "content": question})
864
+ if stream:
865
+ streamer = TelechatIterTextStreamer(tokenizer, history,skip_prompt=True)
866
+ Thread(target=self.generate, kwargs=dict(
867
+ inputs=inputs.to(self.device), streamer=streamer,
868
+ generation_config=generation_config, **model_kwargs
869
+ )).start()
870
+ return streamer
871
+ else:
872
+ outputs = self.generate(inputs.to(self.device), generation_config=generation_config, **model_kwargs)
873
+ response = tokenizer.decode(outputs[0][len(inputs[0]):-1])
874
+ history.append({"role": "bot", "content": response})
875
+ return response, history
876
+
877
+ def build_inputs_for_chat(self, tokenizer, question, history, generation_config, usr_id, bot_id):
878
+ """
879
+ check history and build inputs here
880
+ """
881
+ # first tokenize question
882
+ q_token = tokenizer(question)
883
+ qa_history = copy.deepcopy(history)
884
+
885
+ # get the max length we should build our inputs in
886
+ model_max_length = self.config.seq_length
887
+ build_max_length = max(0, model_max_length - generation_config.max_new_tokens) \
888
+ if generation_config.max_new_tokens else max(0, generation_config.max_length)
889
+ if build_max_length < 3:
890
+ logger.warning("the model can not meet the requirements of input length,Please check config")
891
+ raise ValueError("")
892
+
893
+ # trunc left
894
+ input_tokens = [usr_id] + q_token["input_ids"][-build_max_length + 1:] + [bot_id]
895
+ length = len(input_tokens)
896
+
897
+ while len(qa_history) != 0:
898
+ message = qa_history.pop()
899
+ if message["role"] == "user":
900
+ tokens = [usr_id] + message["input_ids"]
901
+ elif message["role"] == "bot":
902
+ tokens = [bot_id] + message["input_ids"] + [generation_config.eos_token_id]
903
+ else:
904
+ tokens = []
905
+ if len(tokens) + length >= build_max_length:
906
+ break
907
+ else:
908
+ input_tokens = tokens + input_tokens
909
+
910
+ return torch.tensor([input_tokens], dtype=torch.int64)
tokenization_telechat.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+ import sentencepiece as spm
5
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
11
+
12
+ # TODO: when we get download url from huggingface, refresh the map
13
+ PRETRAINED_VOCAB_FILES_MAP = {
14
+ "vocab_file": {},
15
+ "tokenizer_file": {},
16
+ }
17
+
18
+
19
+ class TelechatTokenizer(PreTrainedTokenizer):
20
+
21
+ vocab_files_names = VOCAB_FILES_NAMES
22
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
23
+ model_input_names = ["input_ids", "attention_mask"]
24
+
25
+ def __init__(
26
+ self,
27
+ vocab_file,
28
+ unk_token="<unk>",
29
+ bos_token="<_start>",
30
+ eos_token="<_end>",
31
+ pad_token="<_pad>",
32
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
33
+ add_bos_token=True,
34
+ add_eos_token=False,
35
+ clean_up_tokenization_spaces=False,
36
+ **kwargs,
37
+ ):
38
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
39
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
40
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
41
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
42
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
43
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
44
+ self.sp_model.Load(vocab_file)
45
+ super().__init__(
46
+ bos_token=bos_token,
47
+ eos_token=eos_token,
48
+ unk_token=unk_token,
49
+ pad_token=pad_token,
50
+ add_bos_token=add_bos_token,
51
+ add_eos_token=add_eos_token,
52
+ sp_model_kwargs=self.sp_model_kwargs,
53
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
54
+ **kwargs,
55
+ )
56
+ self.vocab_file = vocab_file
57
+ self.add_bos_token = add_bos_token
58
+ self.add_eos_token = add_eos_token
59
+
60
+
61
+ def __getstate__(self):
62
+ state = self.__dict__.copy()
63
+ state["sp_model"] = None
64
+ return state
65
+
66
+ def __setstate__(self, d):
67
+ self.__dict__ = d
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(self.vocab_file)
70
+
71
+ @property
72
+ def vocab_size(self):
73
+ """Returns vocab size"""
74
+ return self.sp_model.get_piece_size()
75
+
76
+ def get_vocab(self):
77
+ """Returns vocab as a dict"""
78
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
79
+ vocab.update(self.added_tokens_encoder)
80
+ return vocab
81
+
82
+ def _tokenize(self, text):
83
+ """Returns a tokenized string."""
84
+ return self.sp_model.encode(text, out_type=str)
85
+
86
+ def _convert_token_to_id(self, token):
87
+ """Converts a token (str) in an id using the vocab."""
88
+ return self.sp_model.piece_to_id(token)
89
+
90
+ def _convert_id_to_token(self, index):
91
+ """Converts an index (integer) in a token (str) using the vocab."""
92
+ token = self.sp_model.IdToPiece(index)
93
+ return token
94
+
95
+ def convert_tokens_to_string(self, tokens):
96
+ """Converts a sequence of tokens (string) in a single string."""
97
+ current_sub_tokens = []
98
+ out_string = ""
99
+ prev_is_special = False
100
+ for i, token in enumerate(tokens):
101
+ # make sure that special tokens are not decoded using sentencepiece model
102
+ if token in self.all_special_tokens:
103
+ if not prev_is_special and i != 0:
104
+ out_string += " "
105
+ out_string += self.sp_model.decode(current_sub_tokens) + token
106
+ prev_is_special = True
107
+ current_sub_tokens = []
108
+ else:
109
+ current_sub_tokens.append(token)
110
+ prev_is_special = False
111
+ out_string += self.sp_model.decode(current_sub_tokens)
112
+ return out_string
113
+
114
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
115
+ """
116
+ Save the vocabulary and special tokens file to a directory.
117
+
118
+ Args:
119
+ save_directory (`str`):
120
+ The directory in which to save the vocabulary.
121
+
122
+ Returns:
123
+ `Tuple(str)`: Paths to the files saved.
124
+ """
125
+ if not os.path.isdir(save_directory):
126
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
127
+ return
128
+ out_vocab_file = os.path.join(
129
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
130
+ )
131
+
132
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
133
+ copyfile(self.vocab_file, out_vocab_file)
134
+ elif not os.path.isfile(self.vocab_file):
135
+ with open(out_vocab_file, "wb") as fi:
136
+ content_spiece_model = self.sp_model.serialized_model_proto()
137
+ fi.write(content_spiece_model)
138
+
139
+ return (out_vocab_file,)
140
+
141
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
142
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
143
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
144
+
145
+ output = bos_token_id + token_ids_0 + eos_token_id
146
+
147
+ if token_ids_1 is not None:
148
+ output = output + bos_token_id + token_ids_1 + eos_token_id
149
+
150
+ return output
151
+
152
+ def get_special_tokens_mask(
153
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
154
+ ) -> List[int]:
155
+ """
156
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
157
+ special tokens using the tokenizer `prepare_for_model` method.
158
+
159
+ Args:
160
+ token_ids_0 (`List[int]`):
161
+ List of IDs.
162
+ token_ids_1 (`List[int]`, *optional*):
163
+ Optional second list of IDs for sequence pairs.
164
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
165
+ Whether or not the token list is already formatted with special tokens for the model.
166
+
167
+ Returns:
168
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
169
+ """
170
+ if already_has_special_tokens:
171
+ return super().get_special_tokens_mask(
172
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
173
+ )
174
+
175
+ bos_token_id = [1] if self.add_bos_token else []
176
+ eos_token_id = [1] if self.add_eos_token else []
177
+
178
+ if token_ids_1 is None:
179
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
180
+ return (
181
+ bos_token_id
182
+ + ([0] * len(token_ids_0))
183
+ + eos_token_id
184
+ + bos_token_id
185
+ + ([0] * len(token_ids_1))
186
+ + eos_token_id
187
+ )
188
+
189
+ def create_token_type_ids_from_sequences(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
191
+ ) -> List[int]:
192
+ """
193
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
194
+ sequence pair mask has the following format:
195
+
196
+ ```
197
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
198
+ | first sequence | second sequence |
199
+ ```
200
+
201
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
202
+
203
+ Args:
204
+ token_ids_0 (`List[int]`):
205
+ List of ids.
206
+ token_ids_1 (`List[int]`, *optional*):
207
+ Optional second list of IDs for sequence pairs.
208
+
209
+ Returns:
210
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
211
+ """
212
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
213
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
214
+
215
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
216
+
217
+ if token_ids_1 is not None:
218
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
219
+
220
+ return output
tokenizer_config.json CHANGED
@@ -1,9 +1,9 @@
1
  {
2
- "name_or_path": "ChinaTelecom/telechat3-7b",
3
  "tokenizer_class": "TelechatTokenizer",
4
  "auto_map": {
5
  "AutoTokenizer": [
6
- "tokenization_telechat3.TelechatTokenizer",
7
  null
8
  ]
9
  },
 
1
  {
2
+ "name_or_path": "ChinaTelecom/telechat-12b",
3
  "tokenizer_class": "TelechatTokenizer",
4
  "auto_map": {
5
  "AutoTokenizer": [
6
+ "tokenization_telechat.TelechatTokenizer",
7
  null
8
  ]
9
  },