basakerdogan's picture
Update app.py
4787649 verified
raw
history blame contribute delete
No virus
7.93 kB
import os
from typing import Iterator
import gradio as gr
from text_generation import Client
model_id = 'mistralai/Mistral-7B-Instruct-v0.1'
API_URL = "https://api-inference.huggingface.co/models/" + model_id
HF_TOKEN = os.environ.get('HF_READ_TOKEN', False)
client = Client(
API_URL,
headers={'Authorization': f"Bearer {HF_TOKEN}"}
)
EOS_STRING = "</s>"
EOT_STRING = "<EOT>"
def get_prompt(message, chat_history, system_prompt):
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
message = message.strip() if do_strip else message
texts.append(f"{message} [/INST]")
return ''.join(texts)
def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0.1, top_p=0.9, top_k=50):
prompt = get_prompt(message, chat_history, system_prompt)
generate_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature
)
stream = client.generate_stream(prompt, **generate_kwargs)
output = ''
for response in stream:
if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
return output
else:
output += response.token.text
yield output
return output
DEFAULT_SYSTEM_PROMPT = """
You are Ricky. You are an AI assistant, you are moderately-polite and give only true information.
You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning.
If you think there might not be a correct answer, you say so. Since you are autoregressive,
each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context,
assumptions, and step-by-step thinking BEFORE you try to answer a question. You are an AI developed by MCES10 Software the website is www.mces10-software.com.
The CEO is MCES10. You are based on the Mistral-7B-Instruct-v0.1. You ask what the person's name is if they say hello. MCES10 Software has made apps named To-List a brilliant to-do list app. Web Development Tutorials also known as W.D.T which teaches you how to code websites.
There are AI Hub which is in development which is the gateway to everything AI.
"""
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 256
MAX_INPUT_TOKEN_LENGTH = 4000
DESCRIPTION = "Ricky AI"
def clear_and_save_textbox(message): return '', message
def display_input(message, history=[]):
history.append((message, ''))
return history
def delete_prev_fn(history=[]):
try:
message, _ = history.pop()
except IndexError:
message = ''
return history, message or ''
def generate(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
if max_new_tokens > MAX_MAX_NEW_TOKENS:
raise ValueError
history = history_with_input[:-1]
generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
try:
first_response = next(generator)
yield history + [(message, first_response)]
except StopIteration:
yield history + [(message, '')]
for response in generator:
yield history + [(message, response)]
def process_example(message):
generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
for x in generator:
pass
return '', x
def check_input_token_length(message, chat_history, system_prompt):
input_token_length = len(message) + len(chat_history)
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
with gr.Blocks(theme='Taithrah/Minimal') as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
chatbot = gr.Chatbot(label='RickyAI based on Mistral-7B-Instruct-v0.1')
with gr.Row():
textbox = gr.Textbox(
container=False,
show_label=False,
placeholder='Hi, Ricky',
scale=10
)
submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0)
with gr.Row():
retry_button = gr.Button('Retry', variant='secondary')
undo_button = gr.Button('Undo', variant='secondary')
clear_button = gr.Button('Clear', variant='secondary')
saved_input = gr.State()
with gr.Accordion(label='Advanced options', open=False):
system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10)
textbox.submit(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=check_input_token_length,
inputs=[saved_input, chatbot, system_prompt],
api_name=False,
queue=False,
).success(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
],
outputs=chatbot,
api_name=False,
)
button_event_preprocess = submit_button.click(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=check_input_token_length,
inputs=[saved_input, chatbot, system_prompt],
api_name=False,
queue=False,
).success(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
],
outputs=chatbot,
api_name=False,
)
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
],
outputs=chatbot,
api_name=False,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=textbox,
api_name=False,
queue=False,
)
clear_button.click(
fn=lambda: ([], ''),
outputs=[chatbot, saved_input],
queue=False,
api_name=False,
)
demo.queue(max_size=32).launch(show_api=False)