|
|
|
import os |
|
import requests |
|
from fastapi import APIRouter |
|
from pydantic import BaseModel |
|
from typing import List |
|
from redis_client import redis_client as r |
|
from dotenv import load_dotenv |
|
from urllib.parse import quote |
|
import json |
|
|
|
load_dotenv() |
|
|
|
GNEWS_API_KEY = os.getenv("GNEWS_API_KEY") |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
askMe = APIRouter() |
|
|
|
class QuestionInput(BaseModel): |
|
question: str |
|
|
|
HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3" |
|
HEADERS = { |
|
"Authorization": f"Bearer {HF_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
def extract_last_keywords(raw: str, max_keywords=8): |
|
segments = raw.strip().split("\n") |
|
|
|
|
|
for line in reversed(segments): |
|
line = line.strip() |
|
if line.lower().startswith("extract") or not line or len(line) < 10: |
|
continue |
|
|
|
|
|
if line.count(",") >= 2: |
|
parts = [kw.strip().strip('"') for kw in line.split(",") if kw.strip()] |
|
|
|
if all(len(p.split()) <= 3 for p in parts) and 1 <= len(parts) <= max_keywords: |
|
return parts |
|
|
|
return [] |
|
|
|
|
|
def is_relevant(article, keywords): |
|
text = f"{article.get('title', '')} {article.get('content', '')}".lower() |
|
return any(kw.lower() in text for kw in keywords) |
|
|
|
|
|
def mistral_generate(prompt: str, max_new_tokens=128): |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": 0.7 |
|
} |
|
} |
|
try: |
|
response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30) |
|
response.raise_for_status() |
|
result = response.json() |
|
if isinstance(result, list) and len(result) > 0: |
|
print("Mistral Result", result); |
|
return result[0].get("generated_text", "").strip() |
|
else: |
|
return "" |
|
except Exception: |
|
return "" |
|
|
|
def fetch_gnews_articles(query: str) -> List[dict]: |
|
encoded_query = quote(query) |
|
gnews_url = f"https://gnews.io/api/v4/search?q={encoded_query}&lang=en&max=5&expand=content&token={GNEWS_API_KEY}" |
|
print("GNews URL:", gnews_url) |
|
try: |
|
response = requests.get(gnews_url, timeout=10) |
|
response.raise_for_status() |
|
return response.json().get("articles", []) |
|
except Exception as e: |
|
print("GNews API error:", str(e)) |
|
return [] |
|
|
|
@askMe.post("/ask") |
|
async def ask_question(input: QuestionInput): |
|
question = input.question |
|
|
|
|
|
keyword_prompt = ( |
|
f"Extract the 3β6 most important keywords from the following question. " |
|
f"Return only the keywords, comma-separated (no explanations):\n\n" |
|
f"{question}" |
|
) |
|
raw_keywords = mistral_generate(keyword_prompt, max_new_tokens=32) |
|
keywords = extract_last_keywords(raw_keywords) |
|
|
|
print("Raw extracted keywords:", keywords) |
|
|
|
if not keywords: |
|
return {"error": "Keyword extraction failed."} |
|
|
|
|
|
query_and = " AND ".join(f'"{kw}"' for kw in keywords) |
|
articles = fetch_gnews_articles(query_and) |
|
|
|
if not articles: |
|
query_or = " OR ".join(f'"{kw}"' for kw in keywords) |
|
articles = fetch_gnews_articles(query_or) |
|
|
|
print("Fetched articles:", articles) |
|
|
|
relevant_articles = [a for a in articles if is_relevant(a, keywords)] |
|
|
|
context = "\n\n".join([ |
|
a.get("content") or "" |
|
for a in relevant_articles |
|
])[:15000] |
|
|
|
if not context.strip(): |
|
return { |
|
"question": question, |
|
"answer": "Cannot answer β no relevant context found.", |
|
"sources": [] |
|
} |
|
|
|
|
|
answer_prompt = ( |
|
f"You are a concise news assistant. Answer the user's question clearly using the context below if relevant. " |
|
f"If the context is not helpful, you may rely on your own knowledge, but do not mention the context or question again.\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"Question: {question}\n\n" |
|
f"Answer:" |
|
) |
|
answer = mistral_generate(answer_prompt, max_new_tokens=256) |
|
if not answer: |
|
answer = "Cannot answer β model did not return a valid response." |
|
|
|
print("Answer:", answer) |
|
|
|
return { |
|
"question": question, |
|
"answer": answer.strip(), |
|
"sources": [ |
|
{"title": a["title"], "url": a["url"]} |
|
for a in articles |
|
] |
|
} |
|
|