| import os |
| import glob |
| import chromadb |
|
|
| from sentence_transformers import SentenceTransformer |
| from openai import OpenAI |
|
|
|
|
| COLLECTION_NAME = "poke_wiggle_docs" |
| DATA_PATH = "data/raw" |
|
|
| EMBEDDING_MODEL = "all-MiniLM-L6-v2" |
| HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct" |
|
|
|
|
| embedding_model = SentenceTransformer(EMBEDDING_MODEL) |
|
|
| llm_client = OpenAI( |
| base_url="https://router.huggingface.co/v1", |
| api_key=os.environ["HF_TOKEN"] |
| ) |
|
|
| chroma_client = chromadb.EphemeralClient() |
| collection = chroma_client.get_or_create_collection(COLLECTION_NAME) |
|
|
|
|
| def chunk_text(text, chunk_size=600, overlap=100): |
| chunks = [] |
| start = 0 |
|
|
| while start < len(text): |
| chunk = text[start:start + chunk_size].strip() |
| if chunk: |
| chunks.append(chunk) |
| start += chunk_size - overlap |
|
|
| return chunks |
|
|
|
|
| def rebuild_vector_db(): |
| global collection |
|
|
| try: |
| chroma_client.delete_collection(COLLECTION_NAME) |
| except Exception: |
| pass |
|
|
| collection = chroma_client.get_or_create_collection(COLLECTION_NAME) |
|
|
| files = sorted(glob.glob(os.path.join(DATA_PATH, "*.txt"))) |
|
|
| ids = [] |
| documents = [] |
| metadatas = [] |
|
|
| for file_path in files: |
| source = os.path.basename(file_path) |
| topic = source.replace(".txt", "") |
|
|
| with open(file_path, "r", encoding="utf-8") as f: |
| text = f.read() |
|
|
| chunks = chunk_text(text) |
|
|
| for i, chunk in enumerate(chunks): |
| ids.append(f"{source}_{i}") |
| documents.append(chunk) |
| metadatas.append({ |
| "source": source, |
| "topic": topic |
| }) |
|
|
| if not documents: |
| return 0 |
|
|
| embeddings = embedding_model.encode(documents).tolist() |
|
|
| collection.add( |
| ids=ids, |
| documents=documents, |
| metadatas=metadatas, |
| embeddings=embeddings |
| ) |
|
|
| return len(documents) |
|
|
|
|
| def build_vector_db_if_empty(): |
| if collection.count() > 0: |
| return collection.count() |
|
|
| return rebuild_vector_db() |
|
|
|
|
| def retrieve_context(question, n_results=10): |
| build_vector_db_if_empty() |
|
|
| query_embedding = embedding_model.encode(question).tolist() |
|
|
| results = collection.query( |
| query_embeddings=[query_embedding], |
| n_results=n_results |
| ) |
|
|
| documents = results["documents"][0] |
| metadatas = results["metadatas"][0] |
|
|
| context_parts = [] |
| sources = [] |
|
|
| for doc, meta in zip(documents, metadatas): |
| source = meta.get("source", "unknown") |
| sources.append(source) |
| context_parts.append(f"[Source: {source}]\n{doc}") |
|
|
| return "\n\n".join(context_parts), sorted(set(sources)) |
|
|
|
|
| def ask_llm(question, context): |
| prompt = f""" |
| You are a helpful assistant for Poke & Wiggle. |
| |
| Rules: |
| - Use ONLY the given context. |
| - Do not invent facts. |
| - If the context does not contain the answer, say: |
| "I do not have enough information in the current knowledge base." |
| - Be clear, concise, and factual. |
| |
| Context: |
| {context} |
| |
| Question: |
| {question} |
| |
| Answer: |
| """ |
|
|
| response = llm_client.chat.completions.create( |
| model=HF_MODEL, |
| messages=[ |
| {"role": "user", "content": prompt} |
| ], |
| max_tokens=350, |
| temperature=0.1 |
| ) |
|
|
| return response.choices[0].message.content.strip() |
|
|
|
|
| def answer_question(question): |
| context, sources = retrieve_context(question) |
| answer = ask_llm(question, context) |
| return answer, sources |