FinDoc / build_index /doc2vec.py
xl2533's picture
initial
6c945f2
raw
history blame contribute delete
No virus
1.78 kB
# -*-coding:utf-8 -*-
import os
from tqdm import tqdm
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, CohereEmbeddings
from retry import retry
from key import CoherenceKey, OpenaiKey
# Output Directory for FAISS Index data
OUTPUT_DIR = './output/'
@retry(tries=10, delay=60)
def store_add_texts_with_retry(store, i):
store.add_texts([i.page_content], metadatas=[i.metadata])
def doc2vec(docs, model, folder_name=None):
if folder_name:
dir = os.path.join(OUTPUT_DIR, folder_name)
else:
dir = OUTPUT_DIR
# use first document to init db, 1个1个文件处理避免中间出现问题需要重头尝试
print(f'Building faiss Index from {len(docs)} docs')
docs_test = [docs[0]]
docs.pop(0)
index = 0
print(f'Dumping FAISS to {dir}')
if model =='openai':
key = os.getenv('OPENAI_API_KEY')
db = FAISS.from_documents(docs_test, OpenAIEmbeddings(openai_api_key=key))
elif model =='mpnet':
db = FAISS.from_documents(docs_test, HuggingFaceEmbeddings())
elif model =='cohere':
db = FAISS.from_documents(docs_test, CohereEmbeddings(cohere_api_key=CoherenceKey))
else:
raise ValueError(f'Embedding Model {model} not supported')
for doc in tqdm(docs, desc="Embedding 🦖", unit="docs", total=len(docs),
bar_format='{l_bar}{bar}| Time Left: {remaining}'):
try:
store_add_texts_with_retry(db, doc)
except Exception as e:
print(e)
print("Error on ", doc)
print("Saving progress")
print(f"stopped at {index} out of {len(docs)}")
db.save_local(dir)
break
index += 1
db.save_local(dir)