File size: 8,786 Bytes
1e5b35e
 
1b048ee
1e5b35e
505fff5
 
b944ab4
405b739
1f80d02
405b739
1e5b35e
effe83f
e69fcb5
a04ba24
4416daf
 
 
 
cd31712
b155241
 
f97a9e0
7f38362
b155241
 
f9a8906
1e5b35e
 
 
324c351
ea64562
 
39f5e31
a04ba24
1e5b35e
 
 
effe83f
1e5b35e
a04ba24
1f80d02
 
1e5b35e
a04ba24
 
1e5b35e
 
effe83f
a04ba24
1e5b35e
a04ba24
 
1e5b35e
 
 
a04ba24
f9a8906
1e5b35e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9a8906
1e5b35e
e69fcb5
f9a8906
 
 
 
 
1e5b35e
 
 
 
3ebf529
1e5b35e
 
 
39f5e31
 
 
 
 
1e5b35e
 
 
 
39f5e31
 
 
 
 
849e3ea
 
3ebf529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
effe83f
1f80d02
 
1e5b35e
 
3ebf529
1e5b35e
 
 
 
 
1f80d02
 
 
 
39f5e31
1f80d02
b155241
487fc40
1e5b35e
43f53c1
effe83f
43f53c1
1e5b35e
ad7b39c
487fc40
1e5b35e
1b048ee
1e5b35e
 
4416daf
1b048ee
7894f40
1e5b35e
 
 
babd2a7
a04ba24
43f53c1
f9a8906
324c351
2ad3344
 
1e5b35e
39f5e31
30881d9
f9a8906
 
7894f40
1e5b35e
8524416
 
 
 
 
 
9b4fbc6
fb858f0
1e5b35e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# ── app.py ───────────────────────────────────────────────────────────
import os, logging, textwrap
import gradio as gr
from transformers import pipeline, AutoTokenizer
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import concurrent.futures

# ── KONFIG ───────────────────────────────────────────────────────────
DOCS_DIR       = "document"
INDEX_DIR      = "faiss_index"
EMB_MODEL      = "KBLab/sentence-bert-swedish-cased"
#LLM_MODEL      = "tiiuae/falcon-rw-1b" # DΓ₯lig
#LLM_MODEL      = "google/flan-t5-base" # DΓ₯lig
#LLM_MODEL      = "bigscience/bloom-560m" # DΓ₯lig
#LLM_MODEL      = "NbAiLab/nb-gpt-j-6B" #- Restricted 
#LLM_MODEL      = "datificate/gpt2-small-swedish" # Finns ej pΓ₯ Hugging face
#LLM_MODEL      =  "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# timpal0l/mdeberta-v3-base-squad2 liten och mΓΆjlig pΓ₯ Svenska
#LLM_MODEL      = "AI-Sweden-Models/gpt-sw3-1.3B" # finns olika varianter 126M, 356M, 1.3B, 6.7B, 20B, 40B
LLM_MODEL      = "AI-Sweden-Models/gpt-sw3-356M"
# LLM_MODEL = AI-Sweden-Models/Llama-3-8B-instruct # kanske fΓΆr stor
# https://www.ai.se/en/ai-labs/natural-language-understanding/models-resources

CHUNK_SIZE     = 400
CHUNK_OVERLAP  = 40
CTX_TOK_MAX    = 750          # sparar marginal till frΓ₯ga + svar
MAX_NEW_TOKENS = 512
K              = 5
DEFAULT_TEMP   = 0.8
GEN_TIMEOUT    = 180  # Timeout fΓΆr generering i sekunder

# ── LOGGING ──────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)

# ── 1) Index (bygg eller ladda) ─────────────────────────────────────
emb = HuggingFaceEmbeddings(model_name=EMB_MODEL)
INDEX_PATH = os.path.join(INDEX_DIR, "index.faiss")
if os.path.isfile(INDEX_PATH):
    log.info(f"πŸ”„ Laddar index frΓ₯n {INDEX_DIR}")
    vs = FAISS.load_local(INDEX_DIR, emb)
else:
    splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
    docs, pdfs = [], []
    for fn in os.listdir(DOCS_DIR):
        if fn.lower().endswith(".pdf"):
            chunks = splitter.split_documents(PyPDFLoader(os.path.join(DOCS_DIR, fn)).load())
            for c in chunks:
                c.metadata["source"] = fn
            docs.extend(chunks); pdfs.append(fn)
    vs = FAISS.from_documents(docs, emb); vs.save_local(INDEX_DIR)
    log.info(f"βœ… Byggt index – {len(pdfs)}β€―PDF / {len(docs)}β€―chunkar")
retriever = vs.as_retriever(search_kwargs={"k": K})

# ── 2) LLM‑pipeline & tokenizer ─────────────────────────────────────
log.info("πŸš€ Initierar LLM …")
gen_pipe  = pipeline("text-generation", model=LLM_MODEL, device=-1, max_new_tokens=MAX_NEW_TOKENS)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
log.info("βœ… LLM klar")

# ── 3) HjΓ€lpfunktioner ──────────────────────────────────────────────
def build_prompt(query: str, docs):
    """
    Tar sΓ₯ mΓ₯nga chunkar som ryms i CTX_TOK_MAX token
    """
    context_parts = []
    total_ctx_tok = 0
    for d in docs:
        tok_len = len(tokenizer.encode(d.page_content))
        if total_ctx_tok + tok_len > CTX_TOK_MAX:
            break
        context_parts.append(d.page_content)
        total_ctx_tok += tok_len

    context = "\n\n---\n\n".join(context_parts)
    return textwrap.dedent(f"""\
        Du Γ€r en hjΓ€lpsam assistent som svarar pΓ₯ svenska.
        Kontext (hΓ€mtat ur PDF‑dokument):

        {context}

        FrΓ₯ga: {query}
        Svar (svenska):""").strip()

def test_retrieval(q):  # snabb‑test utan AI
    docs = retriever.invoke(q)
    return "\n\n".join([f"{i+1}. ({d.metadata['source']}) {d.page_content[:160]}…" for i, d in enumerate(docs)]) or "🚫 Inga trΓ€ffar"

def chat_fn(q, temp, max_new_tokens, k, ctx_tok_max, history):
    history = history or []
    history.append({"role": "user", "content": q})

    # HΓ€mta chunkar och poΓ€ng
    docs_and_scores = vs.similarity_search_with_score(q, k=int(k))
    docs = [doc for doc, score in docs_and_scores]
    scores = [score for doc, score in docs_and_scores]

    if not docs:
        history.append({"role": "assistant", "content": "🚫 Hittade inget relevant."})
        return history, history

    # Visa chunkar och poΓ€ng
    chunk_info = "\n\n".join([
        f"{i+1}. ({d.metadata['source']}) score={scores[i]:.3f}\n{d.page_content[:160]}…"
        for i, d in enumerate(docs)
    ])
    history.append({"role": "system", "content": f"πŸ”Ž Chunkar som anvΓ€nds:\n{chunk_info}"})

    def build_prompt_dynamic(query, docs, ctx_tok_max):
        context_parts = []
        total_ctx_tok = 0
        for d in docs:
            tok_len = len(tokenizer.encode(d.page_content))
            if total_ctx_tok + tok_len > int(ctx_tok_max):
                break
            context_parts.append(d.page_content)
            total_ctx_tok += tok_len
        context = "\n\n---\n\n".join(context_parts)
        return textwrap.dedent(f"""\
            Du Γ€r en hjΓ€lpsam assistent som svarar pΓ₯ svenska.
            Kontext (hΓ€mtat ur PDF‑dokument):

            {context}

            FrΓ₯ga: {query}
            Svar (svenska):""").strip()

    prompt = build_prompt_dynamic(q, docs, ctx_tok_max)
    log.info(f"Prompt tokens={len(tokenizer.encode(prompt))}  temp={temp}  max_new_tokens={max_new_tokens} k={k} ctx_tok_max={ctx_tok_max}")

    def generate():
        return gen_pipe(
            prompt,
            temperature=float(temp),
            max_new_tokens=int(max_new_tokens),
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,
            return_full_text=False
        )[0]["generated_text"]

    try:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future = executor.submit(generate)
            ans = future.result(timeout=GEN_TIMEOUT)  # Timeout in seconds
    except concurrent.futures.TimeoutError:
        ans = f"⏰ Ingen respons frΓ₯n modellen inom {GEN_TIMEOUT} sekunder."
    except Exception as e:
        log.exception("Genererings‑fel")
        ans = f"❌ Fel vid generering: {type(e).__name__}: {e}\n\nPrompt:\n{prompt}"

    src_hint = docs[0].metadata["source"] if docs else "Ingen kΓ€lla"
    history.append({"role": "assistant", "content": f"**(KΓ€lla: {src_hint})**\n\n{ans}"})
    return history, history

# ── 4) Gradio UI ────────────────────────────────────────────────────
with gr.Blocks() as demo:
    gr.Markdown("# πŸ“š Svensk RAG‑chat")
    gr.Markdown(f"**PDF‑filer:** {', '.join(os.listdir(DOCS_DIR)) or '–'}")
    gr.Markdown(f"**LLM-modell som anvΓ€nds:** `{LLM_MODEL}`", elem_id="llm-info")

    with gr.Row():
        q_test = gr.Textbox(label="πŸ”Ž Test Retrieval")
        b_test = gr.Button("Testa")
        o_test = gr.Textbox(label="Chunkar")

    with gr.Row():
        q_in   = gr.Textbox(label="FrΓ₯ga", placeholder="Ex: Vad Γ€r fΓΆrvaltningsΓΆverlΓ€mnande?")
        temp   = gr.Slider(0, 1, value=DEFAULT_TEMP, step=0.05, label="Temperatur")
        max_new_tokens = gr.Slider(32, 1024, value=MAX_NEW_TOKENS, step=8, label="Max svarslΓ€ngd (tokens)")
        k      = gr.Slider(1, 10, value=K, step=1, label="Antal chunkar (K)")
        ctx_tok_max = gr.Slider(100, 2000, value=CTX_TOK_MAX, step=50, label="Max kontexttokens")
        b_send = gr.Button("Skicka")
        b_stop = gr.Button("Stoppa")  # LΓ€gg till stoppknapp

    chat      = gr.Chatbot(type="messages", label="Chat")
    chat_hist = gr.State([])

    b_test.click(test_retrieval, inputs=[q_test], outputs=[o_test])
    send_event = b_send.click(
        chat_fn,
        inputs=[q_in, temp, max_new_tokens, k, ctx_tok_max, chat_hist],
        outputs=[chat, chat_hist]
    )
    b_stop.click(None, cancels=[send_event])

if __name__ == "__main__":
    demo.launch(share=True)   # ta bort share=True om du vill hΓ₯lla den privat