File size: 4,694 Bytes
8e9fd43
 
 
2c42748
8e9fd43
 
 
 
587894c
fbd5aba
8e9fd43
 
 
 
fbd5aba
8e9fd43
c1ba890
8e9fd43
 
 
 
2f3b9d0
 
 
 
 
 
2c39b8a
587894c
2c39b8a
 
587894c
2c39b8a
 
 
 
 
 
 
 
 
 
 
587894c
 
2c39b8a
43cf665
 
 
 
 
2f3b9d0
 
 
 
 
 
 
 
 
 
 
 
 
9952378
2f3b9d0
 
 
ed08d36
2f3b9d0
 
ed08d36
 
 
 
 
 
 
 
 
 
 
 
c1ba890
8e9fd43
 
 
ed08d36
2f3b9d0
9ff5d3a
 
 
2f3b9d0
 
587894c
 
 
d229ca3
587894c
2f3b9d0
d229ca3
ed08d36
 
 
587894c
ed08d36
 
 
2f3b9d0
 
8e9fd43
43cf665
 
8e9fd43
43cf665
 
 
8e9fd43
fbd5aba
 
 
 
 
 
5da06ea
ed08d36
2f3b9d0
43cf665
 
 
 
 
2f3b9d0
 
 
 
8573cc3
2f3b9d0
8e9fd43
 
 
fbd5aba
8e9fd43
 
 
 
2c42748
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# app/routes/question.py
import os
import requests
from fastapi import APIRouter
from pydantic import BaseModel
from typing import List
from redis_client import redis_client as r
from dotenv import load_dotenv
from urllib.parse import quote
import json

load_dotenv()

GNEWS_API_KEY = os.getenv("GNEWS_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")

askMe = APIRouter()

class QuestionInput(BaseModel):
    question: str

HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
HEADERS = {
    "Authorization": f"Bearer {HF_TOKEN}",
    "Content-Type": "application/json"
}

def extract_last_keywords(raw: str, max_keywords=8):
    segments = raw.strip().split("\n")

    # Ignore quoted or prompt lines
    for line in reversed(segments):
        line = line.strip()
        if line.lower().startswith("extract") or not line or len(line) < 10:
            continue

        # Look for lines with multiple comma-separated items
        if line.count(",") >= 2:
            parts = [kw.strip().strip('"') for kw in line.split(",") if kw.strip()]
            # Ensure they're not just long phrases or sentence fragments
            if all(len(p.split()) <= 3 for p in parts) and 1 <= len(parts) <= max_keywords:
                return parts

    return []


def is_relevant(article, keywords):
    text = f"{article.get('title', '')} {article.get('content', '')}".lower()
    return any(kw.lower() in text for kw in keywords)


def mistral_generate(prompt: str, max_new_tokens=128):
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature": 0.7
        }
    }
    try:
        response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30)
        response.raise_for_status()
        result = response.json()
        if isinstance(result, list) and len(result) > 0:
            print("Mistral Result", result);
            return result[0].get("generated_text", "").strip()
        else:
            return ""
    except Exception:
        return ""

def fetch_gnews_articles(query: str) -> List[dict]:
    encoded_query = quote(query)
    gnews_url = f"https://gnews.io/api/v4/search?q={encoded_query}&lang=en&max=5&expand=content&token={GNEWS_API_KEY}"
    print("GNews URL:", gnews_url)
    try:
        response = requests.get(gnews_url, timeout=10)
        response.raise_for_status()
        return response.json().get("articles", [])
    except Exception as e:
        print("GNews API error:", str(e))
        return []

@askMe.post("/ask")
async def ask_question(input: QuestionInput):
    question = input.question

    # Step 1: Ask Mistral to extract keywords
    keyword_prompt = (
        f"Extract the 3–6 most important keywords from the following question. "
        f"Return only the keywords, comma-separated (no explanations):\n\n"
        f"{question}"
    )
    raw_keywords = mistral_generate(keyword_prompt, max_new_tokens=32)
    keywords = extract_last_keywords(raw_keywords)

    print("Raw extracted keywords:", keywords)

    if not keywords:
        return {"error": "Keyword extraction failed."}

    # Step 2: Fetch articles using AND, then fallback to OR
    query_and = " AND ".join(f'"{kw}"' for kw in keywords)
    articles = fetch_gnews_articles(query_and)

    if not articles:
        query_or = " OR ".join(f'"{kw}"' for kw in keywords)
        articles = fetch_gnews_articles(query_or)

    print("Fetched articles:", articles)

    relevant_articles = [a for a in articles if is_relevant(a, keywords)]

    context = "\n\n".join([
        a.get("content") or ""
        for a in relevant_articles
    ])[:15000]

    if not context.strip():
        return {
            "question": question,
            "answer": "Cannot answer – no relevant context found.",
            "sources": []
        }

    # Step 3: Ask Mistral to answer using the context
    answer_prompt = (
        f"You are a concise news assistant. Answer the user's question clearly using the context below if relevant. "
        f"If the context is not helpful, you may rely on your own knowledge, but do not mention the context or question again.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {question}\n\n"
        f"Answer:"
    )
    answer = mistral_generate(answer_prompt, max_new_tokens=256)
    if not answer:
        answer = "Cannot answer – model did not return a valid response."

    print("Answer:", answer)

    return {
        "question": question,
        "answer": answer.strip(),
        "sources": [
            {"title": a["title"], "url": a["url"]}
            for a in articles
        ]
    }