Spaces:
Sleeping
Sleeping
import gradio as gr | |
import faiss | |
import numpy as np | |
from rank_bm25 import BM25Okapi | |
from transformers import AutoTokenizer, AutoModel | |
from litellm import completion | |
import os | |
import torch | |
from sentence_transformers import CrossEncoder | |
os.environ['GROQ_API_KEY'] = "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" | |
PROMPT = """/ | |
You are a virtual representative of a retail company and a consultant for customers. | |
To generate answers, use only information from the context! | |
Do not ask additional questions, but simply offer the product available in the context! | |
Your goal is to answer customers' questions, thus helping them. | |
You should advise the customer in choosing products using the context. | |
If you could not find a specific answer: | |
- Answer "I do not know. For more information, please contact: +380954673526" and nothing more. | |
You always maintain a polite, professional tone. The format of the answer should be simple, understandable and clear. Avoid long explanations if they are not necessary. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2") | |
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2") | |
reranker_model = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-6") | |
def load_documents(file_paths): | |
documents = [] | |
for path in file_paths: | |
with open(path, 'r', encoding='utf-8') as file: | |
documents.append(file.read().strip()) | |
return documents | |
def load_documents_with_chunking(file_paths, chunk_size=500): | |
documents = [] | |
for path in file_paths: | |
with open(path, 'r', encoding='utf-8') as file: | |
text = file.read().strip() | |
for i in range(0, len(text), chunk_size): | |
chunk = text[i:i + chunk_size] | |
documents.append(chunk) | |
return documents | |
class Retriver: | |
def __init__(self, documents, tokenizer, model): | |
self.documents = documents | |
self.bm25 = BM25Okapi([doc.split() for doc in documents]) | |
self.tokenizer = tokenizer | |
self.model = model | |
self.index = self.create_faiss_index() | |
def create_faiss_index(self): | |
embeddings = self.embed_documents(self.documents) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
return index | |
def embed_documents(self, docs): | |
tokens = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
embeddings = self.model(**tokens).last_hidden_state.mean(dim=1).numpy() | |
return embeddings | |
def search_bm25(self, query, top_k=5): | |
query_terms = query.split() | |
scores = self.bm25.get_scores(query_terms) | |
top_indices = np.argsort(scores)[::-1][:top_k] | |
return [self.documents[i] for i in top_indices] | |
def search_semantic(self, query, top_k=5): | |
query_embedding = self.embed_documents([query]) | |
distances, indices = self.index.search(query_embedding, top_k) | |
return [self.documents[i] for i in indices[0]] | |
class Reranker: | |
def __init__(self, reranker): | |
self.model = reranker | |
def rank(self, query, documents): | |
pairs = [(query, doc) for doc in documents] | |
scores = self.model.predict(pairs) | |
ranked_docs = [documents[i] for i in np.argsort(scores)[::-1]] | |
return ranked_docs | |
class QAChatbot: | |
def __init__(self, indexer, reranker): | |
self.indexer = indexer | |
self.reranker = reranker | |
def generate_answer(self, query): | |
bm25_results = self.indexer.search_bm25(query) | |
semantic_results = self.indexer.search_semantic(query) | |
combined_results = list(set(bm25_results + semantic_results)) | |
ranked_docs = self.reranker.rank(query, combined_results) | |
context = "\n".join(ranked_docs[:3]) | |
response = completion( | |
model="groq/llama3-8b-8192", | |
messages=[ | |
{ | |
"role": "system", | |
"content": PROMPT | |
}, | |
{ | |
"role": "user", | |
"content": f"Context: {context}\n\nQuestion: {query}\nAnswer:", | |
} | |
], | |
) | |
return response | |
def chatbot_interface(query, history): | |
# file_paths = ["Company_eng.txt", "base_eng.txt"] | |
# documents = load_documents(file_paths) | |
# indexer = Retriver(documents, tokenizer, model) | |
# reranker = Reranker(reranker_model) | |
#chatbot = QAChatbot(indexer, reranker) | |
answer = chatbot.generate_answer(query) | |
return answer["choices"][0]["message"]["content"] | |
iface = gr.ChatInterface(fn=chatbot_interface, type="messages") | |
if __name__ == "__main__": | |
file_paths = ["Company_eng.txt", "base_eng.txt"] | |
documents = load_documents(file_paths) | |
indexer = Retriver(documents, tokenizer, model) | |
reranker = Reranker(reranker_model) | |
chatbot = QAChatbot(indexer, reranker) | |
iface.launch() | |