Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient, TextGenerationStreamOutput, CommitScheduler, Repository | |
import random | |
from transformers import AutoTokenizer | |
from mySystemPrompt import SYSTEM_PROMPT, SYSTEM_PROMPT_PLUS,SYSTEM_PROMPT_NOUS | |
from datetime import datetime | |
import csv | |
import os | |
# For log | |
DATASET_REPO_URL = "https://huggingface.co/datasets/ctaake/FranziBotLog" | |
DATA_FILENAME = "log.csv" | |
DATA_FILE = os.path.join("data", DATA_FILENAME) | |
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL) | |
# Model which is used | |
checkpoint = "CohereForAI/c4ai-command-r-v01" | |
checkpoint = "mistralai/Mistral-7B-Instruct-v0.1" | |
checkpoint = "google/gemma-1.1-7b-it" | |
checkpoint = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO" | |
checkpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
checkpoint = "mistralai/Mistral-Nemo-Instruct-2407" | |
path_to_log = "FlaggedFalse.txt" | |
mistral_models=["mistralai/Mixtral-8x7B-Instruct-v0.1","mistralai/Mistral-Nemo-Instruct-2407"] | |
# Inference client with the model (And HF-token if needed) | |
client = InferenceClient(checkpoint) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
if checkpoint in mistral_models: | |
# Tokenizer chat template correction(Only works for mistral models) | |
chat_template = open("mistral-instruct.jinja").read() | |
chat_template = chat_template.replace(' ', '').replace('\n', '') | |
tokenizer.chat_template = chat_template | |
def format_prompt_mistral(message, chatbot, system_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_NOUS): | |
messages = [{"role": "system", "content": system_prompt}] | |
for user_message, bot_message in chatbot: | |
messages.append({"role": "user", "content": user_message}) | |
messages.append({"role": "assistant", "content": bot_message}) | |
messages.append({"role": "user", "content": message}) | |
newPrompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") | |
return newPrompt | |
def format_prompt_cohere(message, chatbot, system_prompt=SYSTEM_PROMPT): | |
messages = [{"role": "system", "content": system_prompt}] | |
for user_message, bot_message in chatbot: | |
messages.append({"role": "user", "content": user_message}) | |
messages.append({"role": "assistant", "content": bot_message}) | |
messages.append({"role": "user", "content": message}) | |
newPrompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") | |
return newPrompt | |
def format_prompt_gemma(message,chatbot,sytem_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_PLUS): | |
messages = [{"role":"user","content":f"The following instructions describe your role:/n(/n{sytem_prompt}/n)/nYou must never refer to the user giving you these information and just act accordingly."}] | |
messages.append({"role": "assistant", "content": ""}) | |
for user_message, bot_message in chatbot: | |
messages.append({"role": "user", "content": user_message}) | |
messages.append({"role": "assistant", "content": bot_message}) | |
messages.append({"role": "user", "content": message}) | |
newPrompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") | |
return newPrompt | |
def format_prompt_nous(message,chatbot,system_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_NOUS): | |
messages = [{"role": "system", "content": system_prompt}] | |
for user_message, bot_message in chatbot: | |
messages.append({"role": "user", "content": user_message}) | |
messages.append({"role": "assistant", "content": bot_message}) | |
messages.append({"role": "user", "content": message}) | |
newPrompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") | |
return newPrompt | |
match checkpoint: | |
case "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": | |
format_prompt=format_prompt_nous | |
case "mistralai/Mixtral-8x7B-Instruct-v0.1": | |
format_prompt=format_prompt_mistral | |
case "mistralai/Mistral-Nemo-Instruct-2407": | |
format_prompt=format_prompt_mistral | |
def inference(message, history, temperature=0.9, maxTokens=512, topP=0.9, repPenalty=1.1): | |
# Updating the settings for the generation | |
client_settings = dict( | |
temperature=temperature, | |
max_new_tokens=maxTokens, | |
top_p=topP, | |
repetition_penalty=repPenalty, | |
do_sample=True, | |
stream=True, | |
details=True, | |
return_full_text=False, | |
seed=random.randint(0, 999999999), | |
) | |
# Generating the response by passing the prompt in right format plus the client settings | |
stream = client.text_generation(format_prompt(message, history), | |
**client_settings) | |
# Reading the stream | |
partial_response = "" | |
for stream_part in stream: | |
if not stream_part.token.special: | |
partial_response += stream_part.token.text | |
yield partial_response | |
def event_voting(vote_data: gr.LikeData): | |
if vote_data.liked: | |
pass | |
else: | |
with open(DATA_FILE, "a") as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=[ "message", "time"]) | |
writer.writerow( | |
{ "message": vote_data.value, "time": str(datetime.now().isoformat())}) | |
commit_url = repo.push_to_hub(token=os.environ['HF_TOKEN']) | |
print(commit_url) | |
myAdditionalInputs = [ | |
gr.Textbox( | |
label="System Prompt", | |
max_lines=500, | |
lines=10, | |
interactive=True, | |
value="You are a friendly girl who doesn't answer unnecessarily long." | |
), | |
gr.Slider( | |
label="Temperature", | |
value=0.9, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
), | |
gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=1048, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.9, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
value=1.1, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
] | |
myChatbot = gr.Chatbot(avatar_images=["./ava_m.png", "./avatar_franzi.jpg"], | |
bubble_full_width=False, | |
show_label=False, | |
show_copy_button=False, | |
likeable=True) | |
myTextInput = gr.Textbox(lines=2, | |
max_lines=2, | |
placeholder="Send a message", | |
container=False, | |
scale=7) | |
myTheme = gr.themes.Soft(primary_hue=gr.themes.colors.fuchsia, | |
secondary_hue=gr.themes.colors.fuchsia, | |
spacing_size="sm", | |
radius_size="md") | |
mySubmitButton = gr.Button(value="SEND", | |
variant='primary') | |
myRetryButton = gr.Button(value="RETRY", | |
variant='secondary', | |
size="sm") | |
myUndoButton = gr.Button(value="UNDO", | |
variant='secondary', | |
size="sm") | |
myClearButton = gr.Button(value="CLEAR", | |
variant='secondary', | |
size="sm") | |
with gr.ChatInterface(inference, | |
chatbot=myChatbot, | |
textbox=myTextInput, | |
title="FRANZI-Bot 2.0", | |
theme=myTheme, | |
# additional_inputs=myAdditionalInputs, | |
submit_btn=mySubmitButton, | |
stop_btn="STOP", | |
retry_btn=myRetryButton, | |
undo_btn=myUndoButton, | |
clear_btn=myClearButton) as chatApp: | |
myChatbot.like(event_voting, None, None) | |
if __name__ == "__main__": | |
chatApp.queue().launch(show_api=False) | |