|
import os |
|
import requests |
|
import json |
|
from fastapi import APIRouter |
|
from pydantic import BaseModel |
|
from typing import List |
|
from redis_client import redis_client as r |
|
from dotenv import load_dotenv |
|
from urllib.parse import quote |
|
|
|
from nuse_modules.classifier import classify_question, REVERSE_MAP |
|
from nuse_modules.keyword_extracter import keywords_extractor |
|
from nuse_modules.google_search import search_google_news |
|
|
|
load_dotenv() |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3" |
|
HEADERS = { |
|
"Authorization": f"Bearer {HF_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
askMe = APIRouter() |
|
|
|
class QuestionInput(BaseModel): |
|
question: str |
|
|
|
|
|
def should_extract_keywords(type_id: int) -> bool: |
|
return type_id in {1, 2, 3, 4, 5, 6, 7, 10} |
|
|
|
|
|
def extract_answer_after_label(text: str) -> str: |
|
""" |
|
Extracts everything after the first 'Answer:' label. |
|
Assumes 'Answer:' appears once and is followed by the relevant content. |
|
""" |
|
if "Answer:" in text: |
|
return text.split("Answer:", 1)[1].strip() |
|
return text.strip() |
|
|
|
|
|
def mistral_generate(prompt: str, max_new_tokens=128): |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": 0.7 |
|
} |
|
} |
|
try: |
|
response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30) |
|
response.raise_for_status() |
|
result = response.json() |
|
if isinstance(result, list) and len(result) > 0: |
|
return result[0].get("generated_text", "").strip() |
|
else: |
|
return "" |
|
except Exception: |
|
return "" |
|
|
|
|
|
@askMe.post("/ask") |
|
async def ask_question(input: QuestionInput): |
|
question = input.question |
|
|
|
|
|
qid = classify_question(question) |
|
print("Intent ID:", qid) |
|
print("Category:", REVERSE_MAP.get(qid, "unknown")) |
|
|
|
context = "" |
|
sources = [] |
|
|
|
|
|
if should_extract_keywords(qid): |
|
keywords = keywords_extractor(question) |
|
print("Raw extracted keywords:", keywords) |
|
|
|
if not keywords: |
|
return {"error": "Keyword extraction failed."} |
|
|
|
|
|
results = search_google_news(keywords) |
|
print("Found articles:", results) |
|
|
|
|
|
|
|
|
|
|
|
context = "\n\n".join([ |
|
r.get("snippet") or r.get("description", "") |
|
for r in results |
|
])[:15000] |
|
|
|
sources = [ |
|
{"title": r["title"], "url": r["link"]} |
|
for r in results |
|
] |
|
|
|
if not context.strip(): |
|
return { |
|
"question": question, |
|
"answer": "Cannot answer β no relevant context found.", |
|
"sources": sources |
|
} |
|
|
|
|
|
answer_prompt = ( |
|
f"You are a concise news assistant. Answer the user's question clearly using the context below if relevant. " |
|
f"If the context is not helpful, you may rely on your own knowledge, but do not mention the context or question again.\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"Question: {question}\n\n" |
|
f"Answer:" |
|
) |
|
answer_raw = mistral_generate(answer_prompt, max_new_tokens=256) |
|
|
|
if not answer_raw: |
|
final_answer = "Cannot answer β model did not return a valid response." |
|
else: |
|
final_answer = extract_answer_after_label(answer_raw) |
|
|
|
return { |
|
"question": question, |
|
"answer": final_answer.strip(), |
|
"sources": sources |
|
} |
|
|