Triang-jyed-driung commited on
Commit
35429bb
·
1 Parent(s): 3a6cb37

Added chat template and attention mask

Browse files
__init__.py ADDED
File without changes
added_tokens.json CHANGED
@@ -1,3 +1,3 @@
1
  {
2
- "<s>": 0
3
  }
 
1
  {
2
+ "<|rwkv_tokenizer_end_of_text|>": 0
3
  }
config.json CHANGED
@@ -4,6 +4,7 @@
4
  ],
5
  "attention_hidden_size": 768,
6
  "auto_map": {
 
7
  "AutoConfig": "configuration_rwkv7.Rwkv7Config",
8
  "AutoModelForCausalLM": "modeling_rwkv7.Rwkv7ForCausalLM"
9
  },
 
4
  ],
5
  "attention_hidden_size": 768,
6
  "auto_map": {
7
+ "AutoModel": "modeling_rwkv7.Rwkv7Model",
8
  "AutoConfig": "configuration_rwkv7.Rwkv7Config",
9
  "AutoModelForCausalLM": "modeling_rwkv7.Rwkv7ForCausalLM"
10
  },
generation_config.json CHANGED
@@ -1,12 +1,13 @@
1
  {
2
  "chat_format": "chatml",
 
3
  "eos_token_id": 0,
4
  "pad_token_id": 0,
5
- "max_window_size": 4096,
6
  "max_new_tokens": 4096,
7
  "do_sample": true,
8
- "top_k": 0,
9
- "top_p": 0.1,
10
- "repetition_penalty": 1.0,
11
  "transformers_version": "4.31.1"
12
  }
 
1
  {
2
  "chat_format": "chatml",
3
+ "bos_token_id": 0,
4
  "eos_token_id": 0,
5
  "pad_token_id": 0,
6
+ "max_window_size": 2147483647,
7
  "max_new_tokens": 4096,
8
  "do_sample": true,
9
+ "top_k": 65536,
10
+ "top_p": 1.0,
11
+ "temperature": 1.0,
12
  "transformers_version": "4.31.1"
13
  }
hf_rwkv_tokenizer.py CHANGED
@@ -145,7 +145,7 @@ class Rwkv6Tokenizer(PreTrainedTokenizer):
145
  model_input_names = ["input_ids", "attention_mask"]
146
 
147
  def __init__(
148
- self, vocab_file, bos_token="<s>", eos_token="<s>", unk_token="<s>", **kwargs
149
  ):
150
  if not os.path.isfile(vocab_file):
151
  raise ValueError(
 
145
  model_input_names = ["input_ids", "attention_mask"]
146
 
147
  def __init__(
148
+ self, vocab_file, bos_token="<|rwkv_tokenizer_end_of_text|>", eos_token="<|rwkv_tokenizer_end_of_text|>", unk_token="<|rwkv_tokenizer_end_of_text|>", **kwargs
149
  ):
150
  if not os.path.isfile(vocab_file):
151
  raise ValueError(
modeling_rwkv7.py CHANGED
@@ -317,7 +317,7 @@ class Rwkv7SelfAttention(nn.Module):
317
  self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
318
 
319
 
320
- def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True):
321
  # Mix hidden with the previous timestep to produce key, value, receptance
322
  if hidden.size(1) == 1 and state is not None:
323
  shifted = state[0][self.layer_id]
@@ -371,6 +371,8 @@ class Rwkv7SelfAttention(nn.Module):
371
  rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
372
 
373
  xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
 
 
374
  #x = x + ((r * k * self.r_k).view(B,T,H,N).sum(dim=-1, keepdim=True) * v.view(B,T,H,N)).view(B,T,H*N)
375
  xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
376
  xx = self.output(xx * g)
@@ -435,11 +437,15 @@ class Rwkv7Block(nn.Module):
435
  self.attention = Rwkv7SelfAttention(config, layer_id)
436
  self.feed_forward = Rwkv7FeedForward(config, layer_id)
437
 
438
- def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True):
439
- attention, state, v_first = self.attention(self.ln1(hidden), state=state, v_first=v_first, use_cache=use_cache, seq_mode=seq_mode)
 
 
440
  hidden = hidden + attention
441
 
442
- feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
 
 
443
  hidden = hidden + feed_forward
444
 
445
  outputs = (hidden, state, v_first)
@@ -743,13 +749,15 @@ class Rwkv7Model(Rwkv7PreTrainedModel):
743
 
744
  seq_mode = inputs_embeds.shape[1] > 1
745
  hidden_states = self.pre_ln(inputs_embeds)
 
 
746
  v_first = None
747
 
748
  all_self_attentions = () if output_attentions else None
749
  all_hidden_states = () if output_hidden_states else None
750
  for idx, block in enumerate(self.blocks):
751
  hidden_states, state, v_first, attentions = block(
752
- hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode
753
  )
754
 
755
  if output_hidden_states:
@@ -759,6 +767,8 @@ class Rwkv7Model(Rwkv7PreTrainedModel):
759
  all_self_attentions = all_self_attentions + (attentions,)
760
 
761
  hidden_states = self.ln_out(hidden_states)
 
 
762
 
763
  if output_hidden_states:
764
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -846,6 +856,7 @@ class Rwkv7ForCausalLM(Rwkv7PreTrainedModel, GenerationMixin):
846
  output_attentions=output_attentions,
847
  output_hidden_states=output_hidden_states,
848
  return_dict=return_dict,
 
849
  )
850
  hidden_states = outputs[0]
