llm_knowledge_base / langchain_KB.py
allinaigc's picture
Upload 2 files
e418d71 verified
raw
history blame
No virus
3.76 kB
"""
"""
# -*- coding: utf-8 -*-
import requests
import streamlit as st
import openai
# from openai import embeddings
import os
from dotenv import load_dotenv
import numpy as np
import pandas as pd
import csv
import tempfile
from tempfile import NamedTemporaryFile
import pathlib
from pathlib import Path
import re
from re import sub
import time
from time import sleep
# import pretty_errors
import warnings
import PyPDF2
from openai import OpenAI
client = OpenAI()
warnings.filterwarnings('ignore')
''' 以下加载本地知识的核心内容。'''
##! Install package, !pip install "unstructured[all-docs]", 需要完成这一步,否则会报错!
# from langchain.document_loaders import UnstructuredFileLoader ## older version.
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader ## new version.
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
load_dotenv()
### 设置openai的API key
os.environ["OPENAI_API_KEY"] = os.environ['user_token']
openai.api_key = os.environ['user_token']
# filepath = "/Users/yunshi/Downloads/txt_dir/Sparks_of_AGI.pdf"
def langchain_localKB_construct(filepath, username):
print('开始构建Langchain知识库...')
# ''' 以下加载本地知识的核心内容。'''
##! Install package, !pip install "unstructured[all-docs]", 需要完成这一步,否则会报错!
## 加载文件
# filepath = "/Users/yunshi/Downloads/txt_dir/Sparks_of_AGI.pdf" ## a sample reference of a local PDF file.
# loader = UnstructuredFileLoader(filepath)
# from pathlib import Path
# filepath = Path(filepath)
print('now filepath:', filepath.name)
# loader = UnstructuredFileLoader(filepath.name) ### original code here.
loader = PyPDFLoader(filepath.name) ##NOTE: 只有PyPDFLoader才可以提取PDF的页数page信息。
# print('langchain loader:',loader)
docs = loader.load()
# print('docs now:', docs)
## 文本分割
# text_splitter = CharacterTextSplitter(chunk_size=5000, chunk_overlap=200)
docs = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200).split_documents(docs)
## 创建向量数据库
embedding_model_name = 'BAAI/bge-large-zh-v1.5'
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) ## 这里是联网情况下连接huggingface后使用。
# from langchain.embeddings.openai import OpenAIEmbeddings
# embeddings = OpenAIEmbeddings(disallowed_special=()) ## 可能需要更新了。
# print('langchain embeddings:', embeddings)
vector_store = FAISS.from_documents(docs, embeddings)
# print(vector_store)
vector_store.save_local(f'./{username}/faiss_index')
return vector_store
# vector_store = langchain_KB_construct(filepath='/Users/yunshi/Downloads/txt_dir/Sparks_of_AGI.pdf')
# print(vs)
### 根据prompt来检索本地知识库并回答。
def langchain_RAG(prompt, username):
### 用langchain看框架接入本地知识库。
embeddings = OpenAIEmbeddings(disallowed_special=()) ## load embedding model again here.
vector_store = FAISS.load_local(f'./{username}/faiss_index', embeddings, allow_dangerous_deserialization=True)
docs = vector_store.similarity_search(prompt, k=5)
context = [doc.page_content for doc in docs]
total_prompt = f"已知信息:\n{context}\n 根据这些已知信息来回答问题:\n{prompt}"
# print('total prompt in local KB version:', total_prompt)
return total_prompt, docs
# langchain_RAG('what are main challenges of AGI?')