File size: 2,544 Bytes
8cf6e52
1960c63
 
8cf6e52
1960c63
8cf6e52
 
 
1960c63
8cf6e52
 
 
1960c63
8cf6e52
1960c63
8cf6e52
 
 
 
1960c63
8cf6e52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb96876
 
 
 
 
 
8cf6e52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline
import torch

import gradio as gr

# LLM helper functions
def get_response_text(data):
    text = data[0]["generated_text"]

    assistant_text_index = text.rfind('### RESPONSE:')
    if assistant_text_index != -1:
        text = text[assistant_text_index+len('### RESPONSE:'):].strip()

    return text

def get_llm_response(prompt, pipe):
    raw_output = pipe(prompt)
    text = get_response_text(raw_output)
    return text

# Load LLM
model_id = "georgesung/llama2_7b_chat_uncensored"
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)

# Llama tokenizer missing pad token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_length=4096,  # Llama-2 default context window
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.15
)

with gr.Blocks() as demo:
    gr.Markdown("""
    # Chat with llama2_7b_chat_uncensored
    NOTICE: I will pause this space on Monday, July 24, around noon UTC. Since it costs $$ to run :)

    If you wish to run this space yourself, you can duplicate this space and run it on a T4 small instance.
    """)
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def hist_to_prompt(history):
        prompt = ""
        for human_text, bot_text in history:
            prompt += f"### HUMAN:\n{human_text}\n\n### RESPONSE:\n"
            if bot_text:
                prompt += f"{bot_text}\n\n"
        return prompt

    def get_bot_response(text):
        bot_text_index = text.rfind('### RESPONSE:')
        if bot_text_index != -1:
            text = text[bot_text_index + len('### RESPONSE:'):].strip()
        return text

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history):
        #bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
        #history[-1][1] = bot_message + '</s>'

        hist_text = hist_to_prompt(history)
        print(hist_text)
        bot_message = get_llm_response(hist_text, pipe) + tokenizer.eos_token
        history[-1][1] = bot_message  # add bot message to overall history

        return history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch()