ysharma's picture
ysharma HF staff
update the header with emojis
57a8590
raw
history blame contribute delete
No virus
3.69 kB
import time
import gradio as gr
import os
import json
import requests
#Streaming endpoint
API_URL = os.getenv("API_URL") + "/generate_stream"
def predict(inputs, top_p, temperature, top_k, repetition_penalty, history=[]):
if not inputs.startswith("User: "):
inputs = "User: " + inputs + "\n"
payload = {
"inputs": inputs, #"My name is Jane and I",
"parameters": {
"details": True,
"do_sample": True,
"max_new_tokens": 100,
"repetition_penalty": repetition_penalty, #1.03,
"seed": 0,
"temperature": temperature, #0.5,
"top_k": top_k, #10,
"top_p": top_p #0.95
}
}
headers = {
'accept': 'text/event-stream',
'Content-Type': 'application/json'
}
history.append(inputs)
# make a POST request to the API endpoint using the requests.post method, passing in stream=True
response = requests.post(API_URL, headers=headers, json=payload, stream=True)
token_counter = 0
partial_words = ""
# loop over the response data using the iter_lines method of the response object
for chunk in response.iter_lines():
# check whether each line is non-empty
if chunk:
# decode each line as response data is in bytes
partial_words = partial_words + json.loads(chunk.decode()[5:])['token']['text']
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = partial_words
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
token_counter+=1
yield chat, history #{chatbot: chat, state: history} #[(partial_words, history)]
def reset_textbox():
return gr.update(value='')
title = """<h1 align="center">🔥Streaming your 🤖Chatbot output with Gradio🚀</h1>"""
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
```
User: <utterance>
Assistant: <utterance>
User: <utterance>
Assistant: <utterance>
...
```
In this app, you can explore the outputs of a 20B large language model.
"""
with gr.Blocks(css = """#col_container {width: 700px; margin-left: auto; margin-right: auto;}
#chatbot {height: 400px; overflow: auto;}""") as demo:
gr.HTML(title)
with gr.Column(elem_id = "col_container"):
chatbot = gr.Chatbot(elem_id='chatbot') #c
inputs = gr.Textbox(placeholder= "Hi my name is Joe.", label= "Type an input and press Enter") #t
state = gr.State([]) #s
b1 = gr.Button()
#inputs, top_p, temperature, top_k, repetition_penalty
with gr.Accordion("Parameters", open=False):
top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.5, step=0.1, interactive=True, label="Temperature",)
top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
inputs.submit( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],)
b1.click( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],)
b1.click(reset_textbox, [], [inputs])
inputs.submit(reset_textbox, [], [inputs])
gr.Markdown(description)
demo.queue().launch(debug=True)