JohnSmith9982 commited on
Commit
f1812bf
1 Parent(s): 102761b

Delete modules/llama_func.py

Browse files
Files changed (1) hide show
  1. modules/llama_func.py +0 -166
modules/llama_func.py DELETED
@@ -1,166 +0,0 @@
1
- import os
2
- import logging
3
-
4
- from llama_index import download_loader
5
- from llama_index import (
6
- Document,
7
- LLMPredictor,
8
- PromptHelper,
9
- QuestionAnswerPrompt,
10
- RefinePrompt,
11
- )
12
- import colorama
13
- import PyPDF2
14
- from tqdm import tqdm
15
-
16
- from modules.presets import *
17
- from modules.utils import *
18
- from modules.config import local_embedding
19
-
20
-
21
- def get_index_name(file_src):
22
- file_paths = [x.name for x in file_src]
23
- file_paths.sort(key=lambda x: os.path.basename(x))
24
-
25
- md5_hash = hashlib.md5()
26
- for file_path in file_paths:
27
- with open(file_path, "rb") as f:
28
- while chunk := f.read(8192):
29
- md5_hash.update(chunk)
30
-
31
- return md5_hash.hexdigest()
32
-
33
-
34
- def block_split(text):
35
- blocks = []
36
- while len(text) > 0:
37
- blocks.append(Document(text[:1000]))
38
- text = text[1000:]
39
- return blocks
40
-
41
-
42
- def get_documents(file_src):
43
- documents = []
44
- logging.debug("Loading documents...")
45
- logging.debug(f"file_src: {file_src}")
46
- for file in file_src:
47
- filepath = file.name
48
- filename = os.path.basename(filepath)
49
- file_type = os.path.splitext(filepath)[1]
50
- logging.info(f"loading file: {filename}")
51
- try:
52
- if file_type == ".pdf":
53
- logging.debug("Loading PDF...")
54
- try:
55
- from modules.pdf_func import parse_pdf
56
- from modules.config import advance_docs
57
-
58
- two_column = advance_docs["pdf"].get("two_column", False)
59
- pdftext = parse_pdf(filepath, two_column).text
60
- except:
61
- pdftext = ""
62
- with open(filepath, "rb") as pdfFileObj:
63
- pdfReader = PyPDF2.PdfReader(pdfFileObj)
64
- for page in tqdm(pdfReader.pages):
65
- pdftext += page.extract_text()
66
- text_raw = pdftext
67
- elif file_type == ".docx":
68
- logging.debug("Loading Word...")
69
- DocxReader = download_loader("DocxReader")
70
- loader = DocxReader()
71
- text_raw = loader.load_data(file=filepath)[0].text
72
- elif file_type == ".epub":
73
- logging.debug("Loading EPUB...")
74
- EpubReader = download_loader("EpubReader")
75
- loader = EpubReader()
76
- text_raw = loader.load_data(file=filepath)[0].text
77
- elif file_type == ".xlsx":
78
- logging.debug("Loading Excel...")
79
- text_list = excel_to_string(filepath)
80
- for elem in text_list:
81
- documents.append(Document(elem))
82
- continue
83
- else:
84
- logging.debug("Loading text file...")
85
- with open(filepath, "r", encoding="utf-8") as f:
86
- text_raw = f.read()
87
- except Exception as e:
88
- logging.error(f"Error loading file: {filename}")
89
- pass
90
- text = add_space(text_raw)
91
- # text = block_split(text)
92
- # documents += text
93
- documents += [Document(text)]
94
- logging.debug("Documents loaded.")
95
- return documents
96
-
97
-
98
- def construct_index(
99
- api_key,
100
- file_src,
101
- max_input_size=4096,
102
- num_outputs=5,
103
- max_chunk_overlap=20,
104
- chunk_size_limit=600,
105
- embedding_limit=None,
106
- separator=" ",
107
- ):
108
- from langchain.chat_models import ChatOpenAI
109
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
110
- from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
111
-
112
- if api_key:
113
- os.environ["OPENAI_API_KEY"] = api_key
114
- else:
115
- # 由于一个依赖的愚蠢的设计,这里必须要有一个API KEY
116
- os.environ["OPENAI_API_KEY"] = "sk-xxxxxxx"
117
- chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
118
- embedding_limit = None if embedding_limit == 0 else embedding_limit
119
- separator = " " if separator == "" else separator
120
-
121
- prompt_helper = PromptHelper(
122
- max_input_size=max_input_size,
123
- num_output=num_outputs,
124
- max_chunk_overlap=max_chunk_overlap,
125
- embedding_limit=embedding_limit,
126
- chunk_size_limit=600,
127
- separator=separator,
128
- )
129
- index_name = get_index_name(file_src)
130
- if os.path.exists(f"./index/{index_name}.json"):
131
- logging.info("找到了缓存的索引文件,加载中……")
132
- return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
133
- else:
134
- try:
135
- documents = get_documents(file_src)
136
- if local_embedding:
137
- embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2"))
138
- else:
139
- embed_model = OpenAIEmbedding()
140
- logging.info("构建索引中……")
141
- with retrieve_proxy():
142
- service_context = ServiceContext.from_defaults(
143
- prompt_helper=prompt_helper,
144
- chunk_size_limit=chunk_size_limit,
145
- embed_model=embed_model,
146
- )
147
- index = GPTSimpleVectorIndex.from_documents(
148
- documents, service_context=service_context
149
- )
150
- logging.debug("索引构建完成!")
151
- os.makedirs("./index", exist_ok=True)
152
- index.save_to_disk(f"./index/{index_name}.json")
153
- logging.debug("索引已保存至本地!")
154
- return index
155
-
156
- except Exception as e:
157
- logging.error("索引构建失败!", e)
158
- print(e)
159
- return None
160
-
161
-
162
- def add_space(text):
163
- punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
164
- for cn_punc, en_punc in punctuations.items():
165
- text = text.replace(cn_punc, en_punc)
166
- return text