Poke-Wiggle-Chatbot / chatbot.py
AnarPythonMaster's picture
Update chatbot.py
ee2042a verified
Raw
History Blame Contribute Delete
3.43 kB
import os
import glob
import chromadb
from sentence_transformers import SentenceTransformer
from openai import OpenAI
COLLECTION_NAME = "poke_wiggle_docs"
DATA_PATH = "data/raw"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
llm_client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=os.environ["HF_TOKEN"]
)
chroma_client = chromadb.EphemeralClient()
collection = chroma_client.get_or_create_collection(COLLECTION_NAME)
def chunk_text(text, chunk_size=600, overlap=100):
chunks = []
start = 0
while start < len(text):
chunk = text[start:start + chunk_size].strip()
if chunk:
chunks.append(chunk)
start += chunk_size - overlap
return chunks
def rebuild_vector_db():
global collection
try:
chroma_client.delete_collection(COLLECTION_NAME)
except Exception:
pass
collection = chroma_client.get_or_create_collection(COLLECTION_NAME)
files = sorted(glob.glob(os.path.join(DATA_PATH, "*.txt")))
ids = []
documents = []
metadatas = []
for file_path in files:
source = os.path.basename(file_path)
topic = source.replace(".txt", "")
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
chunks = chunk_text(text)
for i, chunk in enumerate(chunks):
ids.append(f"{source}_{i}")
documents.append(chunk)
metadatas.append({
"source": source,
"topic": topic
})
if not documents:
return 0
embeddings = embedding_model.encode(documents).tolist()
collection.add(
ids=ids,
documents=documents,
metadatas=metadatas,
embeddings=embeddings
)
return len(documents)
def build_vector_db_if_empty():
if collection.count() > 0:
return collection.count()
return rebuild_vector_db()
def retrieve_context(question, n_results=10):
build_vector_db_if_empty()
query_embedding = embedding_model.encode(question).tolist()
results = collection.query(
query_embeddings=[query_embedding],
n_results=n_results
)
documents = results["documents"][0]
metadatas = results["metadatas"][0]
context_parts = []
sources = []
for doc, meta in zip(documents, metadatas):
source = meta.get("source", "unknown")
sources.append(source)
context_parts.append(f"[Source: {source}]\n{doc}")
return "\n\n".join(context_parts), sorted(set(sources))
def ask_llm(question, context):
prompt = f"""
You are a helpful assistant for Poke & Wiggle.
Rules:
- Use ONLY the given context.
- Do not invent facts.
- If the context does not contain the answer, say:
"I do not have enough information in the current knowledge base."
- Be clear, concise, and factual.
Context:
{context}
Question:
{question}
Answer:
"""
response = llm_client.chat.completions.create(
model=HF_MODEL,
messages=[
{"role": "user", "content": prompt}
],
max_tokens=350,
temperature=0.1
)
return response.choices[0].message.content.strip()
def answer_question(question):
context, sources = retrieve_context(question)
answer = ask_llm(question, context)
return answer, sources