File size: 4,962 Bytes
1819a5c
 
 
 
 
 
 
18d9fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1819a5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
from threading import Thread
import gradio as gr
import torch

lora_folder = ''
model_folder = ''

config = PeftConfig.from_pretrained(("Junity/Genshin-World-Model" if lora_folder == ''
                                     else lora_folder),
                                    trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
                                              else model_folder),
                                              torch_dtype=torch.float16,
                                              device_map="auto",
                                              trust_remote_code=True)
model = PeftModel.from_pretrained(model,
                                  ("Junity/Genshin-World-Model" if lora_folder == ''
                                   else lora_folder),
                                   device_map="auto",
                                   torch_dtype=torch.float16,
                                   trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
                                           else model_folder),
                                          trust_remote_code=True)

history = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k):
    if textbox != '':
        textbox = (textbox
                   + "\n"
                   + role_name
                   + (":" if role_name != '' else '')
                   + msg
                   + ('。\n' if msg[-1] not in ['。', '!', '?'] else ''))
        yield ["", textbox]
    else:
        textbox = (textbox
                   + role_name
                   + (":" if role_name != '' else '')
                   + msg
                   + ('。' if msg[-1] not in ['。', '!', '?', ')', '}', ':', ':', '('] else '')
                   + ('\n' if msg[-1] in ['。', '!', '?', ')', '}'] else ''))
        yield ["", textbox]
    if character_name != '':
        textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':'
    input_ids = tokenizer.encode(textbox)[-3200:]
    input_ids = torch.LongTensor([input_ids]).to(device)
    generation_config = model.generation_config
    stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
    gen_kwargs = {}
    gen_kwargs.update(dict(
        input_ids=input_ids,
        temperature=temp,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=rep,
        max_new_tokens=max_len,
        do_sample=True,
    ))
    outputs = []
    print(input_ids)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs["streamer"] = streamer

    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    for new_text in streamer:
        textbox += new_text
        yield ["", textbox]


with gr.Blocks() as demo:
    gr.Markdown(
        """
        ## Genshin-World-Model
        - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
        - 此模型不支持要求对方回答什么,只支持续写。
        - 目前运行不了,因为没有钱租卡。
        """
    )
    with gr.Tab("创作") as chat:
        role_name = gr.Textbox(label="你将扮演的角色(可留空)")
        character_name = gr.Textbox(label="对方的角色(可留空)")
        msg = gr.Textbox(label="你说的话")
    with gr.Row():
        clear = gr.ClearButton()
        sub = gr.Button("Submit", variant="primary")
    with gr.Row():
        temp = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.5, label="温度(调大则更随机)", interactive=True)
        rep = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.0, label="对重复生成的惩罚", interactive=True)
        max_len = gr.Slider(minimum=4, maximum=512, step=4, value=256, label="对方回答的最大长度", interactive=True)
        top_p = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.7, label="Top-p(调大则更随机)", interactive=True)
        top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k(调大则更随机)", interactive=True)
    textbox = gr.Textbox(interactive=True, label="全部文本(可修改)")
    clear.add([msg, role_name, textbox])
    sub.click(fn=respond,
              inputs=[role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k],
              outputs=[msg, textbox])
    gr.Markdown(
        """
        #### 特别鸣谢 XXXX
        """
    )
    demo.queue().launch()