raghavNCI commited on
Commit
f00f379
Β·
1 Parent(s): 206e141

too many changes, hope this works

Browse files
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI
2
- from routes import router # routes.py must be in same folder
3
- from question import askMe
4
  from dotenv import load_dotenv
5
  from cache_init import fetch_and_cache_articles
6
 
 
1
  from fastapi import FastAPI
2
+ from routes.category import router # routes.py must be in same folder
3
+ from routes.question import askMe
4
  from dotenv import load_dotenv
5
  from cache_init import fetch_and_cache_articles
6
 
clients/__init__.py ADDED
File without changes
redis_client.py β†’ clients/redis_client.py RENAMED
File without changes
models_initialization/__init__.py ADDED
File without changes
models_initialization/bart_large_registry.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
+ HF_BART_MODEL = "facebook/bart-large-mnli"
9
+ HF_API_URL = f"https://api-inference.huggingface.co/models/{HF_BART_MODEL}"
10
+
11
+ HEADERS = {
12
+ "Authorization": f"Bearer {HF_TOKEN}",
13
+ "Content-Type": "application/json"
14
+ }
15
+
16
+ def run_zero_shot_classification(prompt: str, candidate_labels: list[str]) -> str:
17
+ payload = {
18
+ "inputs": prompt,
19
+ "parameters": {
20
+ "candidate_labels": candidate_labels
21
+ }
22
+ }
23
+
24
+ try:
25
+ response = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=20)
26
+ response.raise_for_status()
27
+ data = response.json()
28
+ if isinstance(data, dict) and "labels" in data and data["labels"]:
29
+ return data["labels"][0] # Most confident label
30
+ except Exception as e:
31
+ print("[BART Zero-Shot Error]:", str(e))
32
+
33
+ return ""
34
+
35
+ def run_entailment_check(premise: str, hypothesis: str) -> bool:
36
+ payload = {
37
+ "inputs": {
38
+ "premise": premise,
39
+ "hypothesis": hypothesis
40
+ }
41
+ }
42
+
43
+ try:
44
+ response = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=20)
45
+ response.raise_for_status()
46
+ data = response.json()
47
+ if isinstance(data, dict) and "labels" in data:
48
+ labels = data["labels"]
49
+ scores = data["scores"]
50
+ if "entailment" in labels:
51
+ entailment_score = scores[labels.index("entailment")]
52
+ return entailment_score > 0.5
53
+ except Exception as e:
54
+ print("[BART Entailment Error]:", str(e))
55
+
56
+ return False
models_initialization/mistral_registry.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ HF_TOKEN = os.getenv("HF_TOKEN")
9
+ HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
10
+
11
+ HEADERS = {
12
+ "Authorization": f"Bearer {HF_TOKEN}",
13
+ "Content-Type": "application/json"
14
+ }
15
+
16
+ def mistral_generate(prompt: str, max_new_tokens=128, temperature=0.7) -> str:
17
+ payload = {
18
+ "inputs": prompt,
19
+ "parameters": {
20
+ "max_new_tokens": max_new_tokens,
21
+ "temperature": temperature
22
+ }
23
+ }
24
+
25
+ try:
26
+ response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30)
27
+ response.raise_for_status()
28
+ result = response.json()
29
+ if isinstance(result, list) and result:
30
+ return result[0].get("generated_text", "").strip()
31
+ except Exception as e:
32
+ print("Mistral API error:", e)
33
+
34
+ return ""
nuse_modules/classifier.py CHANGED
@@ -1,12 +1,10 @@
1
  import os
2
  import re
3
- import requests
4
  from dotenv import load_dotenv
5
 
6
- load_dotenv()
7
 
8
- HF_TOKEN = os.getenv("HF_TOKEN")
9
- HF_ZERO_SHOT_MODEL = "facebook/bart-large-mnli"
10
 
11
  # Map readable categories to numeric IDs
