File size: 4,669 Bytes
7a42908
 
2bf22c3
b5dde4f
85b9fdf
7a42908
9e2e59a
 
7a42908
eed9231
7a42908
 
 
 
 
 
 
639f2fe
cce1926
7a42908
cce1926
 
 
7a42908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eed9231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c376d29
7a42908
 
 
 
 
 
 
 
2c12f3b
7a42908
 
639f2fe
7a42908
 
9e2e59a
7a42908
 
 
 
 
 
 
 
 
 
 
 
0426b89
 
7a42908
9e2e59a
7a42908
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import time
import gradio as gr
import os 
import json 
import requests

#Streaming endpoint
API_URL = os.getenv("API_URL") + "/generate_stream"

def predict_old(inputs, top_p, temperature, top_k, repetition_penalty, history=[]): 
    if not inputs.startswith("User: "):
        inputs = "User: " + inputs + "\n"
    payload = {
        "inputs": inputs, #"My name is Jane and I",
        "parameters": {
            "details": True,
            "do_sample": True,
            "max_new_tokens": 100,
            "repetition_penalty": repetition_penalty, #1.03,
            "seed": 0,
            "temperature": temperature, #0.5,
            "top_k": top_k, #10,
            "top_p": top_p #0.95
        }
    }

    headers = {
        'accept': 'text/event-stream',
        'Content-Type': 'application/json'
    }
    
    history.append(inputs)
    response = requests.post(API_URL, headers=headers, json=payload)
    responses = response.text.split("\n\n")

    partial_words = ""
    for idx, resp in enumerate(responses):
      if resp[:4] == 'data':
        partial_words = partial_words + json.loads(resp[5:])['token']['text']
        #print(partial_words)
        time.sleep(0.05)
        if idx == 0:
          history.append(" " + partial_words)
        else:
          history[-1] = partial_words
        
        chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ]  # convert to tuples of list
        
        yield chat, history #resembles {chatbot: chat, state: history}  


def predict(inputs, top_p, temperature, top_k, repetition_penalty, history=[]):
    if not inputs.startswith("User: "):
        inputs = "User: " + inputs + "\n"
    payload = {
        "inputs": inputs, #"My name is Jane and I",
        "parameters": {
            "details": True,
            "do_sample": True,
            "max_new_tokens": 100,
            "repetition_penalty": repetition_penalty, #1.03,
            "seed": 0,
            "temperature": temperature, #0.5,
            "top_k": top_k, #10,
            "top_p": top_p #0.95
        }
    }

    headers = {
        'accept': 'text/event-stream',
        'Content-Type': 'application/json'
    }
    
    history.append(inputs)
    response = requests.post(API_URL2, headers=headers, json=payload, stream=True)
    token_counter = 0 
    partial_words = "" #inputs
    for chunk in response.iter_lines():
      if chunk:
        #print(chunk.decode())
        partial_words = partial_words + json.loads(chunk.decode()[5:])['token']['text']
        #print(partial_words)
        time.sleep(0.05)
        #print([(partial_words, "")])
        if token_counter == 0:
          history.append(" " + partial_words)
        else:
          history[-1] = partial_words
        chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ]  # convert to tuples of list
        #yield [(partial_words, history)]
        token_counter+=1
        yield chat, history #{chatbot: chat, state: history}  #[(partial_words, history)]


title = """<h1 align="center">Streaming your Chatbot output with Gradio</h1>"""
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
```
User: <utterance>
Assistant: <utterance>
User: <utterance>
Assistant: <utterance>
...
```
In this app, you can explore the outputs of a large language models.
"""

with gr.Blocks(css = "#chatbot {height: 400px; overflow: auto;}") as demo:
    gr.HTML(title)
    chatbot = gr.Chatbot(elem_id='chatbot') #c
    inputs = gr.Textbox(placeholder= "Hi my name is Joe.", label= "Type an input and press Enter") #t
    state = gr.State([]) #s
    b1 = gr.Button()

    #inputs, top_p, temperature, top_k, repetition_penalty
    with gr.Accordion("Parameters", open=False):
        top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
        temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.5, step=0.1, interactive=True, label="Temperature",)
        top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
        repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
    
    #b1.click(predict, [t,s], [c,s])
    #inputs.submit(predict, [t,s], [c,s])
    inputs.submit( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],)
    b1.click( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],)

    gr.Markdown(description)
    demo.queue().launch(debug=True)