WordsApp / create_db.py
dht-tb16p
Commit 1st version
e60c070
raw
history blame
5.19 kB
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import tempfile
import config
import nltk
from typing import List
from nltk.corpus import words
from loguru import logger
from llm.call_llm import get_completion_from_messages
from embedding.call_embedding import get_embedding
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.vectorstores import Chroma
from prompts import system_message_select
WORDS_DB_PATH = "../words_db"
VECTOR_DB_PATH = "./vector_db/chroma"
def parse_file(file_path):
docs = []
# check file type
file_type = file_path.split('.')[-1]
if file_type == 'pdf':
loader = PyMuPDFLoader(file_path)
content = loader.load()
docs.extend(content)
else:
return "File type not supported"
if len(docs) > 5:
return "File too large, please select a pdf file with less than 5 pages"
slices = split_text(docs) # split content into slices
words = extract_words(slices) # extract words from slices
try:
vectorize_words(words) # store words into vector database
except Exception as e:
logger.error(e)
return ""
def parse_text(input: str):
content = input
return content
def split_text(docs: List[object]):
"""Split text into slices"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 1500,
chunk_overlap = 150
)
splits = text_splitter.split_documents(docs)
logger.info(f"Split {len(docs)} pages document into {len(splits)} slices")
return splits
def extract_words(splits: List[object]):
"""Extract words from slices"""
all_words = []
for slice in splits:
tmp_content = slice.page_content
messages = [
{'role':'system',
'content': system_message_select},
{'role':'user',
'content': f"{tmp_content}"},
]
respond = get_completion_from_messages(messages, api_key=config.api_key)
words_list = respond.split(", ")
if len(words_list) == 0:
continue
else:
all_words.extend(words_list)
all_words = wash_words(all_words)
logger.info(f"Extract {len(all_words)} words from slices")
return all_words
def wash_words(input_words: list[str]):
"""Wash words into a list of correct english words"""
words_list = [word for word in input_words
if len(word) >= 3 and len(word) <= 30]
nltk.download('words')
english_words = set(words.words())
filtered_words = [word.lower() for word in words_list if word.lower() in english_words]
filtered_words = list(set(filtered_words))
logger.info(f"Wash {len(filtered_words)} words into a list of correct english words")
return filtered_words
def get_words_from_text(input: str):
words = input.split(' ')
return words
def store_words(input: str, db_path=WORDS_DB_PATH):
"""Store words into database"""
pass
def vectorize_words(input: list[str], embedding=None):
"""Vectorize words into vectors"""
model = get_embedding("openai", embedding_key=config.api_key)
persist_path = VECTOR_DB_PATH
vectordb = Chroma.from_texts(
texts=input,
embedding=model,
persist_directory=persist_path
)
vectordb.persist()
logger.info(f"Vectorized {len(input)} words into vectors")
return vectordb
def get_similar_k_words(query_word, k=3) -> List[str]:
# get 3 simlilar words from DB
model = get_embedding("openai", embedding_key=config.api_key)
vectordb = Chroma(persist_directory=VECTOR_DB_PATH, embedding_function=model)
similar_words = vectordb.max_marginal_relevance_search(query_word, k=k)
similar_words = [word.page_content for word in similar_words]
logger.info(f"Get {k} similar words {similar_words} from DB")
return similar_words
def create_db(input, chat_history):
"""The input is file or text"""
action_msg = "" # the description of user action: put file or text into database
# 1. for file upload
if isinstance(input, tempfile._TemporaryFileWrapper):
tmp_file_path = input.name
file_name = tmp_file_path.split('/')[-1]
action_msg = f"Add words from my file: {file_name} to database"
try:
parse_file(tmp_file_path) #TODO
output = f"Words from your file: {file_name} has been added to database"
except Exception as e:
logger.error(e)
output = f"Error: failed to use your file: {file_name} generate dictionary"
# 2. for text input
elif isinstance(input, str):
action_msg = f"Add words from my text: {input} to database"
try:
parse_text(input) #TODO
output = f"Words from your text: {input} has been added to database"
except Exception as e:
logger.error(e)
output = f"Error: failed to use your text: {input} generate dictionary"
chat_history.append((action_msg, output))
return chat_history
if __name__ == "__main__":
create_db(embeddings="m3e")