chatty-lms-old / app.py
lewtun's picture
lewtun HF staff
Fix
d0f9fb0
raw
history blame
6.46 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
# %% auto 0
__all__ = ['HF_TOKEN', 'ENDPOINT_URL', 'title', 'description', 'get_model_endpoint_params', 'query_chat_api', 'inference_chat']
# %% app.ipynb 0
import gradio as gr
import requests
import json
import requests
import os
from pathlib import Path
from dotenv import load_dotenv
# %% app.ipynb 1
if Path(".env").is_file():
load_dotenv(".env")
HF_TOKEN = os.getenv("HF_TOKEN")
ENDPOINT_URL = os.getenv("ENDPOINT_URL")
# %% app.ipynb 2
def get_model_endpoint_params(model_id):
if "joi" in model_id:
headers = None
max_new_tokens_supported = True
return ENDPOINT_URL, headers, max_new_tokens_supported
else:
max_new_tokens_supported = False
headers = {"Authorization": f"Bearer {HF_TOKEN}", "x-wait-for-model": "1"}
return f"https://api-inference.huggingface.co/models/{model_id}", headers, max_new_tokens_supported
# %% app.ipynb 3
def query_chat_api(
model_id,
inputs,
temperature,
top_p
):
endpoint, headers, max_new_tokens_supported = get_model_endpoint_params(model_id)
payload = {
"inputs": inputs,
"parameters": {
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
},
}
if max_new_tokens_supported is True:
payload["parameters"]["max_new_tokens"] = 100
payload["parameters"]["repetition_penalty"]: 1.03
payload["parameters"]["stop"] = ["Human:"]
else:
payload["parameters"]["max_length"] = 512
response = requests.post(endpoint, json=payload, headers=headers)
if response.status_code == 200:
return response.json()
else:
return "Error: " + response.text
# %% app.ipynb 5
def inference_chat(
model_id,
text_input,
temperature,
top_p,
history=[],
):
if "joi" in model_id:
prompt_filename = "langchain_default.json"
else:
prompt_filename = "anthropic_hhh_single.json"
print(prompt_filename)
with open(f"prompt_templates/{prompt_filename}", "r") as f:
prompt_template = json.load(f)
history_input = ""
for idx, text in enumerate(history):
if idx % 2 == 0:
history_input += f"Human: {text}\n"
else:
history_input += f"Assistant: {text}\n"
history_input = history_input.rstrip("\n")
inputs = prompt_template["prompt"].format(human_input=text_input, history=history_input)
history.append(text_input)
print(f"History: {history}")
print(f"Inputs: {inputs}")
output = query_chat_api(model_id, inputs, temperature, top_p)
if isinstance(output, list):
output = output[0]
output = output["generated_text"].rstrip(" Human:")
history.append(" " + output)
chat = [
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
] # convert to tuples of list
return {chatbot: chat, state: history}
# %% app.ipynb 21
title = """<h1 align="center">Chatty Language Models</h1>"""
description = """Pretrained language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
```
Human: <utterance>
Assistant: <utterance>
Human: <utterance>
Assistant: <utterance>
...
```
In this app, you can explore the outputs of several language models conditioned on different conversational prompts. The models are trained on different datasets and have different objectives, so they will have different personalities and strengths.
"""
# %% app.ipynb 23
with gr.Blocks(
css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
"""
) as iface:
state = gr.State([])
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
model_id = gr.Dropdown(
choices=["google/flan-t5-xl" ,"Rallio67/joi_20B_instruct_alpha"],
value="google/flan-t5-xl",
label="Model",
interactive=True,
)
# prompt_template = gr.Dropdown(
# choices=[
# "langchain_default",
# "openai_chatgpt",
# "deepmind_sparrow",
# "deepmind_gopher",
# "anthropic_hhh",
# ],
# value="langchain_default",
# label="Prompt Template",
# interactive=True,
# )
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.,
maximum=1.0,
value=0.8,
step=0.05,
interactive=True,
label="Top-p (nucleus sampling)",
)
with gr.Column(scale=1.8):
with gr.Row():
chatbot = gr.Chatbot(
label="Chat Output",
)
with gr.Row():
chat_input = gr.Textbox(lines=1, label="Chat Input")
chat_input.submit(
inference_chat,
[
model_id,
chat_input,
temperature,
top_p,
state,
],
[chatbot, state],
)
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True)
clear_button.click(
lambda: ("", [], []),
[],
[chat_input, chatbot, state],
queue=False,
)
submit_button = gr.Button(
value="Submit", interactive=True, variant="primary"
)
submit_button.click(
inference_chat,
[
model_id,
chat_input,
temperature,
top_p,
state,
],
[chatbot, state],
)
iface.launch()