File size: 4,746 Bytes
779a991
 
 
 
 
052f9cb
779a991
ed2a31a
 
 
 
779a991
 
 
 
052f9cb
 
 
 
779a991
6c70efd
779a991
 
 
 
6c70efd
9cc1eea
779a991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef87e14
7c5b6ea
 
 
7f0b753
7c5b6ea
 
aa0ae43
7c5b6ea
d2704e9
779a991
ef87e14
779a991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4eb928
779a991
a4eb928
 
779a991
 
 
 
 
 
 
 
 
 
 
aadd9ad
 
779a991
7169471
656f8a7
779a991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
052f9cb
 
 
779a991
 
916f14e
779a991
 
 
b9f4649
5a59c8e
779a991
 
 
b9f4649
779a991
b9f4649
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import json
import gradio as gr
import os
import requests 

# We get the token and the models API url 
hf_token = os.getenv("HF_TOKEN")
llama_7b = os.getenv("API_URL_LLAMA_7")
llama_13b = os.getenv("API_URL_LLAMA_13")
zephyr_7b = os.getenv("API_URL_ZEPHYR_7")

headers = {
    'Content-Type': 'application/json',
}

"""
Chat Function
"""
def chat(message, 
            chatbot, 
            model= llama_13b,
            system_prompt = "", 
            temperature = 0.9,
            max_new_tokens = 256,
            top_p = 0.6,
            repetition_penalty = 1.0
            ):

    # Write the system prompt
    if system_prompt != "":
        input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
    else:
        input_prompt = f"<s>[INST] "

    temperature = float(temperature)

    # We check that temperature is not less than 1e-2
    if temperature < 1e-2:
        temperature = 1e-2

    top_p = float(top_p)

    for interaction in chatbot:
        input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "

    input_prompt = input_prompt + str(message) + " [/INST] "

    data = {
        "inputs": input_prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "do_sample": True,
        },
    }

    print("MODEL" + model)

    if model == "zephyr_7b":
        model = zephyr_7b
    elif model == "llama_7b":
        model = llama_7b
    elif model == "llama_13b":
        model = llama_13b

    response = requests.post(model, headers=headers, data=json.dumps(data), auth=("hf", hf_token), stream=True)

    
    partial_message = ""
    for line in response.iter_lines():
        if line:  # filter out keep-alive new lines
            # Decode from bytes to string
            decoded_line = line.decode('utf-8')

            # Remove 'data:' prefix 
            if decoded_line.startswith('data:'):
                json_line = decoded_line[5:]  # Exclude the first 5 characters ('data:')
            else:
                gr.Warning(f"This line does not start with 'data:': {decoded_line}")
                continue

            # Load as JSON
            try:
                json_obj = json.loads(json_line)
                if 'token' in json_obj:
                    partial_message = partial_message + json_obj['token']['text'] 
                    return partial_message #yield
                elif 'error' in json_obj:
                    return json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
                    # yield
                else:
                    gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")

            except json.JSONDecodeError:
                gr.Warning(f"This line is not valid JSON: {json_line}")
                continue
            except KeyError as e:
                gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
                continue




additional_inputs=[
    gr.Dropdown(choices=["llama_7b", "llama_13b", "zephyr_7b"], label="Model", info="Which model do you want to use?"),
    gr.Textbox("", label="Optional system prompt"),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=4096,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.6,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

title = "Find the password 🔒"
description = "In this game prototype, your goal is to discuss with the intercom to find the correct password"

chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)

chat_interface = gr.ChatInterface(chat, 
                 title=title, 
                 description=description, 
                 textbox=gr.Textbox(),
                 chatbot=chatbot,
                 additional_inputs=additional_inputs)

# Gradio Demo 
with gr.Blocks() as demo:
    chat_interface.render()

demo.launch(debug=True)