Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
# Load pre-trained multilingual model for retrieval and generation with trust_remote_code=True | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") | |
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="faiss", trust_remote_code=True) | |
# Set up FAISS for multilingual document retrieval | |
def setup_faiss(): | |
# Load multilingual embeddings for documents (e.g., using LaBSE or multilingual BERT) | |
model_embed = SentenceTransformer('sentence-transformers/LaBSE') | |
# Example multilingual documents | |
docs = [ | |
"How to learn programming?", | |
"Comment apprendre la programmation?", | |
"پروگرامنگ سیکھنے کا طریقہ کیا ہے؟" | |
] | |
embeddings = model_embed.encode(docs, convert_to_tensor=True) | |
faiss_index = faiss.IndexFlatL2(embeddings.shape[1]) | |
faiss_index.add(np.array(embeddings)) | |
return faiss_index, docs | |
# Set up FAISS index | |
faiss_index, docs = setup_faiss() | |
# Retrieve documents based on query | |
def retrieve_docs(query): | |
# Embed the query | |
query_embedding = SentenceTransformer('sentence-transformers/LaBSE').encode([query], convert_to_tensor=True) | |
# Perform retrieval using FAISS | |
D, I = faiss_index.search(np.array(query_embedding), 1) | |
# Get the most relevant document | |
return docs[I[0][0]] | |
# Handle question-answering | |
def answer_question(query): | |
# Retrieve relevant document | |
retrieved_doc = retrieve_docs(query) | |
# Tokenize the input | |
inputs = tokenizer(query, retrieved_doc, return_tensors="pt", padding=True, truncation=True) | |
# Generate an answer | |
generated = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) | |
# Decode the answer | |
answer = tokenizer.decode(generated[0], skip_special_tokens=True) | |
return answer | |
# Streamlit interface for user input | |
st.title("Multilingual RAG Translator/Answer Bot") | |
st.write("Ask a question in your preferred language (Urdu, French, Hindi)") | |
query = st.text_input("Enter your question:") | |
if query: | |
answer = answer_question(query) | |
st.write(f"Answer: {answer}") | |