File size: 3,187 Bytes
b183ad6
 
 
 
 
 
 
 
 
5da40b8
 
 
 
 
 
 
 
 
 
 
b183ad6
 
 
072a08d
b183ad6
082fbfd
b183ad6
 
 
 
 
 
 
 
c402338
 
b183ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
082fbfd
b183ad6
 
 
082fbfd
b183ad6
 
 
 
 
 
 
 
 
 
 
082fbfd
 
b183ad6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# ウェブUIの起動
import os
import itertools

import torch
from transformers import AutoTokenizer
import ctranslate2
import gradio as gr

DESCRIPTION="""
## 概要
- これは、とある研究発表のために作られたチャットルーム(スペース)です。アクセス過多の場合は、少し時間をおいてから再度アクセスしてください。
- 詳細設定にて、AIが生成する文章のテイストを調整することが出来ます。
- AIの名前は「ベータ」です。
## たのむぞ
- あまり個人情報を入力しないでください。
- 会話内容は収集しておりません。
"""


device = "cuda" if torch.cuda.is_available() else "cpu"
generator = ctranslate2.Generator("./FixedStar-BETA-7b-ct2", device=device)
tokenizer = AutoTokenizer.from_pretrained(
    "./tokenizer", use_fast=True)

def inference_func(prompt, max_length=64, sampling_temperature=0.7):
    tokens = tokenizer.convert_ids_to_tokens(
        tokenizer.encode(prompt, add_special_tokens=False)
    )
    results = generator.generate_batch(
        [tokens],
        max_length=max_length,
        sampling_topk=20,
        sampling_temperature=sampling_temperature,
        repetition_penalty=1.1,
        end_token=[26168, 27, 208, 14719, 9078, 18482, 27, 208],
        include_prompt_in_result=False,
    )
    output = tokenizer.decode(results[0].sequences_ids[0])
    return output

def make_prompt(message, chat_history, max_context_size: int = 10):
    contexts = chat_history + [[message, ""]]
    contexts = list(itertools.chain.from_iterable(contexts))
    if max_context_size > 0:
        context_size = max_context_size - 1
    else:
        context_size = 100000
    contexts = contexts[-context_size:]
    prompt = []
    for idx, context in enumerate(reversed(contexts)):
        if idx % 2 == 0:
            prompt = [f"ASSISTANT: {context}"] + prompt
        else:
            prompt = [f"USER: {context}"] + prompt
    prompt = "\n".join(prompt)
    return prompt


def interact_func(message, chat_history, max_context_size, max_length, sampling_temperature):
    prompt = make_prompt(message, chat_history, max_context_size)
    print(f"prompt: {prompt}")
    generated = inference_func(prompt, max_length, sampling_temperature)
    print(f"generated: {generated}")
    chat_history.append((message, generated))
    return "", chat_history

with gr.Blocks(theme="monochrome") as demo:
    with gr.Accordion("Configs", open=False):
        # max_context_size = the number of turns * 2
        max_context_size = gr.Number(value=20, label="記憶する会話ターン数", precision=0)
        max_length = gr.Number(value=64, label="最大文字数", precision=0)
        sampling_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="創造性")
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("消す")
    msg.submit(
        interact_func,
        [msg, chatbot, max_context_size, max_length, sampling_temperature],
        [msg, chatbot],
    )
    clear.click(lambda: None, None, chatbot, queue=False)

gr.Markdown(DESCRIPTION)

if __name__ == "__main__":
    demo.launch(debug=True, share=True)