charent commited on
Commit
8797a7e
1 Parent(s): aa3cf53

Upload 8 files

Browse files
chat_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, LongTensor
3
+ from transformers import T5ForConditionalGeneration, T5Config
4
+ from transformers import TextIteratorStreamer
5
+ from transformers.generation.configuration_utils import GenerationConfig
6
+
7
+ class TextToTextModel(T5ForConditionalGeneration):
8
+ def __init__(self, config: T5Config) -> None:
9
+ '''
10
+ TextToTextModel继承T5ForConditionalGeneration
11
+ '''
12
+ super().__init__(config)
13
+
14
+ @torch.no_grad()
15
+ def my_generate(self,
16
+ input_ids: LongTensor,
17
+ attention_mask: LongTensor,
18
+ max_seq_len: int=256,
19
+ search_type: str='beam',
20
+ streamer: TextIteratorStreamer=None,
21
+ ) -> Tensor:
22
+ '''
23
+ 自定义gennerate方法方便调用、测试
24
+ search_type: ['greedy', 'beam', 'sampling', 'contrastive', ]
25
+
26
+ - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
27
+ `do_sample=False`
28
+ - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
29
+ and `top_k>1`
30
+ - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
31
+ `do_sample=True`
32
+ - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
33
+ `do_sample=False`
34
+ - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
35
+ `num_beams>1` and `do_sample=True`
36
+ '''
37
+ generation_config = GenerationConfig()
38
+ generation_config.remove_invalid_values = True
39
+ generation_config.eos_token_id = 1
40
+ generation_config.pad_token_id = 0
41
+ generation_config.decoder_start_token_id = self.config.decoder_start_token_id
42
+ generation_config.max_new_tokens = max_seq_len
43
+ # generation_config.repetition_penalty = 1.1 # 重复词惩罚
44
+
45
+ if search_type == 'greedy':
46
+ generation_config.num_beams = 1
47
+ generation_config.do_sample = False
48
+ elif search_type == 'beam':
49
+ generation_config.top_k = 50
50
+ generation_config.num_beams = 5
51
+ generation_config.do_sample = True
52
+ generation_config.top_p = 0.95
53
+ generation_config.no_repeat_ngram_size = 4
54
+ generation_config.length_penalty = -2.0
55
+ generation_config.early_stopping = True
56
+ elif search_type == 'sampling':
57
+ generation_config.num_beams = 1
58
+ generation_config.do_sample = True
59
+ generation_config.top_k = 50
60
+ generation_config.temperature = 0.98 # 越低概率越趋向于均匀分布
61
+ generation_config.top_p = 0.80
62
+ generation_config.no_repeat_ngram_size = 4
63
+ elif search_type == 'contrastive':
64
+ generation_config.penalty_alpha = 0.5
65
+ generation_config.top_k = 50
66
+
67
+ result = self.generate(
68
+ inputs=input_ids,
69
+ attention_mask=attention_mask,
70
+ generation_config=generation_config,
71
+ streamer=streamer,
72
+ )
73
+
74
+ return result
chat_model_config.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers import T5Config
2
+
3
+ class TextToTextModelConfig(T5Config):
4
+ model_type = 't5'
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "charent/ChatLM-Chinese-0.2B",
3
+ "architectures": [
4
+ "TextToTextModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "chat_model_config.TextToTextModelConfig",
8
+ "AutoModelForSeq2SeqLM": "chat_model.TextToTextModel"
9
+ },
10
+ "classifier_dropout": 0.0,
11
+ "d_ff": 3072,
12
+ "d_kv": 64,
13
+ "d_model": 768,
14
+ "decoder_start_token_id": 0,
15
+ "dense_act_fn": "relu",
16
+ "dropout_rate": 0.1,
17
+ "eos_token_id": 1,
18
+ "feed_forward_proj": "relu",
19
+ "initializer_factor": 1.0,
20
+ "is_encoder_decoder": true,
21
+ "is_gated_act": false,
22
+ "layer_norm_epsilon": 1e-06,
23
+ "model_type": "t5",
24
+ "num_decoder_layers": 10,
25
+ "num_heads": 12,
26
+ "num_layers": 10,
27
+ "pad_token_id": 0,
28
+ "relative_attention_max_distance": 128,
29
+ "relative_attention_num_buckets": 32,
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.36.2",
32
+ "use_cache": true,
33
+ "vocab_size": 29298
34
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.36.2"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:054caeae92bcc13f0b6e7a12f86e75c8e18117279ecd89c4aa1f8ac74c95c02a
3
+ size 750794624
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "[EOS]",
3
+ "pad_token": "[PAD]",
4
+ "unk_token": "[UNK]"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[BOS]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[CLS]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "[MASK]",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "6": {
52
+ "content": "[UNK]",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ }
59
+ },
60
+ "clean_up_tokenization_spaces": true,
61
+ "eos_token": "[EOS]",
62
+ "model_max_length": 1000000000000000019884624838656,
63
+ "pad_token": "[PAD]",
64
+ "tokenizer_class": "PreTrainedTokenizerFast",
65
+ "unk_token": "[UNK]"
66
+ }