Spaces:
Runtime error
Runtime error
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb. | |
# %% auto 0 | |
__all__ = ['HF_TOKEN', 'title', 'description', 'query_chat_api', 'inference_chat'] | |
# %% app.ipynb 0 | |
import gradio as gr | |
import requests | |
import json | |
import requests | |
import os | |
from pathlib import Path | |
from dotenv import load_dotenv | |
# %% app.ipynb 1 | |
if Path(".env").is_file(): | |
load_dotenv(".env") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# %% app.ipynb 2 | |
def query_chat_api( | |
model_id, | |
inputs, | |
temperature, | |
top_p | |
): | |
API_URL = f"https://api-inference.huggingface.co/models/{model_id}" | |
headers = {"Authorization": f"Bearer {HF_TOKEN}", "x-wait-for-model": "1"} | |
payload = { | |
"inputs": inputs, | |
"parameters": { | |
"temperature": temperature, | |
"top_p": top_p, | |
"do_sample": True, | |
"max_length": 512, | |
}, | |
} | |
response = requests.post(API_URL, json=payload, headers=headers) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
return "Error: " + response.text | |
# %% app.ipynb 5 | |
def inference_chat( | |
model_id, | |
prompt_template, | |
text_input, | |
temperature, | |
top_p, | |
history=[], | |
): | |
with open(f"prompt_templates/{prompt_template}.json", "r") as f: | |
prompt_template = json.load(f) | |
history.append(text_input) | |
inputs = prompt_template["prompt"].format(human_input=text_input) | |
output = query_chat_api(model_id, inputs, temperature, top_p) | |
history.append(" " + output[0]["generated_text"]) | |
chat = [ | |
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) | |
] # convert to tuples of list | |
return {chatbot: chat, state: history} | |
# %% app.ipynb 15 | |
title = """<h1 align="center">Chatty Language Models</h1>""" | |
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form: | |
``` | |
User: <utterance> | |
Assistant: <utterance> | |
User: <utterance> | |
Assistant: <utterance> | |
... | |
``` | |
In this app, you can explore the outputs of several language models conditioned on different conversational prompts. The models are trained on different datasets and have different objectives, so they will have different personalities and strengths. | |
So far, the following prompts are available: | |
* `langchain_default`: The default prompt used in the [LangChain library](https://github.com/hwchase17/langchain/blob/bc53c928fc1b221d0038b839d111039d31729def/langchain/chains/conversation/prompt.py#L4). Around 67 tokens long. | |
* `openai_chatgpt`: The prompt used in the OpenAI ChatGPT model. Around 261 tokens long. | |
* `deepmind_sparrow`: The prompt used in the DeepMind Sparrow model (Table 7 of [their paper](https://arxiv.org/abs/2209.14375)). Around 880 tokens long. | |
* `deepmind_gopher`: The prompt used in the DeepMind Gopher model (Table A30 of [their paper](https://arxiv.org/abs/2112.11446)). Around 791 tokens long. | |
* `anthropic_hhh`: The prompt used in the [Anthropic HHH models](https://gist.github.com/jareddk/2509330f8ef3d787fc5aaac67aab5f11#file-hhh_prompt-txt). A whopping 6,341 tokens long! | |
As you can see, most of these prompts exceed the maximum context size of models like Flan-T5, so an error usually means the Inference API has timed out. | |
""" | |
# %% app.ipynb 16 | |
with gr.Blocks( | |
css=""" | |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px} | |
#component-21 > div.wrap.svelte-w6rprc {height: 600px;} | |
""" | |
) as iface: | |
state = gr.State([]) | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_id = gr.Dropdown( | |
choices=["google/flan-t5-xl"], | |
value="google/flan-t5-xl", | |
label="Model", | |
interactive=True, | |
) | |
prompt_template = gr.Dropdown( | |
choices=[ | |
"langchain_default", | |
"openai_chatgpt", | |
"deepmind_sparrow", | |
"deepmind_gopher", | |
"anthropic_hhh", | |
], | |
value="langchain_default", | |
label="Prompt Template", | |
interactive=True, | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=1.0, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
top_p = gr.Slider( | |
minimum=0., | |
maximum=1.0, | |
value=0.8, | |
step=0.05, | |
interactive=True, | |
label="Top-p (nucleus sampling)", | |
) | |
with gr.Column(scale=1.8): | |
with gr.Row(): | |
chatbot = gr.Chatbot( | |
label="Chat Output", | |
) | |
with gr.Row(): | |
chat_input = gr.Textbox(lines=1, label="Chat Input") | |
chat_input.submit( | |
inference_chat, | |
[ | |
model_id, | |
prompt_template, | |
chat_input, | |
temperature, | |
top_p, | |
state, | |
], | |
[chatbot, state], | |
) | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear", interactive=True) | |
clear_button.click( | |
lambda: ("", [], []), | |
[], | |
[chat_input, chatbot, state], | |
queue=False, | |
) | |
submit_button = gr.Button( | |
value="Submit", interactive=True, variant="primary" | |
) | |
submit_button.click( | |
inference_chat, | |
[ | |
model_id, | |
prompt_template, | |
chat_input, | |
temperature, | |
top_p, | |
state, | |
], | |
[chatbot, state], | |
) | |
iface.launch() | |