|
import json |
|
import numpy as np |
|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
embedder = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained('t5-small') |
|
model = T5ForConditionalGeneration.from_pretrained('t5-small') |
|
|
|
|
|
def load_docs(path, fmt): |
|
with open(path, 'r') as f: |
|
data = json.load(f) |
|
return [fmt(item) for item in data] |
|
|
|
call_docs = load_docs( |
|
'call_center_data.json', |
|
lambda d: f"Ticket {d['ticket_id']} ({d['customer']}): {d['description']}. Resolution: {d['resolution']}" |
|
) |
|
truck_docs = load_docs( |
|
'trucking_data.json', |
|
lambda d: f"Shipment {d['shipment_id']} from {d['origin']} to {d['destination']} is {d['status']}. Last update: {d['last_update']}" |
|
) |
|
docs = call_docs + truck_docs |
|
|
|
|
|
doc_embeddings = embedder.encode( |
|
docs, |
|
convert_to_numpy=True, |
|
normalize_embeddings=True |
|
) |
|
|
|
def rag_translate(query, top_k): |
|
|
|
q_emb = embedder.encode( |
|
[query], |
|
convert_to_numpy=True, |
|
normalize_embeddings=True |
|
)[0] |
|
sims = doc_embeddings @ q_emb |
|
idxs = np.argsort(sims)[::-1][:top_k] |
|
retrieved = [docs[i] for i in idxs] |
|
retrieved_text = "\n".join(retrieved) |
|
|
|
|
|
try: |
|
|
|
spanish_input = f"translate English to Spanish: {retrieved_text}" |
|
french_input = f"translate English to French: {retrieved_text}" |
|
|
|
|
|
spanish_inputs = tokenizer(spanish_input, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
spanish_outputs = model.generate( |
|
spanish_inputs["input_ids"], |
|
max_length=512, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
spanish_text = tokenizer.decode(spanish_outputs[0], skip_special_tokens=True) |
|
|
|
|
|
french_inputs = tokenizer(french_input, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
french_outputs = model.generate( |
|
french_inputs["input_ids"], |
|
max_length=512, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
french_text = tokenizer.decode(french_outputs[0], skip_special_tokens=True) |
|
|
|
|
|
translations = f"Spanish:\n{spanish_text}\n\nFrench:\n{french_text}" |
|
except Exception as e: |
|
translations = f"Error during translation: {e}" |
|
|
|
return retrieved_text, translations |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🚀 Mini-RAG + Translation Demo") |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
query = gr.Textbox(lines=2, placeholder="Ask me anything…", label="Your question") |
|
top_k = gr.Slider(minimum=1, maximum=len(docs), value=1, step=1, label="Number of docs to retrieve") |
|
submit = gr.Button("🔍 Retrieve & Translate") |
|
retrieved_output = gr.Textbox(label="Retrieved Content (English)") |
|
translated_output = gr.Textbox(label="Translated Content (Spanish & French)") |
|
submit.click( |
|
fn=rag_translate, |
|
inputs=[query, top_k], |
|
outputs=[retrieved_output, translated_output] |
|
) |
|
with gr.Column(scale=1): |
|
gr.Markdown("### Example Queries") |
|
examples = [ |
|
"What is the status of shipment S-2003?", |
|
"Where did shipment S-2002 originate?", |
|
"How many tickets mention payment errors?", |
|
"What resolution was given for T-1001?" |
|
] |
|
for ex in examples: |
|
btn = gr.Button(ex) |
|
btn.click(lambda x=ex: x, None, query) |
|
gr.Markdown( |
|
""" |
|
**How it works** |
|
1) Retrieval: fetch top-k snippets from JSON KBs. |
|
2) Translation: translate retrieved content to Spanish and French using t5-small locally. |
|
""" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0") |