File size: 2,398 Bytes
aefa1e1
 
9c1bffa
 
 
 
aefa1e1
f00f379
aefa1e1
9c1bffa
aefa1e1
9c1bffa
aefa1e1
9c1bffa
 
 
 
 
 
aefa1e1
9c1bffa
 
 
 
 
 
 
 
aefa1e1
9c1bffa
 
 
 
aefa1e1
9c1bffa
 
 
 
 
 
 
 
 
 
 
aefa1e1
9c1bffa
 
 
 
aefa1e1
9c1bffa
 
 
 
 
 
 
 
aefa1e1
9c1bffa
 
 
 
 
 
 
 
 
 
aefa1e1
9c1bffa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# nuse_modules/keyword_extractor.py

from __future__ import annotations
import json, re, logging, itertools
from collections import Counter
from pathlib import Path

from models_initialization.mistral_registry import mistral_generate

STOPWORDS = set(Path(__file__).with_name("stopwords_en.txt").read_text().split())

_JSON_RE = re.compile(r"\[[^\[\]]+\]", re.S)  # first [...] block

def _dedupe_keep_order(seq):
    seen = set()
    for x in seq:
        if x.lower() not in seen:
            seen.add(x.lower())
            yield x

def _extract_with_llm(question: str, k: int) -> list[str]:
    prompt = (
        "Extract the **most important keywords** (nouns or noun-phrases) from the question below.\n"
        f"Return a **JSON list** of {k} or fewer lowercase keywords, no commentary.\n\n"
        f"QUESTION:\n{question}"
    )
    raw = mistral_generate(prompt, max_new_tokens=48, temperature=0.3)
    logging.debug("LLM raw output: %s", raw)

    # find the first [...] JSON chunk
    match = _JSON_RE.search(raw or "")
    if not match:
        raise ValueError("No JSON list detected in LLM output")

    try:
        keywords = json.loads(match.group())
        if not isinstance(keywords, list):
            raise ValueError
    except Exception as e:
        raise ValueError("Invalid JSON list") from e
    
    cleaned = list(
        _dedupe_keep_order(
            kw.lower().strip(" .,\"'") for kw in keywords if kw and kw.lower() not in STOPWORDS
        )
    )
    return cleaned[:k]


_WORD_RE = re.compile(r"[A-Za-z][\w\-]+")

def _fallback_keywords(text: str, k: int) -> list[str]:
    tokens = [t.lower() for t in _WORD_RE.findall(text)]
    tokens = [t for t in tokens if t not in STOPWORDS and len(t) > 2]
    counts = Counter(tokens)
    # remove very common words by frequency threshold
    common_cut = (len(tokens) // 100) + 2
    keywords, _ = zip(*counts.most_common(k + common_cut))
    return list(keywords[:k])

def keywords_extractor(question: str, max_keywords: int = 6) -> list[str]:
    """
    Return ≤ `max_keywords` keywords for the given question.
    """
    try:
        kw = _extract_with_llm(question, max_keywords)
        if kw:
            return kw
    except Exception as exc:
        logging.warning("LLM keyword extraction failed: %s. Falling back.", exc)

    # fallback heuristic
    return _fallback_keywords(question, max_keywords)