td-rag-sapce / app.py
J
init app
01968bf
import gradio as gr
import faiss
import numpy as np
import json
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
REPO_ID = "imjj/touchdesigner-rag-wiki-index"
EMBEDDING_MODELS = {
"MiniLM Flat": ("td_index_mini_flat_l2.faiss", "td_metadata_mini_flat_l2.json", "sentence-transformers/all-MiniLM-L6-v2"),
"MiniLM HNSW": ("td_index_mini_hnsw_l2.faiss", "td_metadata_mini_hnsw_l2.json", "sentence-transformers/all-MiniLM-L6-v2"),
"MPNet Flat": ("td_index_mpnet_flat_l2.faiss", "td_metadata_mpnet_flat_l2.json", "sentence-transformers/all-mpnet-base-v2")
}
TOP_K = 10
current_index = None
current_chunks = None
current_embedder = None
def load_resources(selected_model):
global current_index, current_chunks, current_embedder
index_file, metadata_file, embedding_model = EMBEDDING_MODELS[selected_model]
index_path = hf_hub_download(repo_id=REPO_ID, filename=index_file)
metadata_path = hf_hub_download(repo_id=REPO_ID, filename=metadata_file)
current_index = faiss.read_index(index_path)
with open(metadata_path, "r", encoding="utf-8") as f:
current_chunks = json.load(f)
current_embedder = SentenceTransformer(embedding_model)
load_resources("MiniLM Flat")
def search(query, selected_model, top_k=TOP_K):
load_resources(selected_model)
query_vector = current_embedder.encode([query])
distances, indices = current_index.search(np.array(query_vector).astype('float32'), k=top_k)
results = []
for idx in indices[0]:
if 0 <= idx < len(current_chunks):
results.append(current_chunks[idx]["text"])
return "\n---\n".join(results)
with gr.Blocks() as demo:
gr.Markdown("# TouchDesigner Wiki Search (RAG Prototype)")
gr.Markdown("Local retrieval system based on TouchDesigner's offline wiki using FAISS + MiniLM / MPNet embeddings.")
with gr.Row():
query_input = gr.Textbox(label="Enter your question about TouchDesigner", lines=10)
output_text = gr.Textbox(label="Top-k Relevant Wiki Chunks", lines=10)
with gr.Row():
model_selector = gr.Dropdown(choices=list(EMBEDDING_MODELS.keys()), value="MiniLM Flat", label="Select Index")
top_k_input = gr.Slider(minimum=1, maximum=20, value=TOP_K, label="Top K", step=1)
submit_btn = gr.Button("Submit")
submit_btn.click(fn=search, inputs=[query_input, model_selector, top_k_input], outputs=output_text)
if __name__ == "__main__":
demo.launch()