File size: 5,100 Bytes
fd095be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f494a07
 
fd095be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300a211
 
 
 
 
 
 
 
 
 
 
 
 
14f3e21
300a211
 
 
 
 
 
 
 
 
 
 
 
fb89707
9f2a5e4
1cfb8e6
fb89707
 
 
300a211
 
 
fb89707
300a211
 
 
 
fd095be
 
03fcc17
 
 
 
 
 
 
300a211
03fcc17
300a211
 
 
9f2a5e4
300a211
55f40f7
300a211
 
 
 
 
 
 
 
03fcc17
300a211
 
 
 
03fcc17
300a211
 
9f2a5e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# -*- coding: utf-8 -*-
"""Fujisaki_CPU.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1Damnr0Ha4zZAlKFvne9cu76uuElLNYus

李萌萌的电子骨灰盒
----

这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以在问题栏目填入内容,或者什么都不填,来观察李萌萌到底会说些什么。
T4级别的GPU已经可以很胜任这个任务了。

### 安装依赖
"""

from modeling_chatglm import ChatGLMForConditionalGeneration
import torch
import sys

from transformers import AutoTokenizer, GenerationConfig

model = ChatGLMForConditionalGeneration.from_pretrained("THUDM/chatglm-6b").float()
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)

from peft import get_peft_model, LoraConfig, TaskType, PeftModel
peft_path = 'ljsabc/Fujisaki_GLM'      # change it to your own
model = PeftModel.from_pretrained(
       model,
       peft_path,
       torch_dtype=torch.float,
    )

# dump a log to ensure everything works well
print(model.peft_config)
# We have to use full precision, as some tokens are >65535
model.eval()

torch.set_default_tensor_type(torch.FloatTensor)
def evaluate(context, temperature, top_p, top_k):
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        #repetition_penalty=1.1,
        num_beams=1,
        do_sample=True,
    )
    with torch.no_grad():
        input_text = f"Context: {context}Answer: " 
        ids = tokenizer.encode(input_text)
        input_ids = torch.LongTensor([ids]).to('cpu')
        out = model.generate(
            input_ids=input_ids,
            max_length=160,
            generation_config=generation_config
        )
        out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
        return out_text
    
def evaluate_stream(msg, history, temperature, top_p):
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        #repetition_penalty=1.1,
        num_beams=1,
        do_sample=True,
    )

    history.append([msg, None])

    context = ""
    if len(history) > 8:
        history.pop(0)

    for j in range(len(history)):
        history[j][0] = history[j][0].replace("<br>", "")

    # concatenate context
    for h in history[:-1]:
        context += h[0] + "\n" + h[1] + "\n"

    context += history[-1][0]
    context = context.replace(r'<br>', '')

    # TODO: Avoid the tokens are too long.
    CUTOFF = 256 
    while len(tokenizer.encode(context)) > CUTOFF:
        # save 15 token size for the answer
        context = context[15:]

    h = []
    print("History:", history)
    print("Context:", context)
    for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):
        history[-1][1] = response
        yield history, ""

    #return response

import gradio as gr

title = """<h1 align="center">李萌萌(Alter Ego)</h1>
<h3 align="center">这是一个通过ChatGLM模型训练的李萌萌的数字分身,你可以与她聊天,或者直接在文本框按下Enter,来观察李萌萌到底会说些什么。</h3>"""

footer =  """<p align='center'>项目在<a href='https://github.com/ljsabc/Fujisaki' target='_blank'>GitHub</a>上托管,基于清华的<a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>项目。</p>
<p align='center'><em>"I'm... a boy." --Chihiro Fujisaki</em></p>"""

with gr.Blocks() as demo:
    gr.HTML(title)
    state = gr.State()
    with gr.Row():
        with gr.Column(scale=2):
            temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.95, label="Temperature",
                info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。")
            top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.975, label="Top-p",
                info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
            #code = gr.Textbox(label="temp_output", info="解码器输出")
            #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
            #    info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
            
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="聊天框", info="")
            msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
                info="输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。")
            clear = gr.Button("清除聊天")

    msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
    clear.click(lambda: None, None, chatbot, queue=False)
    gr.HTML(footer)

demo.queue()
demo.launch(debug=False)