Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM | |
import os | |
import json | |
import time | |
import logging | |
from threading import Lock | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
messages = [ | |
{"role": "user", "content": "Who are you?"}, | |
] | |
pipe = pipeline("text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1") | |
pipe(messages) | |
class EnhancedChatbot: | |
def __init__(self): | |
self.model = None | |
self.config = self.load_config() | |
self.model_lock = Lock() | |
self.load_model() | |
def load_config(self): | |
if os.path.exists(CONFIG_FILE): | |
with open(CONFIG_FILE, 'r') as f: | |
return json.load(f) | |
return { | |
"model_name": MODEL_NAME, | |
"max_tokens": 512, | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"system_message": "You are a friendly and helpful AI assistant.", | |
"gpu_layers": 0 | |
} | |
def save_config(self): | |
with open(CONFIG_FILE, 'w') as f: | |
json.dump(self.config, f, indent=2) | |
def load_model(self): | |
try: | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.config["model_name"], | |
model_type="llama", | |
gpu_layers=self.config["gpu_layers"], | |
cache_dir=CACHE_DIR | |
) | |
logging.info(f"Model loaded successfully: {self.config['model_name']}") | |
except Exception as e: | |
logging.error(f"Error loading model: {str(e)}") | |
raise | |
def generate_response(self, message, history, system_message, max_tokens, temperature, top_p): | |
prompt = f"{system_message}\n\n" | |
for user_msg, assistant_msg in history: | |
prompt += f"Human: {user_msg}\nAssistant: {assistant_msg}\n" | |
prompt += f"Human: {message}\nAssistant: " | |
start_time = time.time() | |
with self.model_lock: | |
generated_text = self.model( | |
prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
end_time = time.time() | |
response_time = end_time - start_time | |
logging.info(f"Response generated in {response_time:.2f} seconds") | |
return generated_text.strip() | |
chatbot = EnhancedChatbot() | |
def respond(message, history, system_message, max_tokens, temperature, top_p): | |
try: | |
response = chatbot.generate_response(message, history, system_message, max_tokens, temperature, top_p) | |
yield response | |
except Exception as e: | |
logging.error(f"Error generating response: {str(e)}") | |
yield "I apologize, but I encountered an error while processing your request. Please try again." | |
def update_model_config(model_name, gpu_layers): | |
chatbot.config["model_name"] = model_name | |
chatbot.config["gpu_layers"] = gpu_layers | |
chatbot.save_config() | |
chatbot.load_model() | |
return f"Model updated to {model_name} with {gpu_layers} GPU layers." | |
def update_system_message(system_message): | |
chatbot.config["system_message"] = system_message | |
chatbot .save_config() | |
return f"System message updated: {system_message}" | |
with gr.Blocks() as demo: | |
gr.Markdown("# Enhanced AI Chatbot") | |
with gr.Tab("Chat"): | |
chatbot_interface= gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value=chatbot.config["system_message"], label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=chatbot.config["max_tokens"], step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=chatbot.config["temperature"], step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=chatbot.config["top_p"], | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
with gr.Tab("Settings"): | |
with gr.Group(): | |
gr.Markdown("### Model Settings") | |
model_name_input = gr.Textbox(value=chatbot.config["model_name"], label="Model name") | |
gpu_layers_input = gr.Slider(minimum=0, maximum=8, value=chatbot.config["gpu_layers"], step=1, label="GPU layers") | |
update_model_button = gr.Button("Update model") | |
update_model_button.click(update_model_config, inputs=[model_name_input, gpu_layers_input], outputs="text") | |
with gr.Group(): | |
gr.Markdown("### System Message Settings") | |
system_message_input = gr.Textbox(value=chatbot.config["system_message"], label="System message") | |
update_system_message_button = gr.Button("Update system message") | |
update_system_message_button.click(update_system_message, inputs=[system_message_input], outputs="text") | |
if __name__ == "__main__": | |
demo.launch() |