Chatbot / app.py
Kiy-K's picture
Update app.py
36b6bbe verified
raw
history blame
9.23 kB
# app.py — full version with memory + web search + datasets
import os
import json
import threading
import gradio as gr
from huggingface_hub import InferenceClient, snapshot_download
from datasets import load_dataset
from duckduckgo_search import DDGS
client = InferenceClient(
provider="cerebras",
api_key=os.environ["csk-933e3whtcvhjtfchfmmk4ncdtc86jp26v4vkn9rd5yk6ny5c"],
)
# ---------------- CONFIG ----------------
MODEL_ID = "openai/gpt-oss-120b" # or granite
DATA_DIR = "/data" if os.path.isdir("/data") else "./data"
os.makedirs(DATA_DIR, exist_ok=True)
SHORT_TERM_LIMIT = 10
SUMMARY_MAX_TOKENS = 150
MEMORY_LOCK = threading.Lock()
# ---------------- dataset loading ----------------
# ⚠️ Heavy startup, comment out if running on free HF Space
folder = snapshot_download(
"HuggingFaceFW/fineweb",
repo_type="dataset",
local_dir="./fineweb/",
allow_patterns="sample/10BT/*",
)
ds1 = load_dataset("HuggingFaceH4/ultrachat_200k")
ds2 = load_dataset("Anthropic/hh-rlhf")
# ---------------- helpers: memory ----------------
def get_user_id(hf_token: gr.OAuthToken | None):
if hf_token and getattr(hf_token, "token", None):
return "user_" + hf_token.token[:12]
return "anon"
def memory_file_path(user_id: str):
return os.path.join(DATA_DIR, f"memory_{user_id}.json")
def load_memory(user_id: str):
p = memory_file_path(user_id)
if os.path.exists(p):
try:
with open(p, "r", encoding="utf-8") as f:
mem = json.load(f)
if isinstance(mem, dict) and "short_term" in mem and "long_term" in mem:
return mem
except Exception as e:
print("load_memory error:", e)
return {"short_term": [], "long_term": ""}
def save_memory(user_id: str, memory: dict):
p = memory_file_path(user_id)
try:
with MEMORY_LOCK:
with open(p, "w", encoding="utf-8") as f:
json.dump(memory, f, ensure_ascii=False, indent=2)
except Exception as e:
print("save_memory error:", e)
# ---------------- normalize history ----------------
def normalize_history(history):
out = []
if not history: return out
for turn in history:
if isinstance(turn, dict) and "role" in turn and "content" in turn:
out.append({"role": turn["role"], "content": str(turn["content"])})
elif isinstance(turn, (list, tuple)) and len(turn) == 2:
user_msg, assistant_msg = turn
out.append({"role": "user", "content": str(user_msg)})
out.append({"role": "assistant", "content": str(assistant_msg)})
elif isinstance(turn, str):
out.append({"role": "user", "content": turn})
return out
# ---------------- sync completion ----------------
def _get_chat_response_sync(client: InferenceClient, messages, max_tokens=SUMMARY_MAX_TOKENS, temperature=0.3, top_p=0.9):
try:
resp = client.chat_completion(messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stream=False)
except Exception as e:
print("sync chat_completion error:", e)
return ""
try:
choices = resp.get("choices") if isinstance(resp, dict) else getattr(resp, "choices", None)
if choices:
c0 = choices[0]
msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
if isinstance(msg, dict):
return msg.get("content", "")
return getattr(msg, "content", "") or str(msg or "")
except Exception:
pass
return ""
# ---------------- web search ----------------
def web_search(query, num_results=3):
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=num_results))
search_context = "🔍 Web Search Results:\n\n"
for i, r in enumerate(results, 1):
title = r.get("title", "")[:200]
body = r.get("body", "")[:200].replace("\n", " ")
href = r.get("href", "")
search_context += f"{i}. {title}\n{body}...\nSource: {href}\n\n"
return search_context
except Exception as e:
return f"❌ Search error: {str(e)}"
# ---------------- summarization ----------------
def summarize_old_messages(client: InferenceClient, old_messages):
text = "\n".join([f"{m['role']}: {m['content']}" for m in old_messages])
system = {"role": "system", "content": "You are a summarizer. Summarize <=150 words."}
user = {"role": "user", "content": text}
return _get_chat_response_sync(client, [system, user])
# ---------------- memory tools ----------------
def show_memory(hf_token: gr.OAuthToken | None = None):
user = get_user_id(hf_token)
p = memory_file_path(user)
if not os.path.exists(p):
return "ℹ️ No memory file found for user: " + user
with open(p, "r", encoding="utf-8") as f:
return f.read()
def clear_memory(hf_token: gr.OAuthToken | None = None):
user = get_user_id(hf_token)
p = memory_file_path(user)
if os.path.exists(p):
os.remove(p)
return f"✅ Memory cleared for {user}"
return "ℹ️ No memory to clear."
# ---------------- main chat ----------------
def respond(message, history: list, system_message, max_tokens, temperature, top_p,
enable_search, enable_persistent_memory, hf_token: gr.OAuthToken = None):
client = InferenceClient(token=(hf_token.token if hf_token else None), model=MODEL_ID)
user_id = get_user_id(hf_token)
memory = load_memory(user_id) if enable_persistent_memory else {"short_term": [], "long_term": ""}
session_history = normalize_history(history)
combined = memory.get("short_term", []) + session_history
if len(combined) > SHORT_TERM_LIMIT:
to_summarize = combined[:len(combined) - SHORT_TERM_LIMIT]
summary = summarize_old_messages(client, to_summarize)
if summary:
memory["long_term"] = (memory.get("long_term", "") + "\n" + summary).strip()
combined = combined[-SHORT_TERM_LIMIT:]
combined.append({"role": "user", "content": message})
memory["short_term"] = combined
if enable_persistent_memory:
save_memory(user_id, memory)
messages = [{"role": "system", "content": system_message}]
if memory.get("long_term"):
messages.append({"role": "system", "content": "Long-term memory:\n" + memory["long_term"]})
messages.extend(memory["short_term"])
if enable_search and any(k in message.lower() for k in ["search", "google", "tin tức", "news", "what is"]):
sr = web_search(message)
messages.append({"role": "user", "content": f"{sr}\n\nBased on search results, answer: {message}"})
response = ""
try:
for chunk in client.chat_completion(messages, max_tokens=int(max_tokens),
stream=True, temperature=float(temperature), top_p=float(top_p)):
choices = chunk.get("choices") if isinstance(chunk, dict) else getattr(chunk, "choices", None)
if not choices: continue
c0 = choices[0]
delta = c0.get("delta") if isinstance(c0, dict) else getattr(c0, "delta", None)
token = None
if delta and (delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)):
token = delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None)
else:
msg = c0.get("message") if isinstance(c0, dict) else getattr(c0, "message", None)
if isinstance(msg, dict):
token = msg.get("content", "")
else:
token = getattr(msg, "content", None) or str(msg or "")
if token:
response += token
yield response
except Exception as e:
yield f"⚠️ Inference error: {e}"
return
memory["short_term"].append({"role": "assistant", "content": response})
memory["short_term"] = memory["short_term"][-SHORT_TERM_LIMIT:]
if enable_persistent_memory:
save_memory(user_id, memory)
# ---------------- Gradio UI ----------------
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="You are a helpful AI assistant.", label="System message"),
gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
gr.Checkbox(value=True, label="Enable Web Search 🔍"),
gr.Checkbox(value=True, label="Enable Persistent Memory"),
],
)
with gr.Blocks(title="AI Chatbot (full version)") as demo:
gr.Markdown("# 🤖 AI Chatbot with Memory + Web Search + Datasets")
with gr.Sidebar():
gr.LoginButton()
gr.Markdown("### Memory Tools")
gr.Button("👀 Show Memory").click(show_memory, inputs=None, outputs=gr.Textbox(label="Memory"))
gr.Button("🗑️ Clear Memory").click(clear_memory, inputs=None, outputs=gr.Textbox(label="Status"))
chatbot.render()
if __name__ == "__main__":
demo.launch()