Jaita's picture
Rename main.py to main3.py
026206e verified
from fastapi import FastAPI
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
import uuid
from huggingface_hub import InferenceClient
import os
from docx import Document
import google.generativeai as genai
# --- 0. Config ---
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
raise RuntimeError("GEMINI_API_KEY is not set in environment.")
# Configure the SDK
genai.configure(api_key=GEMINI_API_KEY)
# Choose the model
MODEL_NAME = "gemini-2.5-flash-lite"
LLM = genai.GenerativeModel(MODEL_NAME)
app = FastAPI()
# -----------------------------
# 1. SETUP: Embeddings + LLM
# -----------------------------
EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# -----------------------------
# 2. SETUP: ChromaDB
# -----------------------------
chroma_client = chromadb.PersistentClient(path="./chroma_db")
collection = chroma_client.get_or_create_collection(name="knowledge_base")
# -----------------------------
# Helper: Extract text from docx
# -----------------------------
def extract_docx_text(file_path):
doc = Document(file_path)
return "\n".join([para.text for para in doc.paragraphs])
# -----------------------------
# 3. STARTUP INGEST
# -----------------------------
@app.on_event("startup")
def ingest_documents():
print("Checking if KB already has data...")
if collection.count() > 0:
print("KB exists. Skipping ingest.")
return
print("Empty KB. Ingesting files...")
for fname in os.listdir("./documents"):
if fname.endswith(".docx"):
text = extract_docx_text(f"./documents/{fname}")
chunks = text.split("\n\n") # simple chunking for beginners
for chunk in chunks:
if len(chunk.strip()) < 50:
continue
embedding = EMBED_MODEL.encode(chunk).tolist()
collection.add(
ids=[str(uuid.uuid4())],
embeddings=[embedding],
documents=[chunk],
metadatas=[{"source": fname}]
)
print("Ingest complete.")
# -----------------------------
# 4. LLM for Intent detection
# -----------------------------
def get_intent(query):
prompt = f"""
Classify the user's intent from the list:
- receiving
- inventory_adjustment
- update_footprint
- picking
- shipping
- trailer_close
User query: "{query}"
Respond ONLY with the intent label.
"""
resp = LLM.text_generation(prompt, max_new_tokens=10)
return resp.strip()
# -----------------------------
# 5. Hybrid Search (vector + keyword)
# -----------------------------
def hybrid_search(query, intent, top_k=3):
# Vector search
emb = EMBED_MODEL.encode(query).tolist()
results = collection.query(query_embeddings=[emb], n_results=top_k)
docs = results["documents"][0]
scores = results["distances"][0]
# Convert distances to similarity
similarities = [1 - d for d in scores]
combined = list(zip(docs, similarities))
# Simple keyword boost
boosted = []
for text, sim in combined:
score = sim
if intent.replace("_", " ") in text.lower():
score += 0.05
boosted.append((text, score))
boosted.sort(key=lambda x: x[1], reverse=True)
return boosted
# -----------------------------
# 6. LLM Format (rephrase KB)
# -----------------------------
def format_with_llm(answer):
prompt = f"""
Rewrite this answer clearly and politely without adding new information:
{answer}
"""
return LLM.text_generation(prompt, max_new_tokens=150)
# -----------------------------
# 7. RAG Fallback
# -----------------------------
def rag_fallback(query, docs):
context = "\n\n".join([d for d, _ in docs])
prompt = f"""
Use ONLY the information below to answer the question.
If the answer is not found, say "not found".
Context:
{context}
Question: {query}
Answer:
"""
return LLM.text_generation(prompt, max_new_tokens=200)
# -----------------------------
# 8. INCIDENT NUMBER GENERATOR
# -----------------------------
def generate_incident():
return "INC" + str(uuid.uuid4())[:8].upper()
# -----------------------------
# 9. MAIN CHAT ENDPOINT
# -----------------------------
@app.post("/chat")
def chat(query: str):
# Step 2: Detect intent
intent = get_intent(query)
# Step 3–4: Hybrid search
docs = hybrid_search(query, intent)
top_answer, top_score = docs[0]
# Step 5: High confidence (≥ 0.89)
if top_score >= 0.89:
reply = format_with_llm(top_answer)
return {"answer": reply, "intent": intent, "confidence": top_score}
# Step 6: RAG fallback
rag_answer = rag_fallback(query, docs)
if "not found" not in rag_answer.lower() and len(rag_answer.split()) > 5:
return {"answer": rag_answer, "intent": intent, "mode": "RAG"}
# Step 7: Still not resolved → create incident
incident = generate_incident()
return {
"answer": f"I couldn't find this information. I've created incident {incident}.",
"incident": incident,
"intent": intent
}