# app.py — Fixed version with streaming + memory + web search import os import json import threading import gradio as gr from huggingface_hub import InferenceClient from datasets import load_dataset from duckduckgo_search import DDGS from transformers import pipeline import torch from cerebras.cloud.sdk import Cerebras # Lấy API key từ environment variable CEREBRAS_API_KEY = os.getenv("CEREBRAS_API_KEY") # Khởi tạo Cerebras client client = Cerebras(api_key=CEREBRAS_API_KEY) # ---------------- CONFIG ---------------- MODEL_ID = "openai/gpt-oss-20b" DATA_DIR = "/data" if os.path.isdir("/data") else "./data" os.makedirs(DATA_DIR, exist_ok=True) SHORT_TERM_LIMIT = 10 SUMMARY_MAX_TOKENS = 150 MEMORY_LOCK = threading.Lock() # ---------------- SIMPLE STREAMING DATASET ---------------- # Only load what we actually use to avoid errors print("Loading FineWeb in streaming mode...") try: fineweb_stream = load_dataset( "HuggingFaceFW/fineweb", repo_type="dataset", local_dir="./fineweb/", allow_patterns="sample/1000BT/*", split="train", streaming=True ) print("✅ FineWeb streaming loaded") except Exception as e: print(f"FineWeb loading failed: {e}") fineweb_stream = None # Keep other datasets as before for stability try: ds1 = load_dataset("HuggingFaceH4/ultrachat_200k", split="train[:5000]") # Small sample ds2 = load_dataset("Anthropic/hh-rlhf", split="train[:5000]") # Small sample print("✅ Other datasets loaded") except Exception as e: print(f"Dataset loading error: {e}") ds1, ds2 = None, None # ---------------- SIMPLE FINEWEB SEARCH ---------------- def search_fineweb(query, max_search=1000): """Simple FineWeb search - safe version""" if not fineweb_stream: return "FineWeb not available" try: query_lower = query.lower() found_content = [] count = 0 for sample in fineweb_stream: if count >= max_search: break text = sample.get('text', '') if len(text) > 50 and query_lower in text.lower(): content = text[:300] + "..." if len(text) > 300 else text found_content.append(content) if len(found_content) >= 3: # Max 3 results break count += 1 if found_content: return "📚 FineWeb Results:\n\n" + "\n\n---\n\n".join(found_content) else: return "No relevant FineWeb content found" except Exception as e: return f"FineWeb search error: {str(e)}" # ---------------- MEMORY FUNCTIONS (SAME AS BEFORE) ---------------- def get_user_id(hf_token): if hf_token and getattr(hf_token, "token", None): return "user_" + hf_token.token[:12] return "anon" def memory_file_path(user_id): return os.path.join(DATA_DIR, f"memory_{user_id}.json") def load_memory(user_id): p = memory_file_path(user_id) if os.path.exists(p): try: with open(p, "r", encoding="utf-8") as f: mem = json.load(f) if isinstance(mem, dict) and "short_term" in mem and "long_term" in mem: return mem except Exception as e: print("load_memory error:", e) return {"short_term": [], "long_term": ""} def save_memory(user_id, memory): p = memory_file_path(user_id) try: with MEMORY_LOCK: with open(p, "w", encoding="utf-8") as f: json.dump(memory, f, ensure_ascii=False, indent=2) except Exception as e: print("save_memory error:", e) def normalize_history(history): out = [] if not history: return out for turn in history: if isinstance(turn, dict) and "role" in turn and "content" in turn: out.append({"role": turn["role"], "content": str(turn["content"])}) elif isinstance(turn, (list, tuple)) and len(turn) == 2: user_msg, assistant_msg = turn out.append({"role": "user", "content": str(user_msg)}) out.append({"role": "assistant", "content": str(assistant_msg)}) return out # ---------------- WEB SEARCH (SAME AS BEFORE) ---------------- def web_search(query, num_results=3): try: with DDGS() as ddgs: results = list(ddgs.text(query, max_results=num_results)) search_context = "🔍 Web Search Results:\n\n" for i, r in enumerate(results, 1): title = r.get("title", "")[:200] body = r.get("body", "")[:200].replace("\n", " ") href = r.get("href", "") search_context += f"{i}. {title}\n{body}...\nSource: {href}\n\n" return search_context except Exception as e: return f"Search error: {str(e)}" # ---------------- MEMORY TOOLS ---------------- def show_memory(hf_token=None): user = get_user_id(hf_token) p = memory_file_path(user) if not os.path.exists(p): return f"No memory found for {user}" with open(p, "r", encoding="utf-8") as f: return f.read() def clear_memory(hf_token=None): user = get_user_id(hf_token) p = memory_file_path(user) if os.path.exists(p): os.remove(p) return f"Memory cleared for {user}" return "No memory to clear" # ---------------- MAIN CHAT FUNCTION ---------------- def respond(message, history, system_message, max_tokens, temperature, top_p, enable_web_search, enable_fineweb_search, enable_memory, hf_token=None): try: client = InferenceClient(token=(hf_token.token if hf_token else None), model=MODEL_ID) user_id = get_user_id(hf_token) # Memory handling memory = load_memory(user_id) if enable_memory else {"short_term": [], "long_term": ""} session_history = normalize_history(history) combined = memory.get("short_term", []) + session_history combined.append({"role": "user", "content": message}) # Keep memory manageable if len(combined) > SHORT_TERM_LIMIT: combined = combined[-SHORT_TERM_LIMIT:] memory["short_term"] = combined if enable_memory: save_memory(user_id, memory) # Build messages messages = [{"role": "system", "content": system_message}] # Add memory context if memory.get("long_term"): messages.append({"role": "system", "content": f"Memory: {memory['long_term']}"}) # Add search results if needed search_keywords = ["search", "find", "what is", "tell me about", "news", "latest"] should_search = any(keyword in message.lower() for keyword in search_keywords) context_parts = [] if enable_web_search and should_search: web_results = web_search(message) context_parts.append(web_results) if enable_fineweb_search and should_search: fineweb_results = search_fineweb(message) if "not available" not in fineweb_results and "No relevant" not in fineweb_results: context_parts.append(fineweb_results) if context_parts: search_context = "\n\n".join(context_parts) messages.append({"role": "system", "content": f"Context:\n{search_context}"}) messages.extend(memory["short_term"]) # Generate response response = "" for chunk in client.chat_completion( messages, max_tokens=int(max_tokens), stream=True, temperature=float(temperature), top_p=float(top_p) ): choices = chunk.get("choices") if isinstance(chunk, dict) else getattr(chunk, "choices", None) if choices: delta = choices[0].get("delta") if isinstance(choices[0], dict) else getattr(choices[0], "delta", None) if delta: token = delta.get("content") if isinstance(delta, dict) else getattr(delta, "content", None) if token: response += token yield response # Save response to memory memory["short_term"].append({"role": "assistant", "content": response}) memory["short_term"] = memory["short_term"][-SHORT_TERM_LIMIT:] if enable_memory: save_memory(user_id, memory) except Exception as e: yield f"Error: {str(e)}" # ---------------- GRADIO UI ---------------- chatbot = gr.ChatInterface( respond, type="messages", additional_inputs=[ gr.Textbox(value="You are KiyAI, a helpful AI assistant with access to web search and knowledge datasets.", label="System message"), gr.Slider(1, 2048, value=1028, step=1, label="Max tokens"), gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), gr.Checkbox(value=True, label="🌐 Web Search"), gr.Checkbox(value=True, label="📚 FineWeb Search"), gr.Checkbox(value=True, label="🧠 Memory"), ], ) with gr.Blocks(title="KiyAI - Experimental Version") as demo: gr.Markdown("# 🤖 KiyAI, unlock your potenetials!") with gr.Sidebar(): gr.LoginButton() gr.Markdown("### Memory Tools") show_btn = gr.Button("👀 Show Memory") clear_btn = gr.Button("🗑️ Clear Memory") memory_display = gr.Textbox(label="Memory Status", lines=5) show_btn.click(show_memory, inputs=None, outputs=memory_display) clear_btn.click(clear_memory, inputs=None, outputs=memory_display) chatbot.render() if __name__ == "__main__": demo.launch(share=True)