Ghostdemo / app.py
ghostai1's picture
Update app.py
5700c07 verified
import json
import numpy as np
import gradio as gr
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration
# 1) Load the embedding model for retrieval
embedder = SentenceTransformer('all-MiniLM-L6-v2')
# 2) Load the T5 model and tokenizer for translation
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')
# 3) Load JSON knowledge bases
def load_docs(path, fmt):
with open(path, 'r') as f:
data = json.load(f)
return [fmt(item) for item in data]
call_docs = load_docs(
'call_center_data.json',
lambda d: f"Ticket {d['ticket_id']} ({d['customer']}): {d['description']}. Resolution: {d['resolution']}"
)
truck_docs = load_docs(
'trucking_data.json',
lambda d: f"Shipment {d['shipment_id']} from {d['origin']} to {d['destination']} is {d['status']}. Last update: {d['last_update']}"
)
docs = call_docs + truck_docs
# 4) Pre-compute embeddings for all documents
doc_embeddings = embedder.encode(
docs,
convert_to_numpy=True,
normalize_embeddings=True
)
def rag_translate(query, top_k):
# 4.1 Retrieval step
q_emb = embedder.encode(
[query],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
sims = doc_embeddings @ q_emb
idxs = np.argsort(sims)[::-1][:top_k]
retrieved = [docs[i] for i in idxs]
retrieved_text = "\n".join(retrieved)
# 4.2 Translation step using local t5-small
try:
# Prepare input for translation
spanish_input = f"translate English to Spanish: {retrieved_text}"
french_input = f"translate English to French: {retrieved_text}"
# Translate to Spanish
spanish_inputs = tokenizer(spanish_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
spanish_outputs = model.generate(
spanish_inputs["input_ids"],
max_length=512,
num_beams=4,
early_stopping=True
)
spanish_text = tokenizer.decode(spanish_outputs[0], skip_special_tokens=True)
# Translate to French
french_inputs = tokenizer(french_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
french_outputs = model.generate(
french_inputs["input_ids"],
max_length=512,
num_beams=4,
early_stopping=True
)
french_text = tokenizer.decode(french_outputs[0], skip_special_tokens=True)
# Combine translations
translations = f"Spanish:\n{spanish_text}\n\nFrench:\n{french_text}"
except Exception as e:
translations = f"Error during translation: {e}"
return retrieved_text, translations
# 5) Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## 🚀 Mini-RAG + Translation Demo")
with gr.Row():
with gr.Column(scale=3):
query = gr.Textbox(lines=2, placeholder="Ask me anything…", label="Your question")
top_k = gr.Slider(minimum=1, maximum=len(docs), value=1, step=1, label="Number of docs to retrieve")
submit = gr.Button("🔍 Retrieve & Translate")
retrieved_output = gr.Textbox(label="Retrieved Content (English)")
translated_output = gr.Textbox(label="Translated Content (Spanish & French)")
submit.click(
fn=rag_translate,
inputs=[query, top_k],
outputs=[retrieved_output, translated_output]
)
with gr.Column(scale=1):
gr.Markdown("### Example Queries")
examples = [
"What is the status of shipment S-2003?",
"Where did shipment S-2002 originate?",
"How many tickets mention payment errors?",
"What resolution was given for T-1001?"
]
for ex in examples:
btn = gr.Button(ex)
btn.click(lambda x=ex: x, None, query)
gr.Markdown(
"""
**How it works**
1) Retrieval: fetch top-k snippets from JSON KBs.
2) Translation: translate retrieved content to Spanish and French using t5-small locally.
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")