File size: 6,870 Bytes
1ce6f28
408fb2e
 
3a8e738
 
 
7901b62
b9b48c5
f991d7d
 
b9b48c5
f991d7d
 
 
 
880e945
408fb2e
880e945
dcf6e59
 
 
 
880e945
dcf6e59
 
 
 
880e945
e61245d
b213ce2
 
e61245d
b213ce2
e61245d
 
64f86b5
e61245d
b213ce2
64f86b5
b213ce2
e61245d
 
 
b213ce2
6ac567d
7e4cf08
 
c314120
dcf6e59
e298b33
dcf6e59
408fb2e
71a6f99
 
 
408fb2e
70a5709
e298b33
70a5709
 
71a6f99
70a5709
 
 
 
8244168
408fb2e
70a5709
 
 
 
 
 
 
 
 
408fb2e
70a5709
9ab3033
70a5709
e298b33
70a5709
 
 
 
 
e298b33
 
8312b78
4c1f576
ac31486
25819b2
 
ae4438b
dcf6e59
 
463c3f1
7901b62
463c3f1
7901b62
463c3f1
 
f7c578d
 
dcf6e59
 
f7c578d
 
 
 
 
dcf6e59
a4cd409
 
 
71a6f99
dcf6e59
880e945
dcf6e59
17f0ed4
421426b
f7c578d
 
 
 
 
 
880e945
dcf6e59
 
 
f7c578d
 
71a6f99
f7c578d
 
 
71a6f99
 
 
 
 
 
 
f7c578d
6ac567d
 
e61245d
6ac567d
 
71a6f99
 
bce3dcd
90981c9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import bitsandbytes
import accelerate
model_name_or_path = "teknium/OpenHermes-2.5-Mistral-7B"
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             torch_dtype=dtype,
                                             trust_remote_code=False,
                                             load_in_4bit=True,
                                             revision="main")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

BASE_SYSTEM_MESSAGE = "I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning."

def clear_chat(chat_history_state, chat_message):
    chat_history_state = []
    chat_message = ''
    return chat_history_state, chat_message

def user(message, history):
    history = history or []
    history.append([message, ""])
    return "", history

def regenerate(chatbot, chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty):
    print("Regenerate function called")  # Debug print
    
    if not chat_history_state:
        print("Chat history is empty")  # Debug print
        return chatbot, chat_history_state, ""
    
    # Remove only the last assistant's message from the chat history
    if len(chat_history_state) > 0:
        print(f"Before: {chat_history_state[-1]}")  # Debug print
        chat_history_state[-1][1] = ""
        print(f"After: {chat_history_state[-1]}")  # Debug print
    
    # Re-run the chat function
    new_history, _, _ = chat(chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty)
    print(f"New history: {new_history}")  # Debug print
    
    return new_history, new_history, ""


def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty):
    print(f"Chat function called with history: {history}")
    history = history or []
    
    # Use BASE_SYSTEM_MESSAGE if system_message is empty
    system_message_to_use = system_message if system_message.strip() else BASE_SYSTEM_MESSAGE
    
    # A última mensagem do usuário
    user_prompt = history[-1][0] if history else ""
    print(f"User prompt used for generation: {user_prompt}")  # Debug print
    # Preparar a entrada para o modelo
    prompt_template = f'''system
{system_message_to_use.strip()}
user
{user_prompt}
assistant
'''
    input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
    
    # Gerar a saída
    output = model.generate(
        input_ids=input_ids,
        max_length=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty
    )
    
    # Decodificar a saída
    decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
    assistant_response = decoded_output.split('assistant')[-1].strip()  # Pegar apenas a última resposta do assistente
    print(f"Generated assistant response: {assistant_response}")  # Debug print
    # Atualizar o histórico
    if history:
        history[-1][1] += assistant_response
    else:
        history.append(["", assistant_response])
        
    print(f"Updated history: {history}")
    return history, history, ""


start_message = ""

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown("""
    ## OpenHermes-V2.5 Finetuned on Mistral 7B
    **Space created by [@artificialguybr](https://twitter.com/artificialguybr). Model by [@Teknium1](https://twitter.com/Teknium1). Thanks HF for GPU!**
    **OpenHermes-V2.5 is currently SOTA in some benchmarks for 7B models.**
    **Hermes 2 model was trained on 900,000 instructions, and surpasses all previous versions of Hermes 13B and below, and matches 70B on some benchmarks! Hermes 2 changes the game with strong multiturn chat skills, system prompt capabilities, and uses ChatML format. It's quality, diversity and scale is unmatched in the current OS LM landscape. Not only does it do well in benchmarks, but also in unmeasured capabilities, like Roleplaying, Tasks, and more.**
    """)
    with gr.Row():
        #chatbot = gr.Chatbot().style(height=500)
        chatbot = gr.Chatbot(elem_id="chatbot")
    with gr.Row():
        message = gr.Textbox(
            label="What do you want to chat about?",
            placeholder="Ask me anything.",
            lines=3,
        )
    with gr.Row():
        submit = gr.Button(value="Send message", variant="secondary", scale=1)
        clear = gr.Button(value="New topic", variant="secondary", scale=0)
        stop = gr.Button(value="Stop", variant="secondary", scale=0)
        regen_btn = gr.Button(value="Regenerate", variant="secondary", scale=0)
    with gr.Accordion("Show Model Parameters", open=False):
        with gr.Row():
            with gr.Column():
                max_tokens = gr.Slider(20, 512, label="Max Tokens", step=20, value=500)
                temperature = gr.Slider(0.0, 2.0, label="Temperature", step=0.1, value=0.7)
                top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95)
                top_k = gr.Slider(1, 100, label="Top K", step=1, value=40)
                repetition_penalty = gr.Slider(1.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1)

        system_msg = gr.Textbox(
            start_message, label="System Message", interactive=True, visible=True, placeholder="System prompt. Provide instructions which you want the model to remember.", lines=5)

    chat_history_state = gr.State()
    clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)

    submit_click_event = submit.click(
    fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True
    ).then(
        fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True
    )
    
    # Corrected the clear button click event
    clear.click(
        fn=clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False
    )
    
    # Stop button remains the same
    stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event], queue=False)
    regen_click_event = regen_btn.click(
        fn=regenerate, 
        inputs=[chatbot, chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], 
        outputs=[chatbot, chat_history_state, message], 
        queue=True
    )


demo.queue(max_size=128, concurrency_count=2)
demo.launch()