Spaces:
Runtime error
Runtime error
File size: 8,728 Bytes
0be3d69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import os
import json
import gradio as gr
from datetime import datetime
from threading import Lock
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
# ========== Auto-create folders ==========
os.makedirs("chat_history", exist_ok=True)
os.makedirs("system", exist_ok=True)
# ========== Load System Context ==========
context_path = "system/context.txt"
if not os.path.exists(context_path):
raise FileNotFoundError(f"Missing system context file at {context_path}!")
with open(context_path, "r", encoding="utf-8") as f:
loaded_context = f.read()
# ========== Simple Chatbot Logic ==========
lock = Lock()
# Provide the folder path, not the file path
model_folder = "model/Mistral-7B-Instruct-v0.3"
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_folder, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_folder)
# Set pad_token to eos_token if pad_token is not available
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Initialize the pipeline for text generation
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ========== Helper Functions ==========
def sanitize_username(username):
return ''.join(c for c in username if c.isalnum() or c in ('_', '-')).strip()
def user_folder(username):
return os.path.join("chat_history", username)
def load_latest_history(username):
folder = user_folder(username)
if not os.path.exists(folder):
os.makedirs(folder, exist_ok=True)
return []
files = sorted(os.listdir(folder), reverse=True)
if not files:
return []
latest_file = os.path.join(folder, files[0])
with open(latest_file, "r", encoding="utf-8") as f:
lines = f.readlines()
history = []
for line in lines:
if ": " in line:
user, msg = line.split(": ", 1)
history.append((user.strip(), msg.strip()))
return history
def save_history(username, history):
folder = user_folder(username)
os.makedirs(folder, exist_ok=True)
filepath = os.path.join(folder, "history.txt")
with open(filepath, "a", encoding="utf-8") as f:
# Only write the last two new entries (user + Sanny Lin)
for user, msg in history[-2:]:
f.write(f"{user}: {msg}\n")
def format_chat(history):
formatted = ""
for user, msg in history:
if user == "Sanny Lin":
formatted += f"""
<div style='text-align: left; margin: 5px;'>
<span class='sanny-message' style='background-color: #e74c3c; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'>
{msg}
</span>
</div>
"""
else:
formatted += f"""
<div style='text-align: right; margin: 5px;'>
<span style='background-color: #3498db; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'>
{msg}
</span>
</div>
"""
return formatted
def generate_reply(username, user_message, history):
with lock:
if not user_message.strip():
return history
# Retrieve the last 30 messages, including history from the user
history = history[-30:] # Limit to the last 30 messages
messages = []
# Start with the system context
if not history:
messages.append({"role": "system", "content": loaded_context})
# Add the last 30 messages to the conversation history
for user, msg in history:
role = "user" if user == username else "assistant"
messages.append({"role": role, "content": msg})
# Add the user message at the end
messages.append({"role": "user", "content": user_message})
# Append the personalized prompt "You are chatting with {{ username }} now:" at the end of the context
user_prompt = f"You are chatting with {username} now. Reply to this message:"
messages.append({"role": "system", "content": user_prompt})
# Extract the content part of each message for encoding
text_messages = [message["content"] for message in messages]
# Tokenize using only the content part
prompt = tokenizer.batch_encode_plus(text_messages, return_tensors="pt", padding=True, truncation=False)
# Generate the assistant's reply without the user message being included at the start
generated_output = generator(user_message,
max_length=32768,
max_new_tokens=512,# Set max length for truncation
num_return_sequences=1,
do_sample=True,
temperature=0.5,
top_p=0.5,
top_k=0,
typical_p=1,
repetition_penalty=1) # Disable sampling for more creative and deterministic responses
response = generated_output[0]["generated_text"]
# Clean the response to remove any prefix from the last user message
if response.startswith(user_message):
response = response[len(user_message):].strip()
# Smart truncation to cut off at 4096 characters without cutting in the middle of a word
max_length = 4096
if len(response) > max_length:
# Find the last space before the cutoff point
truncated_response = response[:max_length]
last_space_idx = truncated_response.rfind(" ")
if last_space_idx != -1:
response = truncated_response[:last_space_idx]
else:
response = truncated_response
# Add the user message and assistant's response to history
history.append((username, user_message))
history.append(("Sanny Lin", response))
save_history(username, history)
return format_chat(history)
# ========== Gradio Interface ==========
with gr.Blocks(theme=gr.themes.Monochrome(), css="""
@font-face {
font-family: "DaemonFont";
src: url('static/daemon.otf') format('opentype');
}
body { background-color: #121212 !important; }
.gradio-container { background-color: #121212 !important; }
textarea { background-color: #1e1e1e !important; color: white; }
input { background-color: #1e1e1e !important; color: white; }
#chat_display { overflow-y: auto; height: calc(100vh - 200px); }
.sanny-message {
font-family: "DaemonFont", sans-serif;
}
""") as demo:
chat_display = gr.HTML(value="", elem_id="chat_display", show_label=False)
with gr.Row():
username_box = gr.Textbox(label="Username", placeholder="Enter username...", interactive=True, scale=2)
user_input = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False, scale=8)
send_button = gr.Button("Send", scale=1)
username_state = gr.State("")
history_state = gr.State([])
def user_send(user_message, username, history, username_input):
if not username_input.strip():
return "<div style='color: red;'>Please enter a valid username first.</div>", history, username
username_input = sanitize_username(username_input)
if not username:
username = username_input
history = history or load_latest_history(username)
return generate_reply(username, user_message, history), history, username
send_button.click(
fn=user_send,
inputs=[user_input, username_state, history_state, username_box],
outputs=[chat_display, history_state, username_state]
)
send_button.click(lambda: "", None, user_input) # Clear input after send
demo.load(None, None, None, js="""
() => {
const textbox = document.querySelector('textarea');
const sendButton = document.querySelector('button');
textbox.addEventListener('keydown', function(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendButton.click();
}
});
}
""")
demo.launch(share=False)
|