851
 
 
317
  self.ln_x = nn.GroupNorm(H, C, eps=self.head_size * 1e-5)
318
 
319
 
320
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, seq_mode=True, attention_mask=None):
321
  # Mix hidden with the previous timestep to produce key, value, receptance
322
  if hidden.size(1) == 1 and state is not None:
323
  shifted = state[0][self.layer_id]
 
371
  rwkv7_attn_triton(r, w, k, v, -kk, kk*a, self.head_size)
372
 
373
  xx = torch.nn.functional.group_norm(xx.view(B*T,H*N), num_groups=H, weight=self.ln_x.weight, bias=self.ln_x.bias, eps = self.ln_x.eps).view(B,T,H*N)
374
+ if attention_mask is not None:
375
+ xx *= attention_mask.unsqueeze(-1)
376
  #x = x + ((r * k * self.r_k).view(B,T,H,N).sum(dim=-1, keepdim=True) * v.view(B,T,H,N)).view(B,T,H*N)
377
  xx = xx + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
378
  xx = self.output(xx * g)
 
437
  self.attention = Rwkv7SelfAttention(config, layer_id)
438
  self.feed_forward = Rwkv7FeedForward(config, layer_id)
439
 
440
+ def forward(self, hidden, state=None, v_first=None, use_cache=False, output_attentions=False, seq_mode=True, attention_mask=None):
441
+ attention, state, v_first = self.attention(
442
+ self.ln1(hidden) if attention_mask is None else self.ln1(hidden) * attention_mask.unsqueeze(-1) ,
443
+ state=state, v_first=v_first, use_cache=use_cache, seq_mode=seq_mode, attention_mask=attention_mask)
444
  hidden = hidden + attention
445
 
446
+ feed_forward, state = self.feed_forward(
447
+ self.ln2(hidden) if attention_mask is None else self.ln2(hidden) * attention_mask.unsqueeze(-1) ,
448
+ state=state)
449
  hidden = hidden + feed_forward
450
 
451
  outputs = (hidden, state, v_first)
 
749
 
750
  seq_mode = inputs_embeds.shape[1] > 1
751
  hidden_states = self.pre_ln(inputs_embeds)
752
+ if attention_mask is not None:
753
+ hidden_states *= attention_mask.unsqueeze(-1)
754
  v_first = None
755
 
756
  all_self_attentions = () if output_attentions else None
757
  all_hidden_states = () if output_hidden_states else None
758
  for idx, block in enumerate(self.blocks):
759
  hidden_states, state, v_first, attentions = block(
760
+ hidden_states, state=state, v_first=v_first, use_cache=use_cache, output_attentions=output_attentions, seq_mode=seq_mode, attention_mask=attention_mask,
761
  )
762
 
763
  if output_hidden_states:
 
767
  all_self_attentions = all_self_attentions + (attentions,)
768
 
769
  hidden_states = self.ln_out(hidden_states)
770
+ if attention_mask is not None:
771
+ hidden_states *= attention_mask.unsqueeze(-1)
772
 
773
  if output_hidden_states:
774
  all_hidden_states = all_hidden_states + (hidden_states,)
 
856
  output_attentions=output_attentions,
857
  output_hidden_states=output_hidden_states,
858
  return_dict=return_dict,
859
+ attention_mask=attention_mask,
860
  )
861
  hidden_states = outputs[0]
862
 
special_tokens_map.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "bos_token": "<s>",
3
- "eos_token": "<s>",
4
- "unk_token": "<s>"
5
  }
 
1
  {
2
+ "bos_token": "<|rwkv_tokenizer_end_of_text|>",
3
+ "eos_token": "<|rwkv_tokenizer_end_of_text|>",
4
+ "unk_token": "<|rwkv_tokenizer_end_of_text|>"
5
  }
tokenizer_config.json CHANGED
@@ -2,7 +2,7 @@
2
  "add_prefix_space": false,
3
  "added_tokens_decoder": {
4
  "0": {
5
- "content": "<s>",
6
  "lstrip": false,
7
  "normalized": false,
8
  "rstrip": false,
@@ -16,11 +16,12 @@
16
  null
17
  ]
18
  },
19
- "bos_token": "<s>",
20
  "clean_up_tokenization_spaces": false,
21
- "eos_token": "<s>",
22
  "model_max_length": 1000000000000000019884624838656,
23
  "tokenizer_class": "Rwkv6Tokenizer",
24
- "unk_token": "<s>",
25
- "use_fast": false
 
26
  }
 
2
  "add_prefix_space": false,
3
  "added_tokens_decoder": {
4
  "0": {
5
+ "content": "<|rwkv_tokenizer_end_of_text|>",
6
  "lstrip": false,
7
  "normalized": false,
8
  "rstrip": false,
 
16
  null
17
  ]
18
  },
19
+ "bos_token": "<|rwkv_tokenizer_end_of_text|>",
20
  "clean_up_tokenization_spaces": false,
21
+ "eos_token": "<|rwkv_tokenizer_end_of_text|>",
22
  "model_max_length": 1000000000000000019884624838656,
23
  "tokenizer_class": "Rwkv6Tokenizer",
24
+ "unk_token": "<|rwkv_tokenizer_end_of_text|>",
25
+ "use_fast": false,
26
+ "chat_template": "{{ '<|rwkv_tokenizer_end_of_text|>' }}{% for message in messages %}{% if message['role'] == 'user' %}{{'User: ' + message['content'] + '\n\n'}}{% elif message['role'] == 'system' %}{{'System: ' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'Assistant: ' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
27
  }