| import os |
| from pydantic import BaseModel |
| from tqdm import tqdm |
| import json |
| import uuid |
| import time |
| import redis |
| import pandas as pd |
| from openai import OpenAI |
| from langchain.embeddings.base import Embeddings |
| from langchain_core.documents import Document |
| from langchain_milvus import Milvus, BM25BuiltInFunction |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever |
| from langchain_core.stores import InMemoryStore |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
|
|
| |
| |
| |
|
|
| def get_redis_client(): |
| |
| pool = redis.ConnectionPool(host='0.0.0.0', port=6379, db=0, password=None, max_connections=10) |
| r = redis.StrictRedis(connection_pool=pool) |
|
|
| |
| try: |
| r.ping() |
| print("成功连接到 Redis !") |
| except redis.ConnectionError: |
| print("无法连接到 Redis !") |
|
|
| return r |
|
|
|
|
| |
| def cache_set(r, question: str, answer: str): |
| r.hset("qa", question, answer) |
| r.expire("qa", 3600) |
|
|
|
|
| |
| def cache_get(r, question: str): |
| return r.hget("qa", question) |
|
|
|
|
| |
| |
| |
|
|
| class OpenAIEmbeddings(Embeddings): |
| """基于 OpenAI Embedding API 的自定义嵌入类""" |
|
|
| def __init__(self): |
| self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
|
| def embed_documents(self, texts): |
| embeddings = [] |
| for text in texts: |
| response = self.client.embeddings.create( |
| model="text-embedding-3-small", |
| input=[text], |
| ) |
| embeddings.append(response.data[0].embedding) |
| return embeddings |
|
|
| def embed_query(self, text): |
| |
| return self.embed_documents([text])[0] |
|
|
|
|
| |
| |
| |
|
|
| class Milvus_vector(): |
| def __init__(self, uri="./milvus_agent.db"): |
| self.URI = uri |
| self.embeddings = OpenAIEmbeddings() |
|
|
| |
| self.dense_index = { |
| "metric_type": "IP", |
| "index_type": "IVF_FLAT", |
| } |
| self.sparse_index = { |
| "metric_type": "BM25", |
| "index_type": "SPARSE_INVERTED_INDEX" |
| } |
|
|
| def create_vector_store(self, docs): |
| init_docs = docs[:10] |
| self.vectorstore = Milvus.from_documents( |
| documents=init_docs, |
| embedding=self.embeddings, |
| builtin_function=BM25BuiltInFunction(), |
| index_params=[self.dense_index, self.sparse_index], |
| vector_field=["dense", "sparse"], |
| connection_args={ |
| "uri": self.URI, |
| }, |
| |
| consistency_level="Bounded", |
| drop_old=False, |
| ) |
| print("已初始化创建 Milvus ‼") |
|
|
| count = 10 |
| temp = [] |
| for doc in tqdm(docs[10:]): |
| temp.append(doc) |
| if len(temp) >= 5: |
| self.vectorstore.aadd_documents(temp) |
| count += len(temp) |
| temp = [] |
| print(f"已插入 {count} 条数据......") |
| time.sleep(1) |
|
|
| print(f"总共插入 {count} 条数据......") |
| print("已创建 Milvus 索引完成 ‼") |
|
|
| return self.vectorstore |
|
|
|
|
| |
| |
| |
|
|
| class Pdf_retriever(): |
| def __init__(self, uri="./pdf_agent.db"): |
| self.URI = uri |
| self.embeddings = OpenAIEmbeddings() |
|
|
| |
| self.dense_index = { |
| "metric_type": "IP", |
| "index_type": "IVF_FLAT", |
| } |
| self.sparse_index = { |
| "metric_type": "BM25", |
| "index_type": "SPARSE_INVERTED_INDEX" |
| } |
|
|
| self.docstore = InMemoryStore() |
|
|
| |
| self.child_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=200, |
| chunk_overlap=50, |
| length_function=len, |
| separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""] |
| ) |
| self.parent_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200 |
| ) |
|
|
| def create_pdf_vector_store(self, docs): |
| self.milvus_vectorstore = Milvus( |
| embedding_function=self.embeddings, |
| builtin_function=BM25BuiltInFunction(), |
| vector_field=["dense", "sparse"], |
| index_params=[ |
| { |
| "metric_type": "IP", |
| "index_type": "IVF_FLAT", |
| }, |
| { |
| "metric_type": "BM25", |
| "index_type": "SPARSE_INVERTED_INDEX" |
| } |
| ], |
| connection_args={"uri": self.URI}, |
| consistency_level="Bounded", |
| drop_old=False, |
| ) |
|
|
| |
| self.retriever = ParentDocumentRetriever( |
| vectorstore=self.milvus_vectorstore, |
| docstore=self.docstore, |
| child_splitter=self.child_splitter, |
| parent_splitter=self.parent_splitter, |
| ) |
|
|
| |
| count = 0 |
| temp = [] |
| for doc in tqdm(docs): |
| temp.append(doc) |
| if len(temp) >= 10: |
| |
| self.retriever.add_documents(temp) |
| count += len(temp) |
| temp = [] |
| print(f"已插入 {count} 条数据......") |
| time.sleep(1) |
|
|
| print(f"总共插入 {count} 条数据......") |
| print("基于PDF文档数据的 Milvus 索引完成 ‼") |
|
|
| return self.retriever |
|
|
|
|
| |
| |
| |
|
|
| def prepare_document(file_path=['./data/dialog.jsonl', './data/train.jsonl']): |
| |
| file_path1 = file_path[0] |
|
|
| count = 0 |
| docs = [] |
|
|
| with open(file_path1, 'r', encoding='utf-8') as f: |
| for line in f: |
| content = json.loads(line.strip()) |
| prompt = content['query'] + "\n" + content['response'] |
|
|
| temp_doc = Document(page_content=prompt, metadata={"doc_id": str(uuid.uuid4())}) |
| docs.append(temp_doc) |
|
|
| count += 1 |
|
|
| print(f"已加载 {count} 条数据!") |
|
|
| return docs |
|
|
|
|
| |
| |
| |
|
|
| def prepare_pdf_document(file_path="./pdf_output/pdf_detailed_text.xlsx"): |
| df = pd.read_excel(file_path) |
|
|
| |
| df = df.dropna(subset=['text_content']) |
|
|
| |
| documents = [] |
| for _, row in df.iterrows(): |
| |
| text_content = str(row['text_content']) if pd.notna(row['text_content']) else "" |
|
|
| doc = Document( |
| page_content=text_content.strip(), |
| metadata={"doc_id": str(uuid.uuid4())} |
| ) |
| documents.append(doc) |
|
|
| print(f"成功加载 {len(documents)} 个文档") |
|
|
| return documents |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| ''' |
| # 预处理即将插入 Milvus 的文档数据 |
| docs = prepare_document() |
| print("预处理文档数据成功......") |
| |
| # 创建 Milvus 连接 |
| milvus_vectorstore = Milvus_vector() |
| print("创建Milvus连接成功......") |
| |
| # 创建向量索引 |
| vectorstore = milvus_vectorstore.create_vector_store(docs) |
| print("全部初始化完成, 可以开始问答了......") |
| ''' |
| '''' |
| # 将 PDF 后处理文档中的数据, 封装成Document |
| docs = prepare_pdf_document() |
| print("预处理 PDF 文档数据成功......") |
| # print(docs[0]) |
| |
| pdf_vectorstore = Pdf_retriever() |
| print("创建 PDF Milvus 连接成功......") |
| |
| retriever = pdf_vectorstore.create_pdf_vector_store(docs) |
| print("创建基于 Milvus 数据库的父子文档检索器成功......") |
| print(retriever) |
| ''' |
| r = get_redis_client() |
| print("创建Redis连接成功......") |
| print(r) |
|
|
| print("全部初始化完成, 可以开始问答了......") |