|
|
|
|
|
from __future__ import annotations |
|
import json, re, logging, itertools |
|
from collections import Counter |
|
from pathlib import Path |
|
|
|
from models_initialization.mistral_registry import mistral_generate |
|
|
|
STOPWORDS = set(Path(__file__).with_name("stopwords_en.txt").read_text().split()) |
|
|
|
_JSON_RE = re.compile(r"\[[^\[\]]+\]", re.S) |
|
|
|
def _dedupe_keep_order(seq): |
|
seen = set() |
|
for x in seq: |
|
if x.lower() not in seen: |
|
seen.add(x.lower()) |
|
yield x |
|
|
|
def _extract_with_llm(question: str, k: int) -> list[str]: |
|
prompt = ( |
|
"Extract the **most important keywords** (nouns or noun-phrases) from the question below.\n" |
|
f"Return a **JSON list** of {k} or fewer lowercase keywords, no commentary.\n\n" |
|
f"QUESTION:\n{question}" |
|
) |
|
raw = mistral_generate(prompt, max_new_tokens=48, temperature=0.3) |
|
logging.debug("LLM raw output: %s", raw) |
|
|
|
|
|
match = _JSON_RE.search(raw or "") |
|
if not match: |
|
raise ValueError("No JSON list detected in LLM output") |
|
|
|
try: |
|
keywords = json.loads(match.group()) |
|
if not isinstance(keywords, list): |
|
raise ValueError |
|
except Exception as e: |
|
raise ValueError("Invalid JSON list") from e |
|
|
|
cleaned = list( |
|
_dedupe_keep_order( |
|
kw.lower().strip(" .,\"'") for kw in keywords if kw and kw.lower() not in STOPWORDS |
|
) |
|
) |
|
return cleaned[:k] |
|
|
|
|
|
_WORD_RE = re.compile(r"[A-Za-z][\w\-]+") |
|
|
|
def _fallback_keywords(text: str, k: int) -> list[str]: |
|
tokens = [t.lower() for t in _WORD_RE.findall(text)] |
|
tokens = [t for t in tokens if t not in STOPWORDS and len(t) > 2] |
|
counts = Counter(tokens) |
|
|
|
common_cut = (len(tokens) // 100) + 2 |
|
keywords, _ = zip(*counts.most_common(k + common_cut)) |
|
return list(keywords[:k]) |
|
|
|
def keywords_extractor(question: str, max_keywords: int = 6) -> list[str]: |
|
""" |
|
Return ≤ `max_keywords` keywords for the given question. |
|
""" |
|
try: |
|
kw = _extract_with_llm(question, max_keywords) |
|
if kw: |
|
return kw |
|
except Exception as exc: |
|
logging.warning("LLM keyword extraction failed: %s. Falling back.", exc) |
|
|
|
|
|
return _fallback_keywords(question, max_keywords) |
|
|