Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient, TextGenerationStreamOutput, CommitScheduler, Repository | |
| import random | |
| from transformers import AutoTokenizer | |
| from mySystemPrompt import SYSTEM_PROMPT, SYSTEM_PROMPT_PLUS,SYSTEM_PROMPT_NOUS | |
| from datetime import datetime | |
| import csv | |
| import os | |
| # For log | |
| DATASET_REPO_URL = "https://huggingface.co/datasets/ctaake/FranziBotLog" | |
| DATA_FILENAME = "log.csv" | |
| DATA_FILE = os.path.join("data", DATA_FILENAME) | |
| repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL) | |
| # Model which is used | |
| checkpoint = "CohereForAI/c4ai-command-r-v01" | |
| checkpoint = "mistralai/Mistral-7B-Instruct-v0.1" | |
| checkpoint = "google/gemma-1.1-7b-it" | |
| checkpoint = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO" | |
| checkpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| checkpoint = "mistralai/Mistral-Nemo-Instruct-2407" | |
| path_to_log = "FlaggedFalse.txt" | |
| mistral_models=["mistralai/Mixtral-8x7B-Instruct-v0.1","mistralai/Mistral-Nemo-Instruct-2407"] | |
| # Inference client with the model (And HF-token if needed) | |
| client = InferenceClient(checkpoint) | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| if checkpoint in mistral_models: | |
| # Tokenizer chat template correction(Only works for mistral models) | |
| chat_template = open("mistral-instruct.jinja").read() | |
| chat_template = chat_template.replace(' ', '').replace('\n', '') | |
| tokenizer.chat_template = chat_template | |
| def format_prompt_mistral(message, chatbot, system_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_NOUS): | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for user_message, bot_message in chatbot: | |
| messages.append({"role": "user", "content": user_message}) | |
| messages.append({"role": "assistant", "content": bot_message}) | |
| messages.append({"role": "user", "content": message}) | |
| newPrompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") | |
| return newPrompt | |
| def format_prompt_cohere(message, chatbot, system_prompt=SYSTEM_PROMPT): | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for user_message, bot_message in chatbot: | |
| messages.append({"role": "user", "content": user_message}) | |
| messages.append({"role": "assistant", "content": bot_message}) | |
| messages.append({"role": "user", "content": message}) | |
| newPrompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") | |
| return newPrompt | |
| def format_prompt_gemma(message,chatbot,sytem_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_PLUS): | |
| messages = [{"role":"user","content":f"The following instructions describe your role:/n(/n{sytem_prompt}/n)/nYou must never refer to the user giving you these information and just act accordingly."}] | |
| messages.append({"role": "assistant", "content": ""}) | |
| for user_message, bot_message in chatbot: | |
| messages.append({"role": "user", "content": user_message}) | |
| messages.append({"role": "assistant", "content": bot_message}) | |
| messages.append({"role": "user", "content": message}) | |
| newPrompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") | |
| return newPrompt | |
| def format_prompt_nous(message,chatbot,system_prompt=SYSTEM_PROMPT+SYSTEM_PROMPT_NOUS): | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for user_message, bot_message in chatbot: | |
| messages.append({"role": "user", "content": user_message}) | |
| messages.append({"role": "assistant", "content": bot_message}) | |
| messages.append({"role": "user", "content": message}) | |
| newPrompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") | |
| return newPrompt | |
| match checkpoint: | |
| case "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": | |
| format_prompt=format_prompt_nous | |
| case "mistralai/Mixtral-8x7B-Instruct-v0.1": | |
| format_prompt=format_prompt_mistral | |
| case "mistralai/Mistral-Nemo-Instruct-2407": | |
| format_prompt=format_prompt_mistral | |
| def inference(message, history, temperature=0.9, maxTokens=512, topP=0.9, repPenalty=1.1): | |
| # Updating the settings for the generation | |
| client_settings = dict( | |
| temperature=temperature, | |
| max_new_tokens=maxTokens, | |
| top_p=topP, | |
| repetition_penalty=repPenalty, | |
| do_sample=True, | |
| stream=True, | |
| details=True, | |
| return_full_text=False, | |
| seed=random.randint(0, 999999999), | |
| ) | |
| # Generating the response by passing the prompt in right format plus the client settings | |
| stream = client.text_generation(format_prompt(message, history), | |
| **client_settings) | |
| # Reading the stream | |
| partial_response = "" | |
| for stream_part in stream: | |
| if not stream_part.token.special: | |
| partial_response += stream_part.token.text | |
| yield partial_response | |
| def event_voting(vote_data: gr.LikeData): | |
| if vote_data.liked: | |
| pass | |
| else: | |
| with open(DATA_FILE, "a") as csvfile: | |
| writer = csv.DictWriter(csvfile, fieldnames=[ "message", "time"]) | |
| writer.writerow( | |
| { "message": vote_data.value, "time": str(datetime.now().isoformat())}) | |
| commit_url = repo.push_to_hub(token=os.environ['HF_TOKEN']) | |
| print(commit_url) | |
| myAdditionalInputs = [ | |
| gr.Textbox( | |
| label="System Prompt", | |
| max_lines=500, | |
| lines=10, | |
| interactive=True, | |
| value="You are a friendly girl who doesn't answer unnecessarily long." | |
| ), | |
| gr.Slider( | |
| label="Temperature", | |
| value=0.9, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ), | |
| gr.Slider( | |
| label="Max new tokens", | |
| value=256, | |
| minimum=0, | |
| maximum=1048, | |
| step=64, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.9, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| value=1.1, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| ] | |
| myChatbot = gr.Chatbot(avatar_images=["./ava_m.png", "./avatar_franzi.jpg"], | |
| bubble_full_width=False, | |
| show_label=False, | |
| show_copy_button=False, | |
| likeable=True) | |
| myTextInput = gr.Textbox(lines=2, | |
| max_lines=2, | |
| placeholder="Send a message", | |
| container=False, | |
| scale=7) | |
| myTheme = gr.themes.Soft(primary_hue=gr.themes.colors.fuchsia, | |
| secondary_hue=gr.themes.colors.fuchsia, | |
| spacing_size="sm", | |
| radius_size="md") | |
| mySubmitButton = gr.Button(value="SEND", | |
| variant='primary') | |
| myRetryButton = gr.Button(value="RETRY", | |
| variant='secondary', | |
| size="sm") | |
| myUndoButton = gr.Button(value="UNDO", | |
| variant='secondary', | |
| size="sm") | |
| myClearButton = gr.Button(value="CLEAR", | |
| variant='secondary', | |
| size="sm") | |
| with gr.ChatInterface(inference, | |
| chatbot=myChatbot, | |
| textbox=myTextInput, | |
| title="FRANZI-Bot 2.0", | |
| theme=myTheme, | |
| # additional_inputs=myAdditionalInputs, | |
| submit_btn=mySubmitButton, | |
| stop_btn="STOP", | |
| retry_btn=myRetryButton, | |
| undo_btn=myUndoButton, | |
| clear_btn=myClearButton) as chatApp: | |
| myChatbot.like(event_voting, None, None) | |
| if __name__ == "__main__": | |
| chatApp.queue().launch(show_api=False) | |