import os import glob import chromadb from sentence_transformers import SentenceTransformer from openai import OpenAI COLLECTION_NAME = "poke_wiggle_docs" DATA_PATH = "data/raw" EMBEDDING_MODEL = "all-MiniLM-L6-v2" HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct" embedding_model = SentenceTransformer(EMBEDDING_MODEL) llm_client = OpenAI( base_url="https://router.huggingface.co/v1", api_key=os.environ["HF_TOKEN"] ) chroma_client = chromadb.EphemeralClient() collection = chroma_client.get_or_create_collection(COLLECTION_NAME) def chunk_text(text, chunk_size=600, overlap=100): chunks = [] start = 0 while start < len(text): chunk = text[start:start + chunk_size].strip() if chunk: chunks.append(chunk) start += chunk_size - overlap return chunks def rebuild_vector_db(): global collection try: chroma_client.delete_collection(COLLECTION_NAME) except Exception: pass collection = chroma_client.get_or_create_collection(COLLECTION_NAME) files = sorted(glob.glob(os.path.join(DATA_PATH, "*.txt"))) ids = [] documents = [] metadatas = [] for file_path in files: source = os.path.basename(file_path) topic = source.replace(".txt", "") with open(file_path, "r", encoding="utf-8") as f: text = f.read() chunks = chunk_text(text) for i, chunk in enumerate(chunks): ids.append(f"{source}_{i}") documents.append(chunk) metadatas.append({ "source": source, "topic": topic }) if not documents: return 0 embeddings = embedding_model.encode(documents).tolist() collection.add( ids=ids, documents=documents, metadatas=metadatas, embeddings=embeddings ) return len(documents) def build_vector_db_if_empty(): if collection.count() > 0: return collection.count() return rebuild_vector_db() def retrieve_context(question, n_results=10): build_vector_db_if_empty() query_embedding = embedding_model.encode(question).tolist() results = collection.query( query_embeddings=[query_embedding], n_results=n_results ) documents = results["documents"][0] metadatas = results["metadatas"][0] context_parts = [] sources = [] for doc, meta in zip(documents, metadatas): source = meta.get("source", "unknown") sources.append(source) context_parts.append(f"[Source: {source}]\n{doc}") return "\n\n".join(context_parts), sorted(set(sources)) def ask_llm(question, context): prompt = f""" You are a helpful assistant for Poke & Wiggle. Rules: - Use ONLY the given context. - Do not invent facts. - If the context does not contain the answer, say: "I do not have enough information in the current knowledge base." - Be clear, concise, and factual. Context: {context} Question: {question} Answer: """ response = llm_client.chat.completions.create( model=HF_MODEL, messages=[ {"role": "user", "content": prompt} ], max_tokens=350, temperature=0.1 ) return response.choices[0].message.content.strip() def answer_question(question): context, sources = retrieve_context(question) answer = ask_llm(question, context) return answer, sources