File size: 4,445 Bytes
779a991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c46507
779a991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656f8a7
 
779a991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
import gradio as gr
import os
import requests 

hf_token = os.getenv("HF_TOKEN")
api_url_7b = os.getenv("API_URL_LLAMA_7")
api_url_13b = os.getenv("API_URL_LLAMA_13")
api_url_70b = os.getenv("API_URL_LLAMA_70")

headers = {
    'Content-Type': 'application/json',
}

def predict(message, 
            chatbot, 
            system_prompt = "", 
            temperature = 0.9,
            max_new_tokens = 256,
            top_p = 0.6,
            repetition_penalty = 1.0,
            model=api_url_70b):

    # Write the system prompt
    if system_prompt != "":
        input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
    else:
        input_prompt = f"<s>[INST] "

    temperature = float(temperature)

    # We check that temperature is not less than 1e-2
    if temperature < 1e-2:
        temperature = 1e-2

    top_p = float(top_p)

    for interaction in chatbot:
        input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "

    input_prompt = input_prompt + str(message) + " [/INST] "

    data = {
        "inputs": input_prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "do_sample": True,
        },
    }

    response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=("hf, hf_token"), stream=True)

    partial_message = ""
    for line in response.iter_lines():
        if line:  # filter out keep-alive new lines
            # Decode from bytes to string
            decoded_line = line.decode('utf-8')

            # Remove 'data:' prefix 
            if decoded_line.startswith('data:'):
                json_line = decoded_line[5:]  # Exclude the first 5 characters ('data:')
            else:
                gr.Warning(f"This line does not start with 'data:': {decoded_line}")
                continue

            # Load as JSON
            try:
                json_obj = json.loads(json_line)
                if 'token' in json_obj:
                    partial_message = partial_message + json_obj['token']['text'] 
                    yield partial_message
                elif 'error' in json_obj:
                    yield json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
                else:
                    gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")

            except json.JSONDecodeError:
                gr.Warning(f"This line is not valid JSON: {json_line}")
                continue
            except KeyError as e:
                gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
                continue


additional_inputs=[
    gr.Dropdown(["api_url_7b", "api_url_13b", "api_url_70b"], label="Model", info="Which model to use?"),
    gr.Textbox("", label="Optional system prompt"),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=4096,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.6,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )

]

chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)

chat_interface_stream = gr.ChatInterface(predict, 
                 title=title, 
                 description=description, 
                 textbox=gr.Textbox(),
                 chatbot=chatbot_stream,
                 css=css, 
                 examples=examples, 
                 cache_examples=True, 
                 additional_inputs=additional_inputs,
                 model = model)

# Gradio Demo 
with gr.Blocks() as demo:
    with gr.Tab("Llama 70B"):
        chat_interface.render()