|
import os |
|
import json |
|
import datetime |
|
from typing import List, Dict |
|
|
|
import requests |
|
from fastapi import APIRouter |
|
from pydantic import BaseModel |
|
from dotenv import load_dotenv |
|
|
|
from clients.redis_client import redis_client as _r |
|
from models_initialization.mistral_registry import mistral_generate |
|
from nuse_modules.classifier import classify_question, REVERSE_MAP |
|
from nuse_modules.keyword_extracter import keywords_extractor |
|
from nuse_modules.google_search import search_google_news |
|
|
|
load_dotenv() |
|
|
|
askMe = APIRouter() |
|
|
|
|
|
|
|
|
|
class QuestionInput(BaseModel): |
|
question: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
def should_extract_keywords(type_id: int) -> bool: |
|
"""Map the intent id to whether we need keyword extraction.""" |
|
return type_id in {1, 2, 3, 4, 5, 6, 7, 10, 11, 12} |
|
|
|
|
|
def extract_answer_after_label(text: str) -> str: |
|
"""Extracts everything after the first 'Answer:' label.""" |
|
if "Answer:" in text: |
|
return text.split("Answer:", 1)[1].strip() |
|
return text.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
@askMe.post("/ask") |
|
async def ask_question(input: QuestionInput): |
|
question = input.question.strip() |
|
|
|
|
|
qid = classify_question(question) |
|
print("Intent ID:", qid) |
|
print("Category:", REVERSE_MAP.get(qid, "unknown")) |
|
|
|
|
|
if qid == "asking_for_headlines": |
|
date_str = datetime.datetime.utcnow().strftime("%Y-%m-%d") |
|
categories = ["world", "india", "finance", "sports", "entertainment"] |
|
all_headlines: List[Dict] = [] |
|
|
|
for cat in categories: |
|
redis_key = f"headlines:{date_str}:{cat}" |
|
cached = _r.get(redis_key) |
|
if cached: |
|
try: |
|
articles = json.loads(cached) |
|
except json.JSONDecodeError: |
|
continue |
|
for art in articles: |
|
all_headlines.append({ |
|
"title": art.get("title"), |
|
"summary": art.get("summary"), |
|
"url": art.get("url"), |
|
"image": art.get("image"), |
|
"category": cat, |
|
}) |
|
|
|
return { |
|
"question": question, |
|
"answer": "Here are todayβs top headlines:", |
|
"headlines": all_headlines, |
|
} |
|
|
|
|
|
context = "" |
|
sources: List[Dict] = [] |
|
|
|
if should_extract_keywords(qid): |
|
keywords = keywords_extractor(question) |
|
print("Raw extracted keywords:", keywords) |
|
|
|
if not keywords: |
|
return {"error": "Keyword extraction failed."} |
|
|
|
|
|
results = search_google_news(keywords) |
|
print("Found articles:", results) |
|
|
|
context = "\n\n".join([ |
|
r.get("snippet") or r.get("description", "") for r in results |
|
])[:15000] |
|
|
|
sources = [{"title": r["title"], "url": r["link"]} for r in results] |
|
|
|
if not context.strip(): |
|
return { |
|
"question": question, |
|
"answer": "Cannot answer β no relevant context found.", |
|
"sources": sources, |
|
} |
|
|
|
answer_prompt = ( |
|
"You are a concise news assistant. Answer the user's question clearly using the provided context if relevant. " |
|
"If the context is not helpful, rely on your own knowledge but do not mention the context.\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"Question: {question}\n\nAnswer:" |
|
) |
|
answer_raw = mistral_generate(answer_prompt, max_new_tokens=256) |
|
|
|
else: |
|
answer_prompt = ( |
|
"You are a concise news assistant. Answer the user's question clearly and accurately.\n\n" |
|
f"Question: {question}\n\nAnswer:" |
|
) |
|
answer_raw = mistral_generate(answer_prompt, max_new_tokens=256) |
|
|
|
|
|
final_answer = extract_answer_after_label(answer_raw or "") or ( |
|
"Cannot answer β model did not return a valid response." |
|
) |
|
|
|
return { |
|
"question": question, |
|
"answer": final_answer.strip(), |
|
"sources": sources, |
|
} |
|
|