File size: 8,143 Bytes
5593fd8
837c5aa
 
5593fd8
 
 
 
837c5aa
5593fd8
 
837c5aa
 
 
5593fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837c5aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5593fd8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

使用[Firefly](https://github.com/yangjianxin1/Firefly)项目微调baichuan-13b-base。训练数据约为一百万多轮对话数据,包括项目分享的moss数据+2万条school math数据。

更多详情见项目:[Firefly](https://github.com/yangjianxin1/Firefly)

技术细节分享:[Firefly增强Baichuan-13B的多轮对话能力](https://mp.weixin.qq.com/s/djO8Tg3emmy6wzw_rTUlcw)

训练loss:

[//]: # (<img src="https://huggingface.co/YeungNLP/firefly-baichuan-13b/resolve/main/firefly-baichuan-13b-loss.jpg" width="450">)
![firefly_logo](firefly-baichuan-13b-loss.jpg)


C-Eval榜单:

| Model                            | C-Eval | STEM  | Social Science | Humanities | Other |
|----------------------------------|--------|-------|----------------|------------|-------|
| Baichuan-13B-Chat(官方)            | 52.05  | 42.23 | 65.27          | 58.61      | 51.32 |
| **firefly-baichuan-13b**         | 51.36  | 44.24 | 61.65          | 54.63      | 51.68 |
| chatglm2-6b(官方)                  | 50.45  | 41.91 | 60.73          | 59.24      | 47.82 |
| **firefly-chatglm2-6b**          | 49.13  | 43.6  | 58.83          | 54.48      | 45.03 |
| openbuddy-llama2-13b-v11.1-bf16  | 43.36  | 39.79 | 50.28          | 44.78      | 42.13 |
| chinese-alpaca-2-13b(哈工大)        | 41.86  | 36.52 | 49.7           | 47.97      | 38.33 |
| openbuddy-llama2-13b-v8.1-fp16   | 41.62  | 38.82 | 44.66          | 40.28      | 45.32 |
| chinese-alpaca-2-7b(哈工大)         | 41.48  | 35.01 | 50.08          | 43.02      | 43.87 |
| belle-llama2-13B-chat-0.4M       | 41.11  | 40.04 | 44.71          | 42.09      | 38.82 |
| ziya-llama-13b                   | 39.1   | -     | -              | -          | -     |
| llama-2-13b-chat(官方)             | 36.38  | 33.68 | 46.38          | 34.47      | 34.1  |
| lama-2-7b-chat(官方)               | 35.86  | 32.85 | 40.04          | 37.37      | 36.01 |
| flagalpha/Llama2-Chinese-7b-Chat | 34.54  | 35.21 | 37.9           | 33.11      | 31.7  |
| yayi-13b-llama2                  | 34.15  | 36.48 | 30.64          | 32.67      | 34.6  |
| yayi-7b-llama2                   | 30.18  | 25.88 | 38.23          | 34.56      | 26.31 |
| linly-llama2-7b                  | 28.35  | 26.06 | 33.47          | 29.71      | 26.53 |
| linly-llama2-13b                 | 27.86  | 27.67 | 26.95          | 27.93      | 28.95 |


单轮对话:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
"""
单轮对话,不具有对话历史的记忆功能
"""


def main():
    model_name = 'YeungNLP/firefly-baichuan-13b'

    max_new_tokens = 500
    top_p = 0.9
    temperature = 0.35
    repetition_penalty = 1.0
    device = 'cuda'
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map='auto'
    ).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        # llama不支持fast
        use_fast=False if model.config.model_type == 'llama' else True
    )
    # QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
    if tokenizer.__class__.__name__ == 'QWenTokenizer':
        tokenizer.pad_token_id = tokenizer.eod_id
        tokenizer.bos_token_id = tokenizer.eod_id
        tokenizer.eos_token_id = tokenizer.eod_id

    text = input('User:')
    while True:
        text = text.strip()
        # chatglm使用官方的数据组织格式
        if model.config.model_type == 'chatglm':
            text = '[Round 1]\n\n问:{}\n\n答:'.format(text)
            input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
        # 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_id
        else:
            input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
            bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
            eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
            input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
                top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
                eos_token_id=tokenizer.eos_token_id
            )
        outputs = outputs.tolist()[0][len(input_ids[0]):]
        response = tokenizer.decode(outputs)
        response = response.strip().replace(tokenizer.eos_token, "").strip()
        print("Firefly:{}".format(response))
        text = input('User:')


if __name__ == '__main__':
    main()
```


多轮对话:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


def main():
    model_name = 'YeungNLP/firefly-baichuan-13b'

    device = 'cuda'
    max_new_tokens = 500    # 每轮对话最多生成多少个token
    history_max_len = 1000  # 模型记忆的最大token长度
    top_p = 0.9
    temperature = 0.35
    repetition_penalty = 1.0

    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map='auto'
    ).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        # llama不支持fast
        use_fast=False if model.config.model_type == 'llama' else True
    )
    # QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>
    if tokenizer.__class__.__name__ == 'QWenTokenizer':
        tokenizer.pad_token_id = tokenizer.eod_id
        tokenizer.bos_token_id = tokenizer.eod_id
        tokenizer.eos_token_id = tokenizer.eod_id

    # 记录所有历史记录
    if model.config.model_type != 'chatglm':
        history_token_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long)
    else:
        history_token_ids = torch.tensor([[]], dtype=torch.long)

    # 开始对话
    utterance_id = 0    # 记录当前是第几轮对话,为了契合chatglm的数据组织格式
    user_input = input('User:')
    while True:
        utterance_id += 1
        # chatglm使用官方的数据组织格式
        if model.config.model_type == 'chatglm':
            user_input = '[Round {}]\n\n问:{}\n\n答:'.format(utterance_id, user_input)
            user_input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
        # firefly的数据组织格式
        # 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_id
        else:
            input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
            eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
            user_input_ids = torch.concat([input_ids, eos_token_id], dim=1)
        history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
        model_input_ids = history_token_ids[:, -history_max_len:].to(device)
        with torch.no_grad():
            outputs = model.generate(
                input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
                temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
            )
        model_input_ids_len = model_input_ids.size(1)
        response_ids = outputs[:, model_input_ids_len:]
        history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
        response = tokenizer.batch_decode(response_ids)
        print("Firefly:" + response[0].strip().replace(tokenizer.eos_token, ""))
        user_input = input('User:')


if __name__ == '__main__':
    main()
```