|
import os |
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
|
import tempfile |
|
import config |
|
import nltk |
|
|
|
from typing import List |
|
from nltk.corpus import words |
|
from loguru import logger |
|
from llm.call_llm import get_completion_from_messages |
|
from embedding.call_embedding import get_embedding |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.document_loaders import PyMuPDFLoader |
|
from langchain.vectorstores import Chroma |
|
|
|
from prompts import system_message_select |
|
|
|
WORDS_DB_PATH = "../words_db" |
|
VECTOR_DB_PATH = "./vector_db/chroma" |
|
|
|
def parse_file(file_path): |
|
docs = [] |
|
|
|
file_type = file_path.split('.')[-1] |
|
if file_type == 'pdf': |
|
loader = PyMuPDFLoader(file_path) |
|
content = loader.load() |
|
docs.extend(content) |
|
else: |
|
return "File type not supported" |
|
if len(docs) > 5: |
|
return "File too large, please select a pdf file with less than 5 pages" |
|
|
|
slices = split_text(docs) |
|
words = extract_words(slices) |
|
try: |
|
vectorize_words(words) |
|
except Exception as e: |
|
logger.error(e) |
|
|
|
return "" |
|
|
|
def parse_text(input: str): |
|
content = input |
|
return content |
|
|
|
def split_text(docs: List[object]): |
|
"""Split text into slices""" |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size = 1500, |
|
chunk_overlap = 150 |
|
) |
|
splits = text_splitter.split_documents(docs) |
|
logger.info(f"Split {len(docs)} pages document into {len(splits)} slices") |
|
return splits |
|
|
|
def extract_words(splits: List[object]): |
|
"""Extract words from slices""" |
|
all_words = [] |
|
for slice in splits: |
|
tmp_content = slice.page_content |
|
messages = [ |
|
{'role':'system', |
|
'content': system_message_select}, |
|
{'role':'user', |
|
'content': f"{tmp_content}"}, |
|
] |
|
respond = get_completion_from_messages(messages, api_key=config.api_key) |
|
words_list = respond.split(", ") |
|
if len(words_list) == 0: |
|
continue |
|
else: |
|
all_words.extend(words_list) |
|
all_words = wash_words(all_words) |
|
logger.info(f"Extract {len(all_words)} words from slices") |
|
return all_words |
|
|
|
def wash_words(input_words: list[str]): |
|
"""Wash words into a list of correct english words""" |
|
words_list = [word for word in input_words |
|
if len(word) >= 3 and len(word) <= 30] |
|
nltk.download('words') |
|
english_words = set(words.words()) |
|
filtered_words = [word.lower() for word in words_list if word.lower() in english_words] |
|
filtered_words = list(set(filtered_words)) |
|
logger.info(f"Wash {len(filtered_words)} words into a list of correct english words") |
|
return filtered_words |
|
|
|
def get_words_from_text(input: str): |
|
words = input.split(' ') |
|
return words |
|
|
|
def store_words(input: str, db_path=WORDS_DB_PATH): |
|
"""Store words into database""" |
|
pass |
|
|
|
def vectorize_words(input: list[str], embedding=None): |
|
"""Vectorize words into vectors""" |
|
model = get_embedding("openai", embedding_key=config.api_key) |
|
persist_path = VECTOR_DB_PATH |
|
vectordb = Chroma.from_texts( |
|
texts=input, |
|
embedding=model, |
|
persist_directory=persist_path |
|
) |
|
vectordb.persist() |
|
logger.info(f"Vectorized {len(input)} words into vectors") |
|
return vectordb |
|
|
|
def get_similar_k_words(query_word, k=3) -> List[str]: |
|
|
|
model = get_embedding("openai", embedding_key=config.api_key) |
|
vectordb = Chroma(persist_directory=VECTOR_DB_PATH, embedding_function=model) |
|
similar_words = vectordb.max_marginal_relevance_search(query_word, k=k) |
|
similar_words = [word.page_content for word in similar_words] |
|
logger.info(f"Get {k} similar words {similar_words} from DB") |
|
return similar_words |
|
|
|
def create_db(input, chat_history): |
|
"""The input is file or text""" |
|
action_msg = "" |
|
|
|
if isinstance(input, tempfile._TemporaryFileWrapper): |
|
tmp_file_path = input.name |
|
file_name = tmp_file_path.split('/')[-1] |
|
action_msg = f"Add words from my file: {file_name} to database" |
|
try: |
|
parse_file(tmp_file_path) |
|
output = f"Words from your file: {file_name} has been added to database" |
|
except Exception as e: |
|
logger.error(e) |
|
output = f"Error: failed to use your file: {file_name} generate dictionary" |
|
|
|
elif isinstance(input, str): |
|
action_msg = f"Add words from my text: {input} to database" |
|
try: |
|
parse_text(input) |
|
output = f"Words from your text: {input} has been added to database" |
|
except Exception as e: |
|
logger.error(e) |
|
output = f"Error: failed to use your text: {input} generate dictionary" |
|
chat_history.append((action_msg, output)) |
|
|
|
return chat_history |
|
|
|
|
|
if __name__ == "__main__": |
|
create_db(embeddings="m3e") |
|
|