File size: 4,312 Bytes
066da6e
 
 
8f120f7
5700c07
8f120f7
5700c07
743990e
8f120f7
5700c07
 
 
 
 
066da6e
 
 
 
07210bd
5da808a
 
066da6e
 
5da808a
 
066da6e
 
 
07210bd
5700c07
066da6e
 
 
 
 
4e60ee7
e6d5216
b453200
e6d5216
 
 
 
 
066da6e
4e60ee7
066da6e
e6d5216
add60ac
5700c07
066da6e
5700c07
 
 
 
e6d5216
5700c07
 
 
 
 
 
e6d5216
5700c07
e6d5216
 
5700c07
 
 
 
 
 
066da6e
5700c07
e6d5216
 
 
066da6e
e6d5216
066da6e
e6d5216
066da6e
 
 
e6d5216
066da6e
 
 
 
e6d5216
 
 
066da6e
e6d5216
066da6e
e6d5216
066da6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5700c07
066da6e
 
8f120f7
743990e
6cea45f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import numpy as np
import gradio as gr
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration

# 1) Load the embedding model for retrieval
embedder = SentenceTransformer('all-MiniLM-L6-v2')

# 2) Load the T5 model and tokenizer for translation
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')

# 3) Load JSON knowledge bases
def load_docs(path, fmt):
    with open(path, 'r') as f:
        data = json.load(f)
    return [fmt(item) for item in data]

call_docs = load_docs(
    'call_center_data.json',
    lambda d: f"Ticket {d['ticket_id']} ({d['customer']}): {d['description']}. Resolution: {d['resolution']}"
)
truck_docs = load_docs(
    'trucking_data.json',
    lambda d: f"Shipment {d['shipment_id']} from {d['origin']} to {d['destination']} is {d['status']}. Last update: {d['last_update']}"
)
docs = call_docs + truck_docs

# 4) Pre-compute embeddings for all documents
doc_embeddings = embedder.encode(
    docs,
    convert_to_numpy=True,
    normalize_embeddings=True
)

def rag_translate(query, top_k):
    # 4.1 Retrieval step
    q_emb = embedder.encode(
        [query],
        convert_to_numpy=True,
        normalize_embeddings=True
    )[0]
    sims = doc_embeddings @ q_emb
    idxs = np.argsort(sims)[::-1][:top_k]
    retrieved = [docs[i] for i in idxs]
    retrieved_text = "\n".join(retrieved)

    # 4.2 Translation step using local t5-small
    try:
        # Prepare input for translation
        spanish_input = f"translate English to Spanish: {retrieved_text}"
        french_input = f"translate English to French: {retrieved_text}"

        # Translate to Spanish
        spanish_inputs = tokenizer(spanish_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        spanish_outputs = model.generate(
            spanish_inputs["input_ids"],
            max_length=512,
            num_beams=4,
            early_stopping=True
        )
        spanish_text = tokenizer.decode(spanish_outputs[0], skip_special_tokens=True)

        # Translate to French
        french_inputs = tokenizer(french_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        french_outputs = model.generate(
            french_inputs["input_ids"],
            max_length=512,
            num_beams=4,
            early_stopping=True
        )
        french_text = tokenizer.decode(french_outputs[0], skip_special_tokens=True)

        # Combine translations
        translations = f"Spanish:\n{spanish_text}\n\nFrench:\n{french_text}"
    except Exception as e:
        translations = f"Error during translation: {e}"

    return retrieved_text, translations

# 5) Build the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## 🚀 Mini-RAG + Translation Demo")
    with gr.Row():
        with gr.Column(scale=3):
            query = gr.Textbox(lines=2, placeholder="Ask me anything…", label="Your question")
            top_k = gr.Slider(minimum=1, maximum=len(docs), value=1, step=1, label="Number of docs to retrieve")
            submit = gr.Button("🔍 Retrieve & Translate")
            retrieved_output = gr.Textbox(label="Retrieved Content (English)")
            translated_output = gr.Textbox(label="Translated Content (Spanish & French)")
            submit.click(
                fn=rag_translate,
                inputs=[query, top_k],
                outputs=[retrieved_output, translated_output]
            )
        with gr.Column(scale=1):
            gr.Markdown("### Example Queries")
            examples = [
                "What is the status of shipment S-2003?",
                "Where did shipment S-2002 originate?",
                "How many tickets mention payment errors?",
                "What resolution was given for T-1001?"
            ]
            for ex in examples:
                btn = gr.Button(ex)
                btn.click(lambda x=ex: x, None, query)
    gr.Markdown(
        """
        **How it works**  
        1) Retrieval: fetch top-k snippets from JSON KBs.  
        2) Translation: translate retrieved content to Spanish and French using t5-small locally.  
        """
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")