File size: 4,312 Bytes
066da6e 8f120f7 5700c07 8f120f7 5700c07 743990e 8f120f7 5700c07 066da6e 07210bd 5da808a 066da6e 5da808a 066da6e 07210bd 5700c07 066da6e 4e60ee7 e6d5216 b453200 e6d5216 066da6e 4e60ee7 066da6e e6d5216 add60ac 5700c07 066da6e 5700c07 e6d5216 5700c07 e6d5216 5700c07 e6d5216 5700c07 066da6e 5700c07 e6d5216 066da6e e6d5216 066da6e e6d5216 066da6e e6d5216 066da6e e6d5216 066da6e e6d5216 066da6e e6d5216 066da6e 5700c07 066da6e 8f120f7 743990e 6cea45f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import json
import numpy as np
import gradio as gr
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration
# 1) Load the embedding model for retrieval
embedder = SentenceTransformer('all-MiniLM-L6-v2')
# 2) Load the T5 model and tokenizer for translation
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')
# 3) Load JSON knowledge bases
def load_docs(path, fmt):
with open(path, 'r') as f:
data = json.load(f)
return [fmt(item) for item in data]
call_docs = load_docs(
'call_center_data.json',
lambda d: f"Ticket {d['ticket_id']} ({d['customer']}): {d['description']}. Resolution: {d['resolution']}"
)
truck_docs = load_docs(
'trucking_data.json',
lambda d: f"Shipment {d['shipment_id']} from {d['origin']} to {d['destination']} is {d['status']}. Last update: {d['last_update']}"
)
docs = call_docs + truck_docs
# 4) Pre-compute embeddings for all documents
doc_embeddings = embedder.encode(
docs,
convert_to_numpy=True,
normalize_embeddings=True
)
def rag_translate(query, top_k):
# 4.1 Retrieval step
q_emb = embedder.encode(
[query],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
sims = doc_embeddings @ q_emb
idxs = np.argsort(sims)[::-1][:top_k]
retrieved = [docs[i] for i in idxs]
retrieved_text = "\n".join(retrieved)
# 4.2 Translation step using local t5-small
try:
# Prepare input for translation
spanish_input = f"translate English to Spanish: {retrieved_text}"
french_input = f"translate English to French: {retrieved_text}"
# Translate to Spanish
spanish_inputs = tokenizer(spanish_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
spanish_outputs = model.generate(
spanish_inputs["input_ids"],
max_length=512,
num_beams=4,
early_stopping=True
)
spanish_text = tokenizer.decode(spanish_outputs[0], skip_special_tokens=True)
# Translate to French
french_inputs = tokenizer(french_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
french_outputs = model.generate(
french_inputs["input_ids"],
max_length=512,
num_beams=4,
early_stopping=True
)
french_text = tokenizer.decode(french_outputs[0], skip_special_tokens=True)
# Combine translations
translations = f"Spanish:\n{spanish_text}\n\nFrench:\n{french_text}"
except Exception as e:
translations = f"Error during translation: {e}"
return retrieved_text, translations
# 5) Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## 🚀 Mini-RAG + Translation Demo")
with gr.Row():
with gr.Column(scale=3):
query = gr.Textbox(lines=2, placeholder="Ask me anything…", label="Your question")
top_k = gr.Slider(minimum=1, maximum=len(docs), value=1, step=1, label="Number of docs to retrieve")
submit = gr.Button("🔍 Retrieve & Translate")
retrieved_output = gr.Textbox(label="Retrieved Content (English)")
translated_output = gr.Textbox(label="Translated Content (Spanish & French)")
submit.click(
fn=rag_translate,
inputs=[query, top_k],
outputs=[retrieved_output, translated_output]
)
with gr.Column(scale=1):
gr.Markdown("### Example Queries")
examples = [
"What is the status of shipment S-2003?",
"Where did shipment S-2002 originate?",
"How many tickets mention payment errors?",
"What resolution was given for T-1001?"
]
for ex in examples:
btn = gr.Button(ex)
btn.click(lambda x=ex: x, None, query)
gr.Markdown(
"""
**How it works**
1) Retrieval: fetch top-k snippets from JSON KBs.
2) Translation: translate retrieved content to Spanish and French using t5-small locally.
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0") |