File size: 2,975 Bytes
1960c63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import re

import gradio as gr
import torch
from transformers import (AutoConfig, AutoModel, AutoModelForSeq2SeqLM,
                          AutoTokenizer, LlamaForCausalLM, LlamaTokenizer)
from vllm import LLM, SamplingParams

model_id = "georgesung/llama2_7b_chat_uncensored"

prompt_config = {
    "system_header": None,
    "system_footer": None,
    "user_header": "### HUMAN:",
    "user_footer": None,
    "input_header": None,
    "response_header": "### RESPONSE:",
}

def get_llm_response_chat(prompt):
    outputs = llm.generate(prompt, sampling_params)
    output = outputs[0].outputs[0].text
    
    # Remove trailing eos token
    eos_token = llm.get_tokenizer().eos_token
    if output.endswith(eos_token):
        output = output[:-len(eos_token)]
    return output

def hist_to_prompt(history):
    prompt = ""
    if prompt_config["system_header"]:
        system_footer = ""
        if prompt_config["system_footer"]:
            system_footer = prompt_config["system_footer"]
        prompt += f"{prompt_config['system_header']}\n{SYSTEM_MESSAGE}{system_footer}\n\n"
        
    for i, (human_text, bot_text) in enumerate(history):
        user_footer = ""
        if prompt_config["user_footer"]:
            user_footer = prompt_config["user_footer"]

        prompt += f"{prompt_config['user_header']}\n{human_text}{user_footer}\n\n"

        prompt += f"{prompt_config['response_header']}\n"
        
        if bot_text:
            prompt += f"{bot_text}\n\n"
    return prompt

def get_bot_response(text):
    bot_text_index = text.rfind(prompt_config['response_header'])
    if bot_text_index != -1:
        text = text[bot_text_index + len(prompt_config['response_header']):].strip()
    return text

def main():
    # RE llama tokenizer:
    # RuntimeError: Failed to load the tokenizer.
    # If you are using a LLaMA-based model, use 'hf-internal-testing/llama-tokenizer' instead of the original tokenizer.
    llm = LLM(model=model_id, tokenizer='hf-internal-testing/llama-tokenizer')

    sampling_params = SamplingParams(temperature=0.01, top_p=0.1, top_k=40, max_tokens=2048)

    tokenizer = llm.get_tokenizer()

    with gr.Blocks() as demo:
        gr.Markdown(
        """
        # Let's chat
        """)
        
        chatbot = gr.Chatbot()
        msg = gr.Textbox()
        clear = gr.Button("Clear")

        def user(user_message, history):
            return "", history + [[user_message, None]]

        def bot(history):
            hist_text = hist_to_prompt(history)

            bot_message = get_llm_response_chat(hist_text) #+ tokenizer.eos_token
            history[-1][1] = bot_message  # add bot message to overall history

            return history

        msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
            bot, chatbot, chatbot
        )
        clear.click(lambda: None, None, chatbot, queue=False)

    demo.queue()
    demo.launch()



if __name__ == "__main__":
    main()