|
import time |
|
import gradio as gr |
|
import os |
|
import json |
|
import requests |
|
|
|
|
|
API_URL = os.getenv("API_URL") + "/generate_stream" |
|
|
|
def predict(inputs, top_p, temperature, top_k, repetition_penalty, history=[]): |
|
if not inputs.startswith("User: "): |
|
inputs = "User: " + inputs + "\n" |
|
payload = { |
|
"inputs": inputs, |
|
"parameters": { |
|
"details": True, |
|
"do_sample": True, |
|
"max_new_tokens": 100, |
|
"repetition_penalty": repetition_penalty, |
|
"seed": 0, |
|
"temperature": temperature, |
|
"top_k": top_k, |
|
"top_p": top_p |
|
} |
|
} |
|
|
|
headers = { |
|
'accept': 'text/event-stream', |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
history.append(inputs) |
|
|
|
response = requests.post(API_URL, headers=headers, json=payload, stream=True) |
|
token_counter = 0 |
|
partial_words = "" |
|
|
|
for chunk in response.iter_lines(): |
|
|
|
if chunk: |
|
|
|
partial_words = partial_words + json.loads(chunk.decode()[5:])['token']['text'] |
|
if token_counter == 0: |
|
history.append(" " + partial_words) |
|
else: |
|
history[-1] = partial_words |
|
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] |
|
token_counter+=1 |
|
yield chat, history |
|
|
|
|
|
title = """<h1 align="center">Streaming your Chatbot output with Gradio</h1>""" |
|
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form: |
|
``` |
|
User: <utterance> |
|
Assistant: <utterance> |
|
User: <utterance> |
|
Assistant: <utterance> |
|
... |
|
``` |
|
In this app, you can explore the outputs of a 20B large language model. |
|
""" |
|
|
|
with gr.Blocks(css = "#col_container {width: 700px; height: 400px; overflow: auto; margin-left: auto; margin-right: auto;}") as demo: |
|
gr.HTML(title) |
|
with gr.Column(elem_id = "col_container"): |
|
chatbot = gr.Chatbot(elem_id='chatbot') |
|
inputs = gr.Textbox(placeholder= "Hi my name is Joe.", label= "Type an input and press Enter") |
|
state = gr.State([]) |
|
b1 = gr.Button() |
|
|
|
|
|
with gr.Accordion("Parameters", open=False): |
|
top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",) |
|
temperature = gr.Slider( minimum=-0, maximum=5.0, value=0.5, step=0.1, interactive=True, label="Temperature",) |
|
top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",) |
|
repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", ) |
|
|
|
|
|
|
|
inputs.submit( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],) |
|
b1.click( predict, [inputs, top_p, temperature, top_k, repetition_penalty, state], [chatbot, state],) |
|
|
|
gr.Markdown(description) |
|
demo.queue().launch(debug=True) |
|
|