File size: 1,922 Bytes
aefa1e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# nuse_modules/keyword_extractor.py

import os
import requests
import json

HF_TOKEN = os.getenv("HF_TOKEN")

HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
HEADERS = {
    "Authorization": f"Bearer {HF_TOKEN}",
    "Content-Type": "application/json"
}


def mistral_generate(prompt: str, max_new_tokens=128) -> str:
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "temperature": 0.7
        }
    }
    try:
        response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30)
        response.raise_for_status()
        result = response.json()
        if isinstance(result, list) and len(result) > 0:
            return result[0].get("generated_text", "").strip()
    except Exception as e:
        print("[mistral_generate error]", str(e))

    return ""


def extract_last_keywords(raw: str, max_keywords: int = 8) -> list[str]:
    segments = raw.strip().split("\n")

    for line in reversed(segments):
        line = line.strip()
        if line.lower().startswith("extract") or not line or len(line) < 10:
            continue

        if line.count(",") >= 2:
            parts = [kw.strip().strip('"') for kw in line.split(",") if kw.strip()]
            if all(len(p.split()) <= 3 for p in parts) and 1 <= len(parts) <= max_keywords:
                return parts

    return []


def keywords_extractor(question: str) -> list[str]:
    prompt = (
        f"Extract the 3–6 most important keywords from the following question. "
        f"Return only the keywords, comma-separated (no explanations):\n\n"
        f"{question}"
    )

    raw_output = mistral_generate(prompt, max_new_tokens=32)
    keywords = extract_last_keywords(raw_output)

    print("Raw extracted keywords:", raw_output)
    print("Parsed keywords:", keywords)

    return keywords