1177_chatbot / RAG_class.py
hassano94's picture
Upload RAG_class.py
c02fb77 verified
raw
history blame
2.58 kB
#import recurive textsplitter
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chromadb
import uuid
import os
class RAG_1177:
def __init__(self):
self.db_name = "RAG_1177"
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2500,chunk_overlap=500,length_function=len)
self.model = SentenceTransformer('KBLab/sentence-bert-swedish-cased')
self.client = chromadb.PersistentClient(path="RAG_1177_db")
self.db = self.client.get_or_create_collection(self.db_name)
self.url_list_path = "all_urls_list.txt"
self.text_folder = "scraped_texts/"
def chunk_text_file(self, file_name):
file_name = self.text_folder + file_name
with open(file_name, 'r', encoding='utf-8') as f:
text = f.read()
chunks = self.text_splitter.create_documents([text])
#append chunks as elements in a list
chunks = [chunk.page_content for chunk in chunks]
return chunks
def get_file_names(self, folder_path):
doc_list = os.listdir(folder_path)
doc_list = sorted(doc_list, key=lambda x: int(x.split('-')[-1].split('.')[0]))
return doc_list
def get_embeddings(self, text):
embeddings = self.model.encode(text)
return (embeddings.tolist())
def get_url(self, url_index):
with open(self.url_list_path, 'r') as f:
urls = f.readlines()
return urls[url_index].strip()
def get_ids(self, num_ids):
ids = [str(uuid.uuid4()) for _ in range(num_ids)]
return ids
def get_url_dict(self, url, integer):
url_list = [{"url": url} for _ in range(integer)]
return url_list
def delete_collection(self):
self.client.delete_collection(self.db_name)
return
def retrieve(self, query, num_results):
query_emb = self.get_embeddings(query)
result = self.db.query(query_embeddings=query_emb, n_results=num_results, include=['documents', 'metadatas'])
result_urls = result['metadatas'][0]
result_docs = result['documents'][0]
url_list = set([item['url'] for item in result_urls])
result_urls = "Läs mer på:\n"
for i, url in enumerate(url_list, start=1):
result_urls += f"{i}: {url}\n"
return result_docs, result_urls
def insert(self,docs, emb, urls, ids):
self.db.add(documents=docs, embeddings=emb, metadatas=urls, ids=ids)
return