Rag_proj / app.py
sgt444pepper's picture
Update app.py
426c844 verified
import gradio as gr
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
from litellm import completion
import os
import torch
from sentence_transformers import CrossEncoder
os.environ['GROQ_API_KEY'] = "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"
PROMPT = """/
You are a virtual representative of a retail company and a consultant for customers.
To generate answers, use only information from the context!
Do not ask additional questions, but simply offer the product available in the context!
Your goal is to answer customers' questions, thus helping them.
You should advise the customer in choosing products using the context.
If you could not find a specific answer:
- Answer "I do not know. For more information, please contact: +380954673526" and nothing more.
You always maintain a polite, professional tone. The format of the answer should be simple, understandable and clear. Avoid long explanations if they are not necessary.
"""
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
reranker_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-6")
def load_documents(file_paths):
documents = []
for path in file_paths:
with open(path, 'r', encoding='utf-8') as file:
documents.append(file.read().strip())
return documents
def load_documents_with_chunking(file_paths, chunk_size=500):
documents = []
for path in file_paths:
with open(path, 'r', encoding='utf-8') as file:
text = file.read().strip()
for i in range(0, len(text), chunk_size):
chunk = text[i:i + chunk_size]
documents.append(chunk)
return documents
class Retriver:
def __init__(self, documents, tokenizer, model):
self.documents = documents
self.bm25 = BM25Okapi([doc.split() for doc in documents])
self.tokenizer = tokenizer
self.model = model
self.index = self.create_faiss_index()
def create_faiss_index(self):
embeddings = self.embed_documents(self.documents)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
def embed_documents(self, docs):
tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy()
return embeddings
def search_bm25(self, query, top_k=5):
query_terms = query.split()
scores = self.bm25.get_scores(query_terms)
top_indices = np.argsort(scores)[::-1][:top_k]
return [self.documents[i] for i in top_indices]
def search_semantic(self, query, top_k=5):
query_embedding = self.embed_documents([query])
distances, indices = self.index.search(query_embedding, top_k)
return [self.documents[i] for i in indices[0]]
class Reranker:
def __init__(self, reranker):
self.model = reranker
def rank(self, query, documents):
pairs = [(query, doc) for doc in documents]
scores = self.model.predict(pairs)
ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]]
return ranked_docs
class QAChatbot:
def __init__(self, indexer, reranker):
self.indexer = indexer
self.reranker = reranker
def generate_answer(self, query):
bm25_results = self.indexer.search_bm25(query)
semantic_results = self.indexer.search_semantic(query)
combined_results = list(set(bm25_results + semantic_results))
ranked_docs = self.reranker.rank(query, combined_results)
context = "\n".join(ranked_docs[:3])
response = completion(
model="groq/llama3-8b-8192",
messages=[
{
"role": "system",
"content": PROMPT
},
{
"role": "user",
"content": f"Context: {context}\n\nQuestion: {query}\nAnswer:",
}
],
)
return response
def chatbot_interface(query, history):
# file_paths = ["Company_eng.txt", "base_eng.txt"]
# documents = load_documents(file_paths)
# indexer = Retriver(documents, tokenizer, model)
# reranker = Reranker(reranker_model)
#chatbot = QAChatbot(indexer, reranker)
answer = chatbot.generate_answer(query)
return answer["choices"][0]["message"]["content"]
iface = gr.ChatInterface(fn=chatbot_interface, type="messages")
if __name__ == "__main__":
file_paths = ["Company_eng.txt", "base_eng.txt"]
documents = load_documents(file_paths)
indexer = Retriver(documents, tokenizer, model)
reranker = Reranker(reranker_model)
chatbot = QAChatbot(indexer, reranker)
iface.launch()