Spaces:
Runtime error
Runtime error
"""Python Application Script for AI chatbot using LLAMA CPP.""" | |
import logging | |
import os | |
import gradio as gr | |
from llama_cpp import Llama | |
# Setting up enviornment | |
log_level = os.environ.get("LOG_LEVEL", "WARNING") | |
logging.basicConfig(encoding='utf-8', level=log_level) | |
# Default System Prompt | |
DEFAULT_SYSTEM_PROMPT = os.environ.get("DEFAULT_SYSTEM", "You are Dolphin, a helpful AI assistant.") | |
# Model Path | |
model_path = "model.gguf" | |
logging.debug("Model Path: %s", model_path) | |
logging.info("Loading Moddel") | |
llm = Llama(model_path=model_path, n_ctx=4000, n_threads=2, chat_format="chatml") | |
def generate( | |
message: str, | |
history: list[tuple[str, str]], | |
system_prompt: str, | |
temperature: float = 0.1, | |
max_tokens: int = 512, | |
top_p: float = 0.95, | |
repetition_penalty: float = 1.0, | |
): | |
"""Function to generate text. | |
:param message: The new user prompt. | |
:param history: The history of the chat session. | |
:param system: The system prompt of the model. | |
:param temperature: The temperature parameter for the model. | |
:param max_tokens: The maximum amount of tokens to use for the model. | |
:param top_p: The top p value for the model. | |
:param repetition_penalty: The repetition penalty for the model. | |
""" | |
logging.info("Generating Text") | |
logging.debug("message: %s", message) | |
logging.debug("history: %s", history) | |
logging.debug("system: %s", system) | |
logging.debug("temperature: %s", temperature) | |
logging.debug("max_tokens: %s", max_tokens) | |
logging.debug("top_p: %s", top_p) | |
logging.debug("repetion_penalty: %s", repetition_penalty) | |
# Formatting Prompt | |
logging.info("Formatting Prompt") | |
formatted_prompt = [{"role": "system", "content": system_prompt}] | |
for user_prompt, bot_response in history: | |
formatted_prompt.append({"role": "user", "content": user_prompt}) | |
formatted_prompt.append({"role": "assistant", "content": bot_response}) | |
formatted_prompt.append({"role": "user", "content": message}) | |
logging.debug("Formatted Prompt: %s", formatted_prompt) | |
# Generating Response | |
logging.info("Generating Response") | |
stream_response = llm.create_chat_completion( | |
messages=formatted_prompt, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
repeat_penalty=repetition_penalty, | |
stream=True, | |
) | |
# Parsing Response | |
logging.info("Parsing Response") | |
response = "" | |
for chunk in stream_response: | |
if ( | |
len(chunk["choices"][0]["delta"]) != 0 | |
and "content" in chunk["choices"][0]["delta"] | |
): | |
response += chunk["choices"][0]["delta"]["content"] | |
logging.debug("Response: %s", response) | |
yield response | |
additional_inputs = [ | |
gr.Textbox( | |
label="System Prompt", | |
max_lines=1, | |
interactive=True, | |
value=DEFAULT_SYSTEM_PROMPT, | |
), | |
gr.Slider( | |
label="Temperature", | |
value=0.9, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
), | |
gr.Slider( | |
label="Max new tokens", | |
value=256, | |
minimum=0, | |
maximum=1048, | |
step=64, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
), | |
gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.90, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
), | |
gr.Slider( | |
label="Repetition penalty", | |
value=1.2, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
info="Penalize repeated tokens", | |
) | |
] | |
examples = [] | |
logging.info("Creating Chatbot") | |
mychatbot = gr.Chatbot(avatar_images=["user.png", "botsc.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,) | |
logging.info("Creating Chat Interface") | |
iface = gr.ChatInterface( | |
fn=generate, | |
chatbot=mychatbot, | |
additional_inputs=additional_inputs, | |
examples=examples, | |
concurrency_limit=20, | |
title="LLAMA CPP Template" | |
) | |
logging.info("Starting Application") | |
iface.launch(show_api=False, server_name="0.0.0.0") |