12
  QUESTION_TYPES = {
@@ -60,39 +58,9 @@ def rule_based_classify(prompt: str) -> int:
60
  return -1
61
 
62
 
63
- # ---------- Step 2: HF Zero-Shot Fallback ----------
64
- def zero_shot_classify(prompt: str) -> int:
65
- candidate_labels = list(QUESTION_TYPES.keys())
66
- payload = {
67
- "inputs": prompt,
68
- "parameters": {
69
- "candidate_labels": candidate_labels
70
- }
71
- }
72
-
73
- headers = {
74
- "Authorization": f"Bearer {HF_TOKEN}",
75
- "Content-Type": "application/json"
76
- }
77
-
78
- url = f"https://api-inference.huggingface.co/models/{HF_ZERO_SHOT_MODEL}"
79
-
80
- try:
81
- res = requests.post(url, headers=headers, json=payload, timeout=20)
82
- res.raise_for_status()
83
- data = res.json()
84
- if isinstance(data, dict) and "labels" in data:
85
- top_label = data["labels"][0]
86
- return QUESTION_TYPES.get(top_label, -1)
87
- except Exception as e:
88
- print("[HF Classifier Error]", str(e))
89
-
90
- return -1
91
-
92
-
93
  # ---------- Public Hybrid Classifier ----------
94
  def classify_question(prompt: str) -> int:
95
  rule_result = rule_based_classify(prompt)
96
  if rule_result != -1:
97
  return rule_result
98
- return zero_shot_classify(prompt)
 
1
  import os
2
  import re
 
3
  from dotenv import load_dotenv
4
 
5
+ from models_initialization.bart_large_registry import run_zero_shot_classification
6
 
7
+ load_dotenv()
 
8
 
9
  # Map readable categories to numeric IDs
10
  QUESTION_TYPES = {
 
58
  return -1
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # ---------- Public Hybrid Classifier ----------
62
  def classify_question(prompt: str) -> int:
63
  rule_result = rule_based_classify(prompt)
64
  if rule_result != -1:
65
  return rule_result
66
+ return run_zero_shot_classification(prompt, candidate_labels=list(QUESTION_TYPES.keys()))
nuse_modules/keyword_extracter.py CHANGED
@@ -4,34 +4,7 @@ import os
4
  import requests
5
  import json
6
 
7
- HF_TOKEN = os.getenv("HF_TOKEN")
8
-
9
- HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
10
- HEADERS = {
11
- "Authorization": f"Bearer {HF_TOKEN}",
12
- "Content-Type": "application/json"
13
- }
14
-
15
-
16
- def mistral_generate(prompt: str, max_new_tokens=128) -> str:
17
- payload = {
18
- "inputs": prompt,
19
- "parameters": {
20
- "max_new_tokens": max_new_tokens,
21
- "temperature": 0.7
22
- }
23
- }
24
- try:
25
- response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30)
26
- response.raise_for_status()
27
- result = response.json()
28
- if isinstance(result, list) and len(result) > 0:
29
- return result[0].get("generated_text", "").strip()
30
- except Exception as e:
31
- print("[mistral_generate error]", str(e))
32
-
33
- return ""
34
-
35
 
36
  def extract_last_keywords(raw: str, max_keywords: int = 8) -> list[str]:
37
  segments = raw.strip().split("\n")
 
4
  import requests
5
  import json
6
 
7
+ from models_initialization.mistral_registry import mistral_generate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def extract_last_keywords(raw: str, max_keywords: int = 8) -> list[str]:
10
  segments = raw.strip().split("\n")
routes/__init__.py ADDED
File without changes
routes.py β†’ routes/category.py RENAMED
@@ -1,7 +1,7 @@
1
  import os
2
  from fastapi import APIRouter
3
  from dotenv import load_dotenv
4
- from redis_client import redis_client as r
5
 
6
  load_dotenv()
7
 
 
1
  import os
2
  from fastapi import APIRouter
3
  from dotenv import load_dotenv
4
+ from clients.redis_client import redis_client as r
5
 
6
  load_dotenv()
7
 
question.py β†’ routes/question.py RENAMED
@@ -3,24 +3,16 @@ import requests
3
  import json
4
  from fastapi import APIRouter
5
  from pydantic import BaseModel
6
- from typing import List
7
- from redis_client import redis_client as r
8
  from dotenv import load_dotenv
9
- from urllib.parse import quote
10
 
 
11
  from nuse_modules.classifier import classify_question, REVERSE_MAP
12
  from nuse_modules.keyword_extracter import keywords_extractor
13
  from nuse_modules.google_search import search_google_news
14
 
15
  load_dotenv()
16
 
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
- HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
19
- HEADERS = {
20
- "Authorization": f"Bearer {HF_TOKEN}",
21
- "Content-Type": "application/json"
22
- }
23
-
24
  askMe = APIRouter()
25
 
26
  class QuestionInput(BaseModel):
@@ -41,26 +33,6 @@ def extract_answer_after_label(text: str) -> str:
41
  return text.strip()
42
 
43
 
44
- def mistral_generate(prompt: str, max_new_tokens=128):
45
- payload = {
46
- "inputs": prompt,
47
- "parameters": {
48
- "max_new_tokens": max_new_tokens,
49
- "temperature": 0.7
50
- }
51
- }
52
- try:
53
- response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30)
54
- response.raise_for_status()
55
- result = response.json()
56
- if isinstance(result, list) and len(result) > 0:
57
- return result[0].get("generated_text", "").strip()
58
- else:
59
- return ""
60
- except Exception:
61
- return ""
62
-
63
-
64
  @askMe.post("/ask")
65
  async def ask_question(input: QuestionInput):
66
  question = input.question
 
3
  import json
4
  from fastapi import APIRouter
5
  from pydantic import BaseModel
6
+ from clients.redis_client import redis_client as r
 
7
  from dotenv import load_dotenv
 
8
 
9
+ from models_initialization.mistral_registry import mistral_generate
10
  from nuse_modules.classifier import classify_question, REVERSE_MAP
11
  from nuse_modules.keyword_extracter import keywords_extractor
12
  from nuse_modules.google_search import search_google_news
13
 
14
  load_dotenv()
15
 
 
 
 
 
 
 
 
16
  askMe = APIRouter()
17
 
18
  class QuestionInput(BaseModel):
 
33
  return text.strip()
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @askMe.post("/ask")
37
  async def ask_question(input: QuestionInput):
38
  question = input.question