File size: 4,622 Bytes
35c5459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9facab9
35c5459
 
9facab9
 
35c5459
 
 
 
 
 
9facab9
 
35c5459
 
9facab9
 
 
 
35c5459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer

# --- minimal core (in-memory only) ---
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
_model = SentenceTransformer(MODEL_NAME)
_dim = int(_model.encode(["_probe_"], convert_to_numpy=True).shape[1])  # 384

_index = faiss.IndexFlatIP(_dim)  # cosine via L2-normalized IP
_ids, _texts, _metas = [], [], []

def _normalize(v: np.ndarray) -> np.ndarray:
    n = np.linalg.norm(v, axis=1, keepdims=True) + 1e-12
    return (v / n).astype("float32")

def _chunk(text: str, size: int, overlap: int):
    t = " ".join((text or "").split())
    n = len(t); s = 0; out = []
    if overlap >= size: overlap = max(size - 1, 0)
    while s < n:
        e = min(s + size, n)
        out.append((t[s:e], s, e))
        if e == n: break
        s = max(e - overlap, 0)
    return out

def reset():
    global _index, _ids, _texts, _metas
    _index = faiss.IndexFlatIP(_dim)
    _ids, _texts, _metas = [], [], []
    return gr.update(value="Index reset."), gr.update(value=0)

def load_sample():
    docs = [
        ("a", "PySpark scales ETL across clusters.", {"tag":"spark"}),
        ("b", "FAISS powers fast vector similarity search used in retrieval.", {"tag":"faiss"})
    ]
    return "\n".join([d[1] for d in docs])

def ingest(docs_text, size, overlap):
    if not docs_text.strip():
        return "Provide at least one line of text.", len(_ids)
    # one document per line
    lines = [ln.strip() for ln in docs_text.splitlines() if ln.strip()]
    rows = []
    for i, ln in enumerate(lines):
        pid = f"doc-{len(_ids)}-{i}"
        for ctext, s, e in _chunk(ln, size, overlap):
            rows.append((f"{pid}::offset:{s}-{e}", ctext, {"parent_id": pid, "start": s, "end": e}))
    if not rows:
        return "No chunks produced.", len(_ids)
    vecs = _normalize(_model.encode([r[1] for r in rows], convert_to_numpy=True))
    _index.add(vecs)
    for rid, txt, meta in rows:
        _ids.append(rid); _texts.append(txt); _metas.append(meta)
    return f"Ingested docs={len(lines)} chunks={len(rows)}", len(_ids)

def answer(q, k, max_context_chars):
    if _index.ntotal == 0:
        return {"answer": "Index is empty. Ingest first.", "matches": []}
    qv = _normalize(_model.encode([q], convert_to_numpy=True))
    D, I = _index.search(qv, int(k))

    matches = []
    for i, s in zip(I[0].tolist(), D[0].tolist()):
        if i < 0:
            continue
        matches.append({
            "id": _ids[i],
            "score": float(s),
            "text": _texts[i],
            "meta": _metas[i]
        })

    if not matches:
        out = "No relevant context."
    else:
        # 👇 only use the top match for the answer
        top = matches[0]["text"]
        out = f"Based on retrieved context:\n- {top}"

    return {"answer": out, "matches": matches}

with gr.Blocks(title="RAG-as-a-Service") as demo:
    gr.Markdown("### RAG-as-a-Service - Gradio\nIn-memory FAISS + MiniLM\n; one-line-per-doc ingest\n; quick answers.")

    with gr.Row():
        with gr.Column():
            docs = gr.Textbox(label="Documents (one per line)", lines=6, placeholder="One document per line…")
            with gr.Row():
                chunk_size = gr.Slider(64, 1024, value=256, step=16, label="Chunk size")
                overlap = gr.Slider(0, 256, value=32, step=8, label="Overlap")
            with gr.Row():
                ingest_btn = gr.Button("Ingest")
                sample_btn = gr.Button("Load sample")
                reset_btn = gr.Button("Reset")
            ingest_status = gr.Textbox(label="Ingest status", interactive=False)
            index_size = gr.Number(label="Index size", interactive=False, value=0)
        with gr.Column():
            q = gr.Textbox(label="Query", placeholder="Ask something...")
            k = gr.Slider(1, 10, value=5, step=1, label="Top-K")
            max_chars = gr.Slider(200, 4000, value=1000, step=100, label="Max context chars")
            run = gr.Button("Answer")
            out = gr.JSON(label="Answer + matches")

    ingest_btn.click(
    ingest,
    [docs, chunk_size, overlap],
    [ingest_status, index_size],
    api_name="ingest"      # exposes POST /api/ingest
)
    sample_btn.click(load_sample, None, docs)
    reset_btn.click(
    reset,
    None,
    [ingest_status, index_size],
    api_name="reset"       # exposes POST /api/reset (optional)
)
    run.click(
    answer,
    [q, k, max_chars],
    out,
    api_name="answer"      # exposes POST /api/answer
)

if __name__ == "__main__":
    demo.launch(share=True)