File size: 5,124 Bytes
88c96b4
 
 
a9acd44
88c96b4
a9acd44
88c96b4
 
02df985
9eb29f6
 
88c96b4
 
 
 
 
 
 
9eb29f6
 
 
 
 
 
 
 
 
7b49d27
88c96b4
 
 
 
 
 
 
 
 
 
 
 
813287b
 
 
 
88c96b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b49d27
88c96b4
 
 
 
 
 
7b49d27
ac50f75
88c96b4
 
 
 
 
 
f1741c6
88c96b4
 
 
813287b
5171bfd
 
813287b
88c96b4
ee072dd
88c96b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6ae6ff
88c96b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011372b
c519e63
 
ac524ea
9eb29f6
 
 
e578409
88c96b4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from collections import defaultdict
import gradio as gr
from optimum.onnxruntime import ORTModelForCausalLM


import itertools
import regex as re
import logging



user_token = "<User>"
eos_token = "<EOS>"
bos_token = "<BOS>"
bot_token = "<Assistant>"


logger = logging.getLogger()
handler = logging.StreamHandler()
formatter = logging.Formatter(
        '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

max_context_length = 750

def format(history):
    prompt = bos_token

    for idx, txt in enumerate(history):
        if idx % 2 == 0:
            prompt += f"{user_token}{txt}{eos_token}"
        else:
            prompt += f"{bot_token}{txt}"
    prompt += bot_token
    return prompt

def remove_spaces_between_chinese(text):
    rex = r"(?<![a-zA-Z]{2})(?<=[a-zA-Z]{1})[ ]+(?=[a-zA-Z] |.$)|(?<=\p{Han}) +"
    return re.sub(rex, "", text, 0, re.MULTILINE | re.UNICODE)

def gradio(model, tokenizer):
    def response(
        user_input,
        chat_history,
        top_k,
        top_p,
        temperature,
        repetition_penalty,
        no_repeat_ngram_size,
    ):
        history = list(itertools.chain(*chat_history))
        history.append(user_input)

        prompt = format(history)

        input_ids = tokenizer.encode(
            prompt,
            return_tensors="pt",
            add_special_tokens=False,
        )[:, -max_context_length:]

        prompt_length = input_ids.shape[1]

        beam_output = model.generate(
            input_ids,
            pad_token_id=tokenizer.pad_token_id,
            max_new_tokens=250,
            num_beams=1, # with cpu
            top_k=top_k,
            top_p=top_p,
            no_repeat_ngram_size=no_repeat_ngram_size,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            early_stopping=True,
            do_sample=True
        )
        output = beam_output[0][prompt_length:]

        generated = remove_spaces_between_chinese(tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True))
        logger.info(prompt+generated)

        return generated

    bot = gr.Chatbot(show_copy_button=True, show_share_button=True, height="2000")

    with gr.Blocks() as demo:
        gr.Markdown("GPT2 chatbot | Powered by nlp-greyfoss")

        with gr.Accordion("Parameters in generation", open=False):
            with gr.Row():
                top_k = gr.Slider(
                    2.0,
                    100.0,
                    label="top_k",
                    step=1,
                    value=50,
                    info="Limit the number of candidate tokens considered during decoding.",
                )
                top_p = gr.Slider(
                    0.1,
                    1.0,
                    label="top_p",
                    value=0.9,
                    info="Control the diversity of the output by selecting tokens with cumulative probabilities up to the Top-P threshold.",
                )
                temperature = gr.Slider(
                    0.1,
                    2.0,
                    label="temperature",
                    value=0.9,
                    info="Control the randomness of the generated text. A higher temperature results in more diverse and unpredictable outputs, while a lower temperature produces more conservative and coherent text.",
                )
                repetition_penalty = gr.Slider(
                    0.1,
                    2.0,
                    label="repetition_penalty",
                    value=1.2,
                    info="Discourage the model from generating repetitive tokens in a sequence.",
                )
                no_repeat_ngram_size = gr.Slider(
                    0,
                    100,
                    label="no_repeat_ngram_size",
                    step=1,
                    value=5,
                    info="Prevent the model from generating sequences of n consecutive tokens that have already been generated in the context. ",
                )

        gr.ChatInterface(
            response,
            chatbot=bot,
            additional_inputs=[
                top_k,
                top_p,
                temperature,
                repetition_penalty,
                no_repeat_ngram_size,
            ],
            retry_btn = "🔄 Regenerate",
            undo_btn = "↩️ Remove last turn",
            clear_btn = "➕ New conversation",
            examples=[
                ["写一篇介绍人工智能的文章。", 30, 0.9, 0.95, 1.2, 5],
                ["给我讲一个笑话。", 50, 0.8, 0.9, 1.2, 6],
                ["Can you describe spring in English?", 50, 0.9, 1.0, 1, 5]
            ]
        )

    demo.queue().launch()




tokenizer = AutoTokenizer.from_pretrained("greyfoss/gpt2-chatbot-chinese")

model = ORTModelForCausalLM.from_pretrained("greyfoss/gpt2-chatbot-chinese", export=True)

gradio(model, tokenizer)