Spaces:
Runtime error
Runtime error
from huggingface_hub import InferenceClient | |
import gradio as gr | |
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1") | |
val_image = gr.Image("/file=val_speaking_transparent.gif") | |
PLACEHOLDER = f""" | |
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Hi Jennifer, welcome to Treasury and Finance</h1> | |
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything about working at here...</p> | |
<p>You might find these links of interest:</p> | |
<a href="https://www.dtf.vic.gov.au/funds-programs-and-policies/victorian-public-service-enterprise-agreement-2020">Read about our enterprise agreement</a> | |
<a href="https://www.vic.gov.au/make-content-accessible">Read guidance to making accessible content</a> | |
<a href="https://www.vic.gov.au/victorian-government-directory">Here's the Victorian Government directory</a> | |
</div>. | |
""" | |
DESCRIPTION = """ | |
You might find these links of interest | |
- [Read about our enterprise agreement](https://www.dtf.vic.gov.au/funds-programs-and-policies/victorian-public-service-enterprise-agreement-2020) | |
- [Here's the Victorian Government directory](https://www.vic.gov.au/victorian-government-directory) | |
- [Read guidance to making accessible content](https://www.vic.gov.au/make-content-accessible) | |
""" | |
TITLE = "Hi I'm Val the Voyager, welcome onboard!" | |
def format_prompt(message, history): | |
prompt = "<s>" | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def generate( | |
prompt, | |
history, | |
temperature=0.9, | |
max_new_tokens=256, | |
top_p=0.95, | |
repetition_penalty=1.0, | |
): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
formatted_prompt = format_prompt(prompt, history) | |
stream = client.text_generation( | |
formatted_prompt, | |
**generate_kwargs, | |
stream=True, | |
details=True, | |
return_full_text=False, | |
) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
yield output | |
return output | |
additional_inputs = [ | |
gr.Slider( | |
label="Temperature", | |
value=0.9, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
), | |
gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=1048, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.90, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
), | |
] | |
gr.ChatInterface( | |
fn=generate, | |
chatbot=gr.Chatbot( | |
show_share_button=False, | |
show_copy_button=True, | |
likeable=True, | |
layout="bubble", | |
placeholder=PLACEHOLDER, | |
# label=DESCRIPTION, | |
# show_label=True, | |
), | |
additional_inputs=additional_inputs, | |
examples=[ | |
["What should I do on my first day?"], | |
["Ask me what an acronym stands for"], | |
["How can I check my leave allowance?"], | |
["Where can I find a floor map of 1 Macarthur?"], | |
["How can I find out about DTF's Disability network?"], | |
], | |
cache_examples=False, | |
title=TITLE, | |
# description=DESCRIPTION, | |
).launch(show_api=False) | |