Spaces:
Sleeping
Sleeping
import gradio as gr | |
import faiss | |
import numpy as np | |
import json | |
from sentence_transformers import SentenceTransformer | |
from huggingface_hub import hf_hub_download | |
REPO_ID = "imjj/touchdesigner-rag-wiki-index" | |
EMBEDDING_MODELS = { | |
"MiniLM Flat": ("td_index_mini_flat_l2.faiss", "td_metadata_mini_flat_l2.json", "sentence-transformers/all-MiniLM-L6-v2"), | |
"MiniLM HNSW": ("td_index_mini_hnsw_l2.faiss", "td_metadata_mini_hnsw_l2.json", "sentence-transformers/all-MiniLM-L6-v2"), | |
"MPNet Flat": ("td_index_mpnet_flat_l2.faiss", "td_metadata_mpnet_flat_l2.json", "sentence-transformers/all-mpnet-base-v2") | |
} | |
TOP_K = 10 | |
current_index = None | |
current_chunks = None | |
current_embedder = None | |
def load_resources(selected_model): | |
global current_index, current_chunks, current_embedder | |
index_file, metadata_file, embedding_model = EMBEDDING_MODELS[selected_model] | |
index_path = hf_hub_download(repo_id=REPO_ID, filename=index_file) | |
metadata_path = hf_hub_download(repo_id=REPO_ID, filename=metadata_file) | |
current_index = faiss.read_index(index_path) | |
with open(metadata_path, "r", encoding="utf-8") as f: | |
current_chunks = json.load(f) | |
current_embedder = SentenceTransformer(embedding_model) | |
load_resources("MiniLM Flat") | |
def search(query, selected_model, top_k=TOP_K): | |
load_resources(selected_model) | |
query_vector = current_embedder.encode([query]) | |
distances, indices = current_index.search(np.array(query_vector).astype('float32'), k=top_k) | |
results = [] | |
for idx in indices[0]: | |
if 0 <= idx < len(current_chunks): | |
results.append(current_chunks[idx]["text"]) | |
return "\n---\n".join(results) | |
with gr.Blocks() as demo: | |
gr.Markdown("# TouchDesigner Wiki Search (RAG Prototype)") | |
gr.Markdown("Local retrieval system based on TouchDesigner's offline wiki using FAISS + MiniLM / MPNet embeddings.") | |
with gr.Row(): | |
query_input = gr.Textbox(label="Enter your question about TouchDesigner", lines=10) | |
output_text = gr.Textbox(label="Top-k Relevant Wiki Chunks", lines=10) | |
with gr.Row(): | |
model_selector = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), value="MiniLM Flat", label="Select Index") | |
top_k_input = gr.Slider(minimum=1, maximum=20, value=TOP_K, label="Top K", step=1) | |
submit_btn = gr.Button("Submit") | |
submit_btn.click(fn=search, inputs=[query_input, model_selector, top_k_input], outputs=output_text) | |
if __name__ == "__main__": | |
demo.launch() | |