|
|
|
|
|
import os |
|
import requests |
|
import json |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3" |
|
HEADERS = { |
|
"Authorization": f"Bearer {HF_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
|
|
def mistral_generate(prompt: str, max_new_tokens=128) -> str: |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": 0.7 |
|
} |
|
} |
|
try: |
|
response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30) |
|
response.raise_for_status() |
|
result = response.json() |
|
if isinstance(result, list) and len(result) > 0: |
|
return result[0].get("generated_text", "").strip() |
|
except Exception as e: |
|
print("[mistral_generate error]", str(e)) |
|
|
|
return "" |
|
|
|
|
|
def extract_last_keywords(raw: str, max_keywords: int = 8) -> list[str]: |
|
segments = raw.strip().split("\n") |
|
|
|
for line in reversed(segments): |
|
line = line.strip() |
|
if line.lower().startswith("extract") or not line or len(line) < 10: |
|
continue |
|
|
|
if line.count(",") >= 2: |
|
parts = [kw.strip().strip('"') for kw in line.split(",") if kw.strip()] |
|
if all(len(p.split()) <= 3 for p in parts) and 1 <= len(parts) <= max_keywords: |
|
return parts |
|
|
|
return [] |
|
|
|
|
|
def keywords_extractor(question: str) -> list[str]: |
|
prompt = ( |
|
f"Extract the 3–6 most important keywords from the following question. " |
|
f"Return only the keywords, comma-separated (no explanations):\n\n" |
|
f"{question}" |
|
) |
|
|
|
raw_output = mistral_generate(prompt, max_new_tokens=32) |
|
keywords = extract_last_keywords(raw_output) |
|
|
|
print("Raw extracted keywords:", raw_output) |
|
print("Parsed keywords:", keywords) |
|
|
|
return keywords |
|
|