Legal-AI-bot / app.py
higher5fh's picture
Update app.py
65d4889 verified
import os
import uuid
import asyncio
import logging
from datetime import datetime, timedelta
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, Depends, Response, Cookie
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pydantic_settings import BaseSettings
from dotenv import load_dotenv
from upstash_redis.asyncio import Redis
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from slowapi.middleware import SlowAPIMiddleware
from openai import OpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOpenAI
from langchain_classic.chains import LLMChain
from langchain_core.prompts import PromptTemplate
load_dotenv()
# ─── SETTINGS ────────────────────────────────────────────────────────────────────
class Settings(BaseSettings):
OPENAI_API_KEY: str
UPSTASH_REDIS_REST_URL: str
UPSTASH_REDIS_REST_TOKEN: str
VECTOR_DB_PATH: str = "./chroma_db"
TOP_K: int = 7
SESSION_TIMEOUT_MIN: int = 30
RATE_LIMIT: str = "60/minute"
RELEVANCE_THRESHOLD: float = 0.75 # below this = low confidence from vectordb
RECENCY_KEYWORDS: list = [
"2024", "2025", "latest", "current", "recent", "new law",
"updated", "amendment", "now", "today", "changed"
]
class Config:
env_file = ".env"
extra = "ignore"
settings = Settings()
# ─── LOGGING ─────────────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s"
)
logger = logging.getLogger("legal-bot")
# ─── GLOBALS ─────────────────────────────────────────────────────────────────────
redis: Redis = None
# ─── GREETING DETECTION ──────────────────────────────────────────────────────────
GREETINGS = {
"hi", "hello", "hey", "good morning", "good afternoon", "good evening",
"howya", "hiya", "sup", "what's up", "greetings", "yo", "morning", "evening"
}
GREETING_RESPONSES = [
"Hello! πŸ‘‹ I'm your Irish Legal Assistant. I can help you understand Irish law, your rights, "
"legislation, legal processes, and more. What legal question can I help you with today?",
"Hi there! πŸ‘‹ Welcome to the Irish Legal AI Assistant. I'm here to help you navigate Irish law "
"and legal matters. What would you like to know?",
"Hey! πŸ‘‹ I'm here to help with any Irish legal questions you have β€” from employment law to "
"tenancy rights, criminal law, family law, and beyond. What's on your mind?"
]
NON_LEGAL_PATTERNS = [
"weather", "sport", "football", "recipe", "cook", "movie", "music",
"news", "joke", "story", "game", "song", "celebrity", "politics",
"cryptocurrency", "bitcoin", "stock", "fashion", "travel"
]
def classify_query(query: str) -> str:
"""
Returns: 'greeting' | 'non_legal' | 'legal'
"""
q = query.lower().strip()
# Check greeting
if q in GREETINGS or any(q.startswith(g) for g in GREETINGS):
return "greeting"
# Very short non-question inputs
if len(q.split()) <= 2 and "?" not in q and not any(
kw in q for kw in ["law", "right", "legal", "act", "court", "fine", "penalty", "gdpr", "rent", "tax"]
):
return "greeting"
# Non-legal topic
if any(pattern in q for pattern in NON_LEGAL_PATTERNS):
return "non_legal"
return "legal"
def needs_recency_check(query: str) -> bool:
"""Returns True if the query is asking about current/recent law."""
q = query.lower()
return any(kw in q for kw in settings.RECENCY_KEYWORDS)
# ─── LIFESPAN ────────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global redis
redis = Redis(
url=settings.UPSTASH_REDIS_REST_URL,
token=settings.UPSTASH_REDIS_REST_TOKEN
)
logger.info("Upstash Redis connection established")
yield
await redis.close()
logger.info("Upstash Redis connection closed")
# ─── APP ─────────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Irish Legal AI Bot",
description="RAG-driven Irish legal assistant",
lifespan=lifespan
)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_middleware(SlowAPIMiddleware)
# ─── OPENAI CLIENT ───────────────────────────────────────────────────────────────
openai_client = OpenAI(api_key=settings.OPENAI_API_KEY)
# ─── MODERATION ──────────────────────────────────────────────────────────────────
async def moderate_content(text: str) -> bool:
try:
resp = await asyncio.to_thread(
openai_client.moderations.create, input=text
)
return not resp.results[0].flagged
except Exception as e:
logger.error(f"Moderation error: {e}")
return False
# ─── SESSION MANAGEMENT ──────────────────────────────────────────────────────────
class SessionData(BaseModel):
session_id: str
created_at: datetime
expires_at: datetime
last_activity: datetime
history: list
async def get_session(
session_id: str = Cookie(default=None),
response: Response = None
) -> SessionData:
if session_id:
try:
raw = await redis.get(session_id)
if raw:
data = SessionData.parse_raw(raw)
if datetime.utcnow() <= data.expires_at:
data.last_activity = datetime.utcnow()
remaining = int((data.expires_at - datetime.utcnow()).total_seconds())
await redis.setex(session_id, remaining, data.json())
return data
else:
await redis.delete(session_id)
except Exception as e:
logger.error(f"Session fetch error: {e}")
new_id = str(uuid.uuid4())
now = datetime.utcnow()
data = SessionData(
session_id=new_id,
created_at=now,
expires_at=now + timedelta(minutes=settings.SESSION_TIMEOUT_MIN),
last_activity=now,
history=[]
)
await redis.setex(new_id, settings.SESSION_TIMEOUT_MIN * 60, data.json())
if response:
response.set_cookie(
key="session_id",
value=new_id,
httponly=True,
secure=True,
samesite="None",
path="/"
)
return data
async def save_session(session: SessionData):
remaining = (session.expires_at - datetime.utcnow()).total_seconds()
if remaining > 0:
await redis.setex(session.session_id, int(remaining), session.json())
# ─── VECTOR STORE ────────────────────────────────────────────────────────────────
embeddings = OpenAIEmbeddings(openai_api_key=settings.OPENAI_API_KEY)
vectordb = Chroma(
embedding_function=embeddings,
persist_directory=settings.VECTOR_DB_PATH
)
class RetrievalResult(BaseModel):
context: str
sources: list[str]
confidence: str # "high" | "medium" | "low" | "none"
top_score: float
has_dated_content: bool
doc_years: list[str] # years found in retrieved docs
def extract_years_from_text(text: str) -> list[str]:
import re
return re.findall(r"\b(19[89]\d|20[012]\d)\b", text)
def retrieve_context(query: str) -> RetrievalResult:
try:
docs = vectordb.similarity_search_with_score(query, k=settings.TOP_K)
if not docs:
return RetrievalResult(
context="No relevant legal context found.",
sources=[],
confidence="none",
top_score=0.0,
has_dated_content=False,
doc_years=[]
)
# Sort by relevance score ascending (lower = more similar in Chroma)
docs.sort(key=lambda x: x[1])
top_score = docs[0][1]
all_years = []
snippets = []
sources = []
for i, (doc, score) in enumerate(docs):
years = extract_years_from_text(doc.page_content)
all_years.extend(years)
relevance_label = (
"High" if score < 0.5 else
"Medium" if score < 0.75 else
"Low"
)
snippets.append(
f"[Source {i+1} | Relevance: {relevance_label} | Score: {score:.3f}]\n"
f"{doc.page_content.strip()}"
)
sources.append(doc.metadata.get("source", f"Source {i+1}"))
# Determine confidence
if top_score < 0.4:
confidence = "high"
elif top_score < 0.65:
confidence = "medium"
elif top_score < settings.RELEVANCE_THRESHOLD:
confidence = "low"
else:
confidence = "none"
unique_years = sorted(set(all_years), reverse=True)
has_dated = bool(unique_years)
return RetrievalResult(
context="\n\n".join(snippets),
sources=sources,
confidence=confidence,
top_score=top_score,
has_dated_content=has_dated,
doc_years=unique_years
)
except Exception as e:
logger.error(f"Vector retrieval error: {e}")
return RetrievalResult(
context="Context retrieval failed.",
sources=[],
confidence="none",
top_score=1.0,
has_dated_content=False,
doc_years=[]
)
# ─── PROMPTS ─────────────────────────────────────────────────────────────────────
# Full RAG prompt β€” high/medium confidence, good vector match
RAG_PROMPT = PromptTemplate(
input_variables=["context", "question", "history", "recency_note"],
template=(
"You are a knowledgeable and experienced Irish legal expert assistant. "
"Answer the user's question thoroughly using the verified legal context below. "
"Do NOT fabricate information. Prioritise the most recent law if multiple versions appear.\n\n"
"{recency_note}"
"VERIFIED LEGAL CONTEXT:\n{context}\n\n"
"CONVERSATION HISTORY:\n{history}\n\n"
"USER QUESTION:\n{question}\n\n"
"Structure your response exactly as follows:\n\n"
"πŸ” DIRECT ANSWER\n"
"Clear, plain-English answer in 2-3 sentences.\n\n"
"βš–οΈ LEGAL BASIS\n"
"Cite the specific Irish Acts, Statutory Instruments, or case law. "
"Always include section numbers and year of the Act "
"(e.g., 'Section 12 of the Residential Tenancies Act 2004 as amended by the "
"Residential Tenancies (Amendment) Act 2019...').\n\n"
"πŸ“‹ DETAILED EXPLANATION\n"
"Full legal position β€” conditions, exceptions, thresholds, penalties, rights. "
"Use bullet points where helpful. If the law was recently amended, clearly state "
"the OLD rule vs the NEW rule with effective dates.\n\n"
"βœ… PRACTICAL NEXT STEPS\n"
"2-4 concrete actions, relevant deadlines, or official bodies: "
"Citizens Information, Courts Service, Data Protection Commission, "
"Workplace Relations Commission, FLAC, MABS, etc.\n\n"
)
)
# Fallback GPT prompt β€” low/no vector match, uses model knowledge
FALLBACK_PROMPT = PromptTemplate(
input_variables=["question", "history", "partial_context"],
template=(
"You are a highly knowledgeable Irish legal expert assistant. "
"The user's question could not be fully answered from your legal database. "
"Use your expert knowledge of Irish law to provide the most accurate, "
"up-to-date answer possible. Always prioritise current Irish law as of 2024-2025.\n\n"
"PARTIAL CONTEXT FROM DATABASE (may be incomplete):\n{partial_context}\n\n"
"CONVERSATION HISTORY:\n{history}\n\n"
"USER QUESTION:\n{question}\n\n"
"Structure your response as follows:\n\n"
"πŸ” DIRECT ANSWER\n"
"Clear answer in 2-3 sentences.\n\n"
"βš–οΈ LEGAL BASIS\n"
"Cite the relevant Irish legislation, Acts, or EU regulations that apply in Ireland. "
"Be specific with section numbers and amendment years where possible.\n\n"
"πŸ“‹ DETAILED EXPLANATION\n"
"Explain the current legal position fully. Note any recent changes or amendments "
"from 2023-2025. Use bullet points where helpful.\n\n"
"βœ… PRACTICAL NEXT STEPS\n"
"2-4 concrete actions or bodies to contact.\n\n"
)
)
# Recency-aware supplement prompt β€” appended when user asks about current/latest law
RECENCY_SUPPLEMENT_NOTE = (
"⚠️ RECENCY ALERT: The user is asking about current or recent law. "
"If the context contains BOTH old and new versions of a law, you MUST use the most recent one. "
"Explicitly state the effective date of the current law. "
"If the context only has older information, clearly flag this and advise the user to verify "
"on gov.ie or citizensinformation.ie for the latest position.\n\n"
)
# ─── LLM CHAINS ──────────────────────────────────────────────────────────────────
llm = ChatOpenAI(
temperature=0,
openai_api_key=settings.OPENAI_API_KEY,
model="gpt-4-turbo",
max_tokens=1800
)
rag_chain = LLMChain(llm=llm, prompt=RAG_PROMPT)
fallback_chain = LLMChain(llm=llm, prompt=FALLBACK_PROMPT)
# ─── INTELLIGENT ANSWER ENGINE ───────────────────────────────────────────────────
async def generate_answer(
query: str,
history_text: str,
retrieval: RetrievalResult,
recency_needed: bool
) -> tuple[str, str]:
"""
Returns (answer, mode_used)
mode_used: 'rag' | 'fallback' | 'hybrid'
"""
recency_note = RECENCY_SUPPLEMENT_NOTE if recency_needed else ""
if retrieval.confidence in ("high", "medium"):
# Strong vector match β€” use RAG
logger.info(f"Mode: RAG | Confidence: {retrieval.confidence} | Score: {retrieval.top_score:.3f}")
answer = await asyncio.to_thread(
rag_chain.run,
{
"context": retrieval.context,
"question": query,
"history": history_text,
"recency_note": recency_note
}
)
return answer, "rag"
elif retrieval.confidence == "low":
# Partial match β€” hybrid: use both context and GPT knowledge
logger.info(f"Mode: HYBRID | Confidence: {retrieval.confidence} | Score: {retrieval.top_score:.3f}")
answer = await asyncio.to_thread(
fallback_chain.run,
{
"question": query,
"history": history_text,
"partial_context": retrieval.context # pass partial context
}
)
return answer, "hybrid"
else:
# No useful vector match β€” pure GPT fallback
logger.info(f"Mode: FALLBACK | Confidence: {retrieval.confidence} | Score: {retrieval.top_score:.3f}")
answer = await asyncio.to_thread(
fallback_chain.run,
{
"question": query,
"history": history_text,
"partial_context": "No relevant context found in the legal database."
}
)
return answer, "fallback"
# ─── PYDANTIC MODELS ─────────────────────────────────────────────────────────────
class QueryRequest(BaseModel):
query: str
class QueryResponse(BaseModel):
answer: str
session_id: str
sources: list[str]
mode: str # rag | fallback | hybrid | greeting
confidence: str # high | medium | low | none | n/a
class SessionStatusResponse(BaseModel):
status: str
ttl: int
session_id: Optional[str]
created_at: Optional[datetime]
expires_at: Optional[datetime]
last_activity: Optional[datetime]
history_count: Optional[int]
class SessionHistoryResponse(BaseModel):
history: list
session_id: str
class HealthResponse(BaseModel):
status: str
timestamp: datetime
redis: str
vectordb: str
# ─── ROUTES ──────────────────────────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def root():
return FileResponse("index.html")
@app.get("/health", response_model=HealthResponse)
async def health_check():
redis_status = "ok"
try:
await redis.ping()
except Exception:
redis_status = "error"
vectordb_status = "ok"
try:
await asyncio.to_thread(vectordb.similarity_search, "test", 1)
except Exception:
vectordb_status = "error"
return HealthResponse(
status="ok" if redis_status == "ok" and vectordb_status == "ok" else "degraded",
timestamp=datetime.utcnow(),
redis=redis_status,
vectordb=vectordb_status
)
@app.post("/query", response_model=QueryResponse)
@limiter.limit(settings.RATE_LIMIT)
async def handle_query(
request: Request,
req: QueryRequest,
session: SessionData = Depends(get_session),
response: Response = None
):
import random
query = req.query.strip()
# ── Input validation
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty.")
if len(query) > 2000:
raise HTTPException(status_code=400, detail="Query too long. Max 2000 characters.")
# ── Moderate input
if not await moderate_content(query):
raise HTTPException(status_code=400, detail="Content policy violation.")
# ── Classify query
query_type = classify_query(query)
# ── Handle greeting
if query_type == "greeting":
greeting_reply = random.choice(GREETING_RESPONSES)
return QueryResponse(
answer=greeting_reply,
session_id=session.session_id,
sources=[],
mode="greeting",
confidence="n/a"
)
# ── Handle non-legal topics
if query_type == "non_legal":
return QueryResponse(
answer=(
"I'm specifically designed to assist with Irish legal matters only. πŸ›οΈ\n\n"
),
session_id=session.session_id,
sources=[],
mode="deflect",
confidence="n/a"
)
# ── Check if user needs current/latest law info
recency_needed = needs_recency_check(query)
# ── Retrieve from vector DB
retrieval = await asyncio.to_thread(retrieve_context, query)
# ── Format conversation history
history_text = "No prior conversation."
if session.history:
recent = session.history[-3:]
history_text = "\n".join(
[f"User: {h['q']}\nAssistant: {h['a']}" for h in recent]
)
# ── Generate answer
try:
answer, mode_used = await generate_answer(
query=query,
history_text=history_text,
retrieval=retrieval,
recency_needed=recency_needed
)
except Exception as e:
logger.error(f"LLM error: {e}")
raise HTTPException(status_code=500, detail="Failed to generate a response.")
# ── Moderate output
if not await moderate_content(answer):
answer = "I'm unable to provide a response to that query due to content policy restrictions."
# ── Update session history
session.history.append({
"q": query,
"a": answer,
"timestamp": datetime.utcnow().isoformat(),
"mode": mode_used
})
if len(session.history) > 5:
session.history.pop(0)
await save_session(session)
logger.info(
f"Query handled | session={session.session_id} | "
f"mode={mode_used} | confidence={retrieval.confidence} | "
f"sources={len(retrieval.sources)} | recency={recency_needed}"
)
return QueryResponse(
answer=answer,
session_id=session.session_id,
sources=retrieval.sources,
mode=mode_used,
confidence=retrieval.confidence
)
@app.get("/session/status", response_model=SessionStatusResponse)
async def get_session_status(session_id: str = Cookie(default=None)):
if not session_id:
return SessionStatusResponse(
status="new", ttl=-2, session_id=None,
created_at=None, expires_at=None,
last_activity=None, history_count=None
)
try:
raw = await redis.get(session_id)
except Exception as e:
logger.error(f"Redis error on status check: {e}")
raise HTTPException(status_code=500, detail="Session store unavailable.")
if not raw:
return SessionStatusResponse(
status="expired", ttl=-2, session_id=session_id,
created_at=None, expires_at=None,
last_activity=None, history_count=None
)
data = SessionData.parse_raw(raw)
now = datetime.utcnow()
if now > data.expires_at:
return SessionStatusResponse(
status="expired", ttl=-2, session_id=session_id,
created_at=data.created_at, expires_at=data.expires_at,
last_activity=data.last_activity, history_count=len(data.history)
)
return SessionStatusResponse(
status="active",
ttl=int((data.expires_at - now).total_seconds()),
session_id=session_id,
created_at=data.created_at,
expires_at=data.expires_at,
last_activity=data.last_activity,
history_count=len(data.history)
)
@app.get("/session/history", response_model=SessionHistoryResponse)
async def get_session_history(session: SessionData = Depends(get_session)):
return SessionHistoryResponse(
history=session.history,
session_id=session.session_id
)
@app.delete("/session/clear")
async def clear_session(session: SessionData = Depends(get_session)):
try:
await redis.delete(session.session_id)
logger.info(f"Session cleared: {session.session_id}")
return JSONResponse({"message": "Session cleared successfully."})
except Exception as e:
logger.error(f"Session clear error: {e}")
raise HTTPException(status_code=500, detail="Failed to clear session.")
# ─── EXCEPTION HANDLERS ──────────────────────────────────────────────────────────
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please slow down."}
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={"detail": "An internal server error occurred."}
)
# ─── LAUNCH ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, workers=4, log_level="info")