File size: 8,728 Bytes
0be3d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import json
import gradio as gr
from datetime import datetime
from threading import Lock
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
# ========== Auto-create folders ==========
os.makedirs("chat_history", exist_ok=True)
os.makedirs("system", exist_ok=True)

# ========== Load System Context ==========
context_path = "system/context.txt"
if not os.path.exists(context_path):
    raise FileNotFoundError(f"Missing system context file at {context_path}!")

with open(context_path, "r", encoding="utf-8") as f:
    loaded_context = f.read()

# ========== Simple Chatbot Logic ==========
lock = Lock()

# Provide the folder path, not the file path
model_folder = "model/Mistral-7B-Instruct-v0.3"

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_folder, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_folder)

# Set pad_token to eos_token if pad_token is not available
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Initialize the pipeline for text generation
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)


# ========== Helper Functions ==========

def sanitize_username(username):
    return ''.join(c for c in username if c.isalnum() or c in ('_', '-')).strip()

def user_folder(username):
    return os.path.join("chat_history", username)

def load_latest_history(username):
    folder = user_folder(username)
    if not os.path.exists(folder):
        os.makedirs(folder, exist_ok=True)
        return []
    files = sorted(os.listdir(folder), reverse=True)
    if not files:
        return []
    latest_file = os.path.join(folder, files[0])
    with open(latest_file, "r", encoding="utf-8") as f:
        lines = f.readlines()
    history = []
    for line in lines:
        if ": " in line:
            user, msg = line.split(": ", 1)
            history.append((user.strip(), msg.strip()))
    return history

def save_history(username, history):
    folder = user_folder(username)
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, "history.txt")
    
    with open(filepath, "a", encoding="utf-8") as f:
        # Only write the last two new entries (user + Sanny Lin)
        for user, msg in history[-2:]:
            f.write(f"{user}: {msg}\n")

def format_chat(history):
    formatted = ""
    for user, msg in history:
        if user == "Sanny Lin":
            formatted += f"""

            <div style='text-align: left; margin: 5px;'>

                <span class='sanny-message' style='background-color: #e74c3c; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'>

                    {msg}

                </span>

            </div>

            """
        else:
            formatted += f"""

            <div style='text-align: right; margin: 5px;'>

                <span style='background-color: #3498db; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'>

                    {msg}

                </span>

            </div>

            """
    return formatted


def generate_reply(username, user_message, history):
    with lock:
        if not user_message.strip():
            return history

        # Retrieve the last 30 messages, including history from the user
        history = history[-30:]  # Limit to the last 30 messages

        messages = []
        
        # Start with the system context
        if not history:
            messages.append({"role": "system", "content": loaded_context})
        
        # Add the last 30 messages to the conversation history
        for user, msg in history:
            role = "user" if user == username else "assistant"
            messages.append({"role": role, "content": msg})

        # Add the user message at the end
        messages.append({"role": "user", "content": user_message})

        # Append the personalized prompt "You are chatting with {{ username }} now:" at the end of the context
        user_prompt = f"You are chatting with {username} now. Reply to this message:"
        messages.append({"role": "system", "content": user_prompt})

        # Extract the content part of each message for encoding
        text_messages = [message["content"] for message in messages]

        # Tokenize using only the content part
        prompt = tokenizer.batch_encode_plus(text_messages, return_tensors="pt", padding=True, truncation=False)

        # Generate the assistant's reply without the user message being included at the start
        generated_output = generator(user_message, 
                                     max_length=32768,
                                     max_new_tokens=512,# Set max length for truncation
                                     num_return_sequences=1,
                                     do_sample=True,
                                     temperature=0.5,
                                     top_p=0.5,
                                     top_k=0,
                                     typical_p=1,
                                     repetition_penalty=1)  # Disable sampling for more creative and deterministic responses

        response = generated_output[0]["generated_text"]

        # Clean the response to remove any prefix from the last user message
        if response.startswith(user_message):
            response = response[len(user_message):].strip()

        # Smart truncation to cut off at 4096 characters without cutting in the middle of a word
        max_length = 4096
        if len(response) > max_length:
            # Find the last space before the cutoff point
            truncated_response = response[:max_length]
            last_space_idx = truncated_response.rfind(" ")

            if last_space_idx != -1:
                response = truncated_response[:last_space_idx]
            else:
                response = truncated_response

        # Add the user message and assistant's response to history
        history.append((username, user_message))
        history.append(("Sanny Lin", response))

        save_history(username, history)

        return format_chat(history)





# ========== Gradio Interface ==========

with gr.Blocks(theme=gr.themes.Monochrome(), css="""

    @font-face {

        font-family: "DaemonFont";

        src: url('static/daemon.otf') format('opentype');

    }

    body { background-color: #121212 !important; }

    .gradio-container { background-color: #121212 !important; }

    textarea { background-color: #1e1e1e !important; color: white; }

    input { background-color: #1e1e1e !important; color: white; }

    #chat_display { overflow-y: auto; height: calc(100vh - 200px); }

    .sanny-message {

        font-family: "DaemonFont", sans-serif;

    }

""") as demo:

    chat_display = gr.HTML(value="", elem_id="chat_display", show_label=False)

    with gr.Row():
        username_box = gr.Textbox(label="Username", placeholder="Enter username...", interactive=True, scale=2)
        user_input = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False, scale=8)
        send_button = gr.Button("Send", scale=1)

    username_state = gr.State("")
    history_state = gr.State([])

    def user_send(user_message, username, history, username_input):
        if not username_input.strip():
            return "<div style='color: red;'>Please enter a valid username first.</div>", history, username

        username_input = sanitize_username(username_input)
        if not username:
            username = username_input

        history = history or load_latest_history(username)

        return generate_reply(username, user_message, history), history, username

    send_button.click(
        fn=user_send,
        inputs=[user_input, username_state, history_state, username_box],
        outputs=[chat_display, history_state, username_state]
    )

    send_button.click(lambda: "", None, user_input)  # Clear input after send

    demo.load(None, None, None, js="""

    () => {

        const textbox = document.querySelector('textarea');

        const sendButton = document.querySelector('button');



        textbox.addEventListener('keydown', function(e) {

            if (e.key === 'Enter' && !e.shiftKey) {

                e.preventDefault();

                sendButton.click();

            }

        });

    }

    """)

demo.launch(share=False)