Spaces:
Running
Running
import os | |
from typing import List, Tuple, Dict | |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
from langchain_community.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.llms import HuggingFacePipeline | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.prompts import PromptTemplate | |
import gradio as gr | |
import torch | |
class EnhancedRAGSystem: | |
def __init__(self): | |
self.chunk_size = 500 | |
self.chunk_overlap = 50 | |
self.k_documents = 4 | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=self.chunk_size, | |
chunk_overlap=self.chunk_overlap, | |
length_function=len | |
) | |
self.embedding_model_name = "intfloat/multilingual-e5-large" | |
self.llm_model_name = "google/flan-t5-large" | |
self.prompt_template = PromptTemplate( | |
template="""Use the context below to answer the question. | |
If the answer is not in the context, say "I don't have enough information in the context to answer this question." | |
Context: {context} | |
Question: {question} | |
Detailed answer:""", | |
input_variables=["context", "question"] | |
) | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name=self.embedding_model_name, | |
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'} | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.llm_model_name) | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.model.to(self.device) | |
self.pipe = pipeline( | |
"text2text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
max_length=512, | |
device=0 if torch.cuda.is_available() else -1, | |
model_kwargs={"temperature": 0.7} | |
) | |
self.llm = HuggingFacePipeline(pipeline=self.pipe) | |
def process_documents(self, text: str) -> bool: | |
try: | |
texts = self.text_splitter.split_text(text) | |
self.vectorstore = Chroma.from_texts( | |
texts, | |
self.embeddings, | |
metadatas=[{"source": f"chunk_{i}", "text": t} for i, t in enumerate(texts)], | |
collection_name="enhanced_rag_docs" | |
) | |
self.retriever = self.vectorstore.as_retriever( | |
search_kwargs={"k": self.k_documents} | |
) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": self.prompt_template} | |
) | |
return True | |
except Exception as e: | |
print(f"Processing error: {str(e)}") | |
return False | |
def answer_question(self, question: str) -> Tuple[str, str]: | |
try: | |
response = self.qa_chain({"query": question}) | |
answer = response["result"] | |
sources = [] | |
for i, doc in enumerate(response["source_documents"], 1): | |
text_preview = doc.page_content[:100] + "..." | |
sources.append(f"Excerpt {i}: {text_preview}") | |
sources_text = "\n".join(sources) | |
return answer, sources_text | |
except Exception as e: | |
return f"Error answering: {str(e)}", "" | |
def create_enhanced_interface(): | |
rag_system = EnhancedRAGSystem() | |
def process_and_answer(text: str, question: str) -> str: | |
if not text.strip() or not question.strip(): | |
return "Please provide both text and question." | |
if not rag_system.process_documents(text): | |
return "Error processing the text." | |
answer, sources = rag_system.answer_question(question) | |
if sources: | |
return f"""Answer: {answer} | |
Relevant excerpts consulted: | |
{sources}""" | |
return answer | |
# HTML para o cabeçalho | |
custom_css = """ | |
.custom-description { | |
margin-bottom: 20px; | |
text-align: center; | |
} | |
.custom-description a { | |
text-decoration: none; | |
color: #007bff; | |
margin: 0 5px; | |
} | |
.custom-description a:hover { | |
text-decoration: underline; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as interface: | |
gr.HTML(""" | |
<div class="custom-description"> | |
<h1>Advanced RAG with Multilingual Support</h1> | |
<p>Ramon Mayor Martins: | |
<a href="https://rmayormartins.github.io/" target="_blank">Website</a> | | |
<a href="https://huggingface.co/rmayormartins" target="_blank">Spaces</a> | | |
<a href="https://github.com/rmayormartins" target="_blank">GitHub</a> | |
</p> | |
<p>This system uses Retrieval-Augmented Generation (RAG) to answer questions about your texts in multiple languages. | |
Simply paste your text and ask questions in any language!</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Base Text", | |
placeholder="Paste here the text that will serve as knowledge base...", | |
lines=10 | |
) | |
question_input = gr.Textbox( | |
label="Your Question", | |
placeholder="What would you like to know about the text?" | |
) | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
output = gr.Textbox(label="Answer") | |
examples = [ | |
["The Earth is the third planet from the Sun. It has one natural satellite called the Moon. It is the only known planet to harbor life.", | |
"What is Earth's natural satellite?"], | |
["La Tierra es el tercer planeta del Sistema Solar. Tiene un satélite natural llamado Luna. Es el único planeta conocido que alberga vida.", | |
"¿Cuál es el satélite natural de la Tierra?"], | |
["A Terra é o terceiro planeta do Sistema Solar. Tem um satélite natural chamado Lua. É o único planeta conhecido que abriga vida.", | |
"Qual é o satélite natural da Terra?"], | |
["The Sun is a medium-sized star at the center of our Solar System. It provides light and heat to all planets.", | |
"What is the Sun?"], | |
["El Sol es una estrella de tamaño medio en el centro de nuestro Sistema Solar. Proporciona luz y calor a todos los planetas.", | |
"¿Qué es el Sol?"], | |
["O Sol é uma estrela de tamanho médio no centro do nosso Sistema Solar. Ele fornece luz e calor para todos os planetas.", | |
"O que é o Sol?"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[text_input, question_input], | |
outputs=output, | |
fn=process_and_answer, | |
cache_examples=True | |
) | |
submit_btn.click( | |
fn=process_and_answer, | |
inputs=[text_input, question_input], | |
outputs=output | |
) | |
return interface | |
if __name__ == "__main__": | |
demo = create_enhanced_interface() | |
demo.launch() |