SucheRAG / app.py
alexkueck's picture
Update app.py
08e8fb9 verified
raw
history blame
No virus
3.3 kB
import gradio as gr
from langchain.chains import RagChain
from langchain.vectorstores import Chroma
from transformers import RagTokenizer, RagSequenceForGeneration
from sentence_transformers import SentenceTransformer
# Initialisierung des Sentence-BERT Modells für die Embeddings
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialisierung von Tokenizer und RAG Modell
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
# Verbindung zur Chroma DB und Laden der Dokumente
chroma_db = Chroma(embedding_model=embedding_model, persist_directory = PATH_WORK + CHROMA_DIR)
# Erstellen eines eigenen Retrievers mit Chroma DB und Embeddings
retriever = chroma_db.as_retriever()
# Erstellung der RAG-Kette mit dem benutzerdefinierten Retriever
rag_chain = RagChain(model=model, retriever=retriever, tokenizer=tokenizer, vectorstore=chroma_db)
#############################################
def document_retrieval_chroma2():
#HF embeddings -----------------------------------
#Alternative Embedding - für Vektorstore, um Ähnlichkeitsvektoren zu erzeugen - die ...InstructEmbedding ist sehr rechenaufwendig
embeddings = HuggingFaceInstructEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
#etwas weniger rechenaufwendig:
#embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cpu"}, encode_kwargs={'normalize_embeddings': False})
#oder einfach ohne Langchain:
#embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
#ChromaDb um die embedings zu speichern
db = Chroma(embedding_function = embeddings, persist_directory = PATH_WORK + CHROMA_DIR)
print ("Chroma DB bereit ...................")
return db
def get_rag_response(prompt):
global rag_chain
#rag-chain nutzen, um Antwort zu generieren
result = rag_chain({"Frage: " : prompt})
#relevante Dokumente extrahieren
docs = result['docs']
passages = [doc['text'] for doc in docs]
links = doc['url'] for doc in docs
#Antwort generieren
answer = result['output']
response = {
"answer" : answer,
"documents" : [{"link" : link, "passage" : passage} for link, passage in zip(links, passages)]
}
return response
def chatbot_response (user_input, chat_history=[]):
response = get_rag_response(user_input)
answer = response['answer']
documents = response['documents']
doc_links = "\n\n".join([f"Link: {doc['link']} \nAuszüge der Dokumente: {doc['passage']}" for doc in documents])
bot_response = f"{answer} \n\nRelevante Dokumente: \n{doc_links}"
chat_history.append((user_inptu, bot_response))
return chat_history, chat_history
#############################
#GUI.........
def user (user_input, history):
return "", history + [[user_input, None]]
with gr.Blocks() as chatbot:
chat_interface = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Löschen")
#Buttons listener
msg.submit(user, [msg, chat_interface], [msg, chat_interface], queue = False). then(chatbot_response, [msg, chat_interface], [chat_interface, chat_interface])
clear.click(lambda: None, None, chat_interface, queue=False)
chatbot.launch()