adaptive_rag / document_processor.py
lanny xu
optimize query speed
c0c60d3
raw
history blame
43 kB
"""
文档处理和向量化模块
负责文档加载、文本分块、向量化和向量数据库初始化
"""
try:
from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
# 尝试导入 langchain_milvus,如果失败则回退到 langchain_community 并应用补丁
try:
from langchain_milvus import MilvusVectorStore as Milvus
print("✅ 使用 langchain-milvus (新版)")
except ImportError:
try:
from langchain_community.vectorstores import Milvus
print("⚠️ 使用 langchain_community.vectorstores.Milvus (旧版)")
# Monkeypatch: 修复旧版 LangChain 对 Milvus Lite 本地文件路径的校验问题
# 旧版 _create_connection_alias 强制要求 URI 以 http/https 开头
def _patched_create_connection_alias(self, connection_args):
uri = connection_args.get("uri")
# 为本地文件生成唯一的 alias
if uri:
import hashlib
return hashlib.md5(uri.encode()).hexdigest()
return "default"
# 应用补丁
Milvus._create_connection_alias = _patched_create_connection_alias
print("🔧 已应用 Milvus Lite 路径校验补丁")
except ImportError:
pass
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from config import (
KNOWLEDGE_BASE_URLS,
CHUNK_SIZE,
CHUNK_OVERLAP,
COLLECTION_NAME,
EMBEDDING_MODEL,
# 混合检索配置
ENABLE_HYBRID_SEARCH,
HYBRID_SEARCH_WEIGHTS,
KEYWORD_SEARCH_K,
BM25_K1,
BM25_B,
# 向量库配置
VECTOR_STORE_TYPE,
MILVUS_HOST,
MILVUS_PORT,
MILVUS_USER,
MILVUS_PASSWORD,
MILVUS_URI,
MILVUS_INDEX_TYPE,
MILVUS_INDEX_PARAMS,
MILVUS_SEARCH_PARAMS,
# 查询扩展配置
ENABLE_QUERY_EXPANSION,
QUERY_EXPANSION_MODEL,
QUERY_EXPANSION_PROMPT,
MAX_EXPANDED_QUERIES,
# 多模态配置
ENABLE_MULTIMODAL,
MULTIMODAL_IMAGE_MODEL,
SUPPORTED_IMAGE_FORMATS,
IMAGE_EMBEDDING_DIM,
MULTIMODAL_WEIGHTS
)
from reranker import create_reranker
# 多模态支持相关导入
import base64
import io
from PIL import Image
import numpy as np
from typing import List, Dict, Any, Optional, Union
try:
from langchain_core.documents import Document
except ImportError:
try:
from langchain_core.documents import Document
except ImportError:
from langchain.schema import Document
class CustomEnsembleRetriever:
"""自定义集成检索器,结合向量检索和BM25检索"""
def __init__(self, retrievers, weights):
self.retrievers = retrievers
self.weights = weights
def invoke(self, query):
"""执行检索并合并结果"""
# 获取各检索器的结果
all_results = []
for i, retriever in enumerate(self.retrievers):
results = retriever.invoke(query)
for doc in results:
# 添加检索器索引和权重信息
doc.metadata["retriever_index"] = i
doc.metadata["retriever_weight"] = self.weights[i]
all_results.append(doc)
return self._process_results(all_results)
async def ainvoke(self, query):
"""异步执行检索并合并结果"""
import asyncio
# 并发获取各检索器的结果
# 注意:假设所有 retriever 都支持 ainvoke
tasks = [retriever.ainvoke(query) for retriever in self.retrievers]
results_list = await asyncio.gather(*tasks)
all_results = []
for i, results in enumerate(results_list):
for doc in results:
# 添加检索器索引和权重信息
doc.metadata["retriever_index"] = i
doc.metadata["retriever_weight"] = self.weights[i]
all_results.append(doc)
return self._process_results(all_results)
def _process_results(self, all_results):
"""排序和去重处理"""
# 根据权重排序并去重
# 简单实现:先按检索器索引排序,再按权重排序
all_results.sort(key=lambda x: (x.metadata["retriever_index"], -x.metadata["retriever_weight"]))
# 去重(基于文档内容)
unique_results = []
seen_content = set()
for doc in all_results:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_results.append(doc)
return unique_results
class DocumentProcessor:
"""文档处理器类,负责文档加载、处理和向量化"""
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP
)
# Try to initialize embeddings with error handling
try:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ 检测到设备: {device}")
if device == 'cuda':
print(f" GPU型号: {torch.cuda.get_device_name(0)}")
print(f" GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
self.embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2", # 轻量级嵌入模型
model_kwargs={'device': device}, # 自动选择GPU或CPU
encode_kwargs={'normalize_embeddings': True} # 标准化嵌入向量
)
print(f"✅ HuggingFace嵌入模型初始化成功 (设备: {device})")
except Exception as e:
print(f"⚠️ HuggingFace嵌入初始化失败: {e}")
print("正在尝试备用嵌入方案...")
# Fallback to OpenAI embeddings or other alternatives
from langchain_community.embeddings import FakeEmbeddings
self.embeddings = FakeEmbeddings(size=384) # For testing purposes
print("✅ 使用测试嵌入模型")
self.vectorstore = None
self.retriever = None
self.bm25_retriever = None # BM25检索器
self.ensemble_retriever = None # 集成检索器
# 初始化重排器
self.reranker = None
self._setup_reranker()
# 初始化多模态支持
self.image_embeddings_model = None
self._setup_multimodal()
# 初始化查询扩展
self.query_expansion_model = None
self._setup_query_expansion()
def _setup_reranker(self):
"""
设置重排器
使用 CrossEncoder 提升重排准确率
"""
try:
# 使用 CrossEncoder 重排器 (准确率最高) ⭐
print("🔧 正在初始化 CrossEncoder 重排器...")
self.reranker = create_reranker(
'crossencoder',
model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', # 轻量级模型
max_length=512
)
print("✅ CrossEncoder 重排器初始化成功")
except Exception as e:
print(f"⚠️ CrossEncoder 初始化失败: {e}")
print("🔄 尝试回退到混合重排器...")
try:
# 回退到混合重排器
self.reranker = create_reranker('hybrid', self.embeddings)
print("✅ 混合重排器初始化成功")
except Exception as e2:
print(f"⚠️ 重排器初始化完全失败: {e2}")
print("⚠️ 将使用基础检索,不进行重排")
def _setup_multimodal(self):
"""设置多模态支持"""
if not ENABLE_MULTIMODAL:
print("⚠️ 多模态支持已禁用")
return
try:
print("🔧 正在初始化多模态支持...")
from transformers import CLIPProcessor, CLIPModel
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.image_embeddings_model = CLIPModel.from_pretrained(MULTIMODAL_IMAGE_MODEL).to(device)
self.image_processor = CLIPProcessor.from_pretrained(MULTIMODAL_IMAGE_MODEL)
print(f"✅ 多模态支持初始化成功 (设备: {device})")
except Exception as e:
print(f"⚠️ 多模态支持初始化失败: {e}")
print("⚠️ 将仅使用文本检索")
self.image_embeddings_model = None
def _setup_query_expansion(self):
"""设置查询扩展"""
if not ENABLE_QUERY_EXPANSION:
print("⚠️ 查询扩展已禁用")
return
try:
print("🔧 正在初始化查询扩展...")
from langchain_community.llms import Ollama
self.query_expansion_model = Ollama(model=QUERY_EXPANSION_MODEL)
print(f"✅ 查询扩展初始化成功 (模型: {QUERY_EXPANSION_MODEL})")
except Exception as e:
print(f"⚠️ 查询扩展初始化失败: {e}")
print("⚠️ 将不使用查询扩展")
self.query_expansion_model = None
def load_documents(self, urls=None):
"""从URL加载文档"""
if urls is None:
urls = KNOWLEDGE_BASE_URLS
print(f"正在加载 {len(urls)} 个URL的文档...")
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
print(f"成功加载 {len(docs_list)} 个文档")
return docs_list
def split_documents(self, docs):
"""将文档分割成块"""
print("正在分割文档...")
doc_splits = self.text_splitter.split_documents(docs)
print(f"文档分割完成,共 {len(doc_splits)} 个文档块")
return doc_splits
def initialize_vectorstore(self):
"""初始化向量数据库连接"""
if self.vectorstore:
return
print("正在连接向量数据库...")
# 强制使用 Milvus
try:
# 准备连接参数
connection_args = {}
is_local_file = False
# 优先使用 URI
if MILVUS_URI and len(MILVUS_URI.strip()) > 0:
is_local_file = not (MILVUS_URI.startswith("http://") or MILVUS_URI.startswith("https://"))
real_uri = MILVUS_URI
if is_local_file:
import os
# Milvus Lite requires absolute path in some versions/environments
if not os.path.isabs(real_uri):
real_uri = os.path.abspath(real_uri)
print(f"📂 将相对路径转换为绝对路径: {real_uri}")
# 确保父目录存在
parent_dir = os.path.dirname(real_uri)
if parent_dir and not os.path.exists(parent_dir):
print(f"📂 创建 Milvus 存储目录: {parent_dir}")
os.makedirs(parent_dir, exist_ok=True)
mode_name = "Lite (Local File)" if is_local_file else "Cloud (HTTP)"
print(f"🔄 正在连接 Milvus {mode_name} ({real_uri})...")
connection_args["uri"] = real_uri
if not is_local_file and MILVUS_PASSWORD:
connection_args["token"] = MILVUS_PASSWORD
else:
print(f"🔄 正在连接 Milvus Server ({MILVUS_HOST}:{MILVUS_PORT})...")
connection_args = {
"host": MILVUS_HOST,
"port": MILVUS_PORT,
"user": MILVUS_USER,
"password": MILVUS_PASSWORD
}
# 显式建立全局连接 (修复 ConnectionNotExistException)
try:
from pymilvus import connections, utility
print(f"🔌 尝试建立 pymilvus 全局连接 (Alias: default)...")
# 移除旧连接(如果存在)以防参数变更
if connections.has_connection("default"):
connections.disconnect("default")
connections.connect(alias="default", **connection_args)
print("✅ pymilvus 全局连接建立成功")
# 检查集合是否存在 (提前检查,避免 LangChain 内部出错)
if utility.has_collection(COLLECTION_NAME, using="default"):
print(f"✅ 集合 {COLLECTION_NAME} 已存在")
else:
print(f"ℹ️ 集合 {COLLECTION_NAME} 不存在,将由 Milvus 类自动创建")
except ImportError:
print("⚠️ 未找到 pymilvus 库,跳过显式连接")
except Exception as e:
print(f"⚠️ 显式连接尝试失败: {e}")
# 继续尝试,也许 LangChain 内部能处理
# 确定索引类型
# Milvus Lite (本地模式) 仅支持 FLAT, IVF_FLAT, AUTOINDEX,不支持 HNSW
final_index_type = MILVUS_INDEX_TYPE
final_index_params = MILVUS_INDEX_PARAMS
if is_local_file and MILVUS_INDEX_TYPE == "HNSW":
print("⚠️ 检测到 Milvus Lite (本地模式),HNSW 索引不受支持,自动切换为 AUTOINDEX")
final_index_type = "AUTOINDEX"
final_index_params = {} # AUTOINDEX 不需要复杂参数
# 初始化 Milvus 连接 (不删除旧数据)
# 注意:由于我们已经手动建立了全局连接 'default',
# 这里我们将 connection_args 简化为仅指向该 alias,
# 避免 LangChain 再次尝试连接或因参数问题覆盖连接。
self.vectorstore = Milvus(
embedding_function=self.embeddings,
collection_name=COLLECTION_NAME,
connection_args={"alias": "default"}, # ✅ 复用已建立的连接
index_params={
"metric_type": "L2",
"index_type": final_index_type,
"params": final_index_params
},
search_params={
"metric_type": "L2",
"params": MILVUS_SEARCH_PARAMS
},
drop_old=False, # ✅ 持久化关键:不删除旧索引
auto_id=True
)
print("✅ Milvus 向量数据库连接成功")
except ImportError:
print("❌ 未安装 pymilvus,请运行: pip install pymilvus")
raise
except Exception as e:
print(f"❌ Milvus 连接失败: {e}")
raise
# 配置检索器
retriever_kwargs = {}
# if ENABLE_MULTIMODAL:
# retriever_kwargs["expr"] = "data_type == 'text'"
self.retriever = self.vectorstore.as_retriever(search_kwargs=retriever_kwargs)
def check_existing_urls(self, urls: List[str]) -> set:
"""检查哪些URL已经存在于向量库中"""
if not self.vectorstore:
return set()
existing = set()
print("正在检查已存在的文档...")
try:
# 尝试通过检索来检查
# 注意:这里假设 source 字段可以作为过滤条件
for url in urls:
# 使用 similarity_search 但带有严格过滤,且只取1条
# 这里的 query 没关系,主要看 filter
try:
# 注意:Milvus 的 expr 语法
expr = f'source == "{url}"'
res = self.vectorstore.similarity_search(
"test",
k=1,
expr=expr
)
if res:
existing.add(url)
except Exception as e:
# 如果失败,可能是 schema 问题,尝试 metadata 字段
try:
expr = f'metadata["source"] == "{url}"'
res = self.vectorstore.similarity_search(
"test",
k=1,
expr=expr
)
if res:
existing.add(url)
except:
pass
print(f"✅ 发现 {len(existing)} 个已存在的 URL")
except Exception as e:
print(f"⚠️ 检查现有URL失败: {e}")
return existing
def add_documents_to_vectorstore(self, doc_splits):
"""添加文档到向量库"""
if not doc_splits:
return
print(f"正在添加 {len(doc_splits)} 个文档块到向量数据库...")
if not self.vectorstore:
self.initialize_vectorstore()
# 添加元数据
for doc in doc_splits:
if 'source_type' not in doc.metadata:
source = doc.metadata.get('source', '')
if any(fmt in source.lower() for fmt in SUPPORTED_IMAGE_FORMATS):
doc.metadata['data_type'] = 'image'
else:
doc.metadata['data_type'] = 'text'
self.vectorstore.add_documents(doc_splits)
print("✅ 文档添加完成")
def create_vectorstore(self, doc_splits, persist_directory=None):
"""(已弃用) 兼容旧接口,但使用新逻辑"""
print("⚠️ create_vectorstore 已弃用,请使用 initialize_vectorstore 和 add_documents_to_vectorstore")
self.initialize_vectorstore()
if doc_splits:
self.add_documents_to_vectorstore(doc_splits)
return self.vectorstore, self.retriever
def get_all_documents_from_vectorstore(self, limit: Optional[int] = None) -> List[Document]:
"""从已持久化的向量数据库读取所有文档内容并构造 Document 列表"""
if not self.vectorstore:
return []
try:
data = self.vectorstore._collection.get(include=["documents", "metadatas"]) # type: ignore
docs_raw = data.get("documents") or []
metas = data.get("metadatas") or []
docs: List[Document] = []
for i, content in enumerate(docs_raw):
if content:
meta = metas[i] if i < len(metas) else {}
docs.append(Document(page_content=content, metadata=meta))
if limit:
return docs[:limit]
return docs
except Exception as e:
print(f"⚠️ 读取向量库文档失败: {e}")
return []
def setup_knowledge_base(self, urls=None, enable_graphrag=False):
"""设置完整的知识库(加载、分割、向量化)
Args:
urls: 文档URL列表
enable_graphrag: 是否启用GraphRAG索引
Returns:
vectorstore, retriever, doc_splits
"""
if urls is None:
urls = KNOWLEDGE_BASE_URLS
# 1. 初始化向量库连接
self.initialize_vectorstore()
# 2. 检查已存在的 URL (去重)
existing_urls = self.check_existing_urls(urls)
new_urls = [url for url in urls if url not in existing_urls]
doc_splits = []
if new_urls:
print(f"🔄 发现 {len(new_urls)} 个新 URL,开始处理...")
docs = self.load_documents(new_urls)
doc_splits = self.split_documents(docs)
self.add_documents_to_vectorstore(doc_splits)
else:
print("✅ 所有 URL 已存在,跳过文档加载和向量化")
# 3. 初始化混合检索 (BM25)
if ENABLE_HYBRID_SEARCH:
print("正在初始化混合检索 (BM25)...")
try:
bm25_docs = []
# 如果有旧数据且这次没有加载全部数据,必须从 DB 加载所有文档以重建 BM25
# 注意:如果只有新文档,BM25 只会包含新文档,这是不对的。
# 只要有 existing_urls,说明库里有旧数据。
if len(existing_urls) > 0:
print("🔄 正在从向量库加载所有文档以重建 BM25 索引...")
# 注意:这里假设内存够大
all_docs = self.get_all_documents_from_vectorstore()
bm25_docs = all_docs
else:
# 全新构建
bm25_docs = doc_splits
if bm25_docs:
self.bm25_retriever = BM25Retriever.from_documents(
bm25_docs,
k=KEYWORD_SEARCH_K,
k1=BM25_K1,
b=BM25_B
)
self.ensemble_retriever = CustomEnsembleRetriever(
retrievers=[self.retriever, self.bm25_retriever],
weights=[HYBRID_SEARCH_WEIGHTS["vector"], HYBRID_SEARCH_WEIGHTS["keyword"]]
)
print("✅ 混合检索初始化成功")
else:
print("⚠️ 没有文档用于初始化 BM25")
except Exception as e:
print(f"⚠️ 混合检索初始化失败: {e}")
self.ensemble_retriever = None
# 返回 doc_splits用于GraphRAG索引 (注意:这里只返回了新增的)
return self.vectorstore, self.retriever, doc_splits
async def async_expand_query(self, query: str) -> List[str]:
"""异步扩展查询"""
if not self.query_expansion_model:
return [query]
try:
# 使用LLM生成扩展查询
prompt = QUERY_EXPANSION_PROMPT.format(query=query)
expanded_queries_text = await self.query_expansion_model.ainvoke(prompt)
# 解析扩展查询
expanded_queries = [query] # 包含原始查询
for line in expanded_queries_text.strip().split('\n'):
line = line.strip()
if line and not line.startswith('#') and not line.startswith('//'):
# 移除可能的编号前缀
if line[0].isdigit() and '.' in line[:5]:
line = line.split('.', 1)[1].strip()
expanded_queries.append(line)
# 限制扩展查询数量
return expanded_queries[:MAX_EXPANDED_QUERIES + 1]
except Exception as e:
print(f"⚠️ 异步查询扩展失败: {e}")
return [query]
async def async_hybrid_retrieve(self, query: str, top_k: int = 5, filter_type: str = "text") -> List:
"""异步混合检索
Args:
filter_type: 数据类型过滤,"text" (默认), "image", 或 "all" (不过滤)
"""
# 构建搜索参数
search_kwargs = {}
if filter_type != "all" and ENABLE_MULTIMODAL:
search_kwargs["expr"] = f"data_type == '{filter_type}'"
if not ENABLE_HYBRID_SEARCH or not self.ensemble_retriever:
# 纯向量检索,直接支持 search_kwargs
if self.vectorstore:
return await self.vectorstore.asimilarity_search(query, k=top_k, **search_kwargs)
return await self.retriever.ainvoke(query)
try:
# 混合检索
# 注意:目前 CustomEnsembleRetriever 的 invoke/ainvoke 尚未透传 search_kwargs
# 为了让混合检索也享受到过滤优化,我们需要修改 CustomEnsembleRetriever 或者在这里处理
# 鉴于 CustomEnsembleRetriever 比较简单,我们假设它主要用于文本
# 如果需要严格过滤,最好在 vectorstore 层面处理
# 临时方案:如果是混合检索且需要过滤,我们可能需要传递给 retriever
# 但标准 retriever 接口不支持动态传参。
# 策略:如果 filter_type 是 text (默认),且我们在 init 时已经设置了默认不严格过滤,
# 这里其实无法动态改变 retriever 的行为,除非我们重新生成一个 retriever 或者修改 retriever.search_kwargs
# 动态修改 retriever 的 search_kwargs (这是 LangChain retriever 的特性)
if filter_type != "all" and ENABLE_MULTIMODAL:
self.retriever.search_kwargs["expr"] = f"data_type == '{filter_type}'"
else:
self.retriever.search_kwargs.pop("expr", None)
results = await self.ensemble_retriever.ainvoke(query)
return results[:top_k]
except Exception as e:
print(f"⚠️ 异步混合检索失败: {e}")
print("回退到向量检索")
if self.vectorstore:
return await self.vectorstore.asimilarity_search(query, k=top_k, **search_kwargs)
return await self.retriever.ainvoke(query)
async def async_enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20,
image_paths: List[str] = None, use_query_expansion: bool = None):
"""异步增强检索"""
import asyncio
# 确定是否使用查询扩展
if use_query_expansion is None:
use_query_expansion = ENABLE_QUERY_EXPANSION
# 如果启用查询扩展,生成扩展查询
if use_query_expansion:
expanded_queries = await self.async_expand_query(query)
print(f"查询扩展: {len(expanded_queries)} 个查询")
else:
expanded_queries = [query]
# 多模态检索(暂时保持同步,使用线程池)
if image_paths and ENABLE_MULTIMODAL:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self.multimodal_retrieve, query, image_paths, top_k)
# 混合检索或向量检索
all_candidate_docs = []
# 决定过滤策略
# 默认情况下,如果只是文本查询,为了性能优化,我们只检索文本数据
# 如果提供了图像,或者用户显式要求,可以放开限制
filter_type = "text" # 默认只搜文本,实现百万级数据的性能优化
if image_paths:
filter_type = "all" # 跨模态时搜所有
# 构建过滤表达式 (仅用于直接调用 vectorstore 的情况,async_hybrid_retrieve 内部已处理)
search_kwargs = {}
if filter_type != "all" and ENABLE_MULTIMODAL:
search_kwargs["expr"] = f"data_type == '{filter_type}'"
async def retrieve_single(q):
if ENABLE_HYBRID_SEARCH:
# 使用支持动态过滤的 hybrid retrieve
docs = await self.async_hybrid_retrieve(q, rerank_candidates, filter_type=filter_type)
else:
# 使用带有过滤条件的检索
if self.vectorstore:
docs = await self.vectorstore.asimilarity_search(
q,
k=rerank_candidates,
**search_kwargs # 传入 expr
)
else:
# Fallback
docs = await self.retriever.ainvoke(q)
if len(docs) > rerank_candidates:
docs = docs[:rerank_candidates]
return docs
# 并发执行所有查询的检索
results = await asyncio.gather(*[retrieve_single(q) for q in expanded_queries])
for docs in results:
all_candidate_docs.extend(docs)
# 去重(基于文档内容)
unique_docs = []
seen_content = set()
for doc in all_candidate_docs:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_docs.append(doc)
print(f"检索获得 {len(unique_docs)} 个候选文档")
# 重排(如果重排器可用)
# 注意:重排通常是计算密集型,建议放入线程池
if self.reranker and len(unique_docs) > top_k:
try:
loop = asyncio.get_running_loop()
# rerank 方法内部可能也比较耗时
reranked_results = await loop.run_in_executor(
None,
self.reranker.rerank,
query, unique_docs, top_k
)
final_docs = [doc for doc, score in reranked_results]
scores = [score for doc, score in reranked_results]
print(f"重排后返回 {len(final_docs)} 个文档")
print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}")
return final_docs
except Exception as e:
print(f"⚠️ 重排失败: {e},使用原始检索结果")
return unique_docs[:top_k]
else:
return unique_docs[:top_k]
def expand_query(self, query: str) -> List[str]:
"""扩展查询,生成相关查询"""
if not self.query_expansion_model:
return [query]
try:
# 使用LLM生成扩展查询
prompt = QUERY_EXPANSION_PROMPT.format(query=query)
expanded_queries_text = self.query_expansion_model.invoke(prompt)
# 解析扩展查询
expanded_queries = [query] # 包含原始查询
for line in expanded_queries_text.strip().split('\n'):
line = line.strip()
if line and not line.startswith('#') and not line.startswith('//'):
# 移除可能的编号前缀
if line[0].isdigit() and '.' in line[:5]:
line = line.split('.', 1)[1].strip()
expanded_queries.append(line)
# 限制扩展查询数量
return expanded_queries[:MAX_EXPANDED_QUERIES + 1] # +1 因为包含原始查询
except Exception as e:
print(f"⚠️ 查询扩展失败: {e}")
return [query]
def encode_image(self, image_path: str) -> np.ndarray:
"""编码图像为嵌入向量"""
if not self.image_embeddings_model:
raise ValueError("多模态支持未初始化")
try:
# 加载并处理图像
image = Image.open(image_path).convert('RGB')
inputs = self.image_processor(images=image, return_tensors="pt")
# 获取图像嵌入
with torch.no_grad():
image_features = self.image_embeddings_model.get_image_features(**inputs)
# 标准化嵌入向量
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
return image_features.cpu().numpy().flatten()
except Exception as e:
print(f"⚠️ 图像编码失败: {e}")
raise
def multimodal_retrieve(self, query: str, image_paths: List[str] = None, top_k: int = 5) -> List:
"""多模态检索,结合文本和图像"""
if not ENABLE_MULTIMODAL or not self.image_embeddings_model:
# 如果多模态未启用,回退到文本检索
return self.hybrid_retrieve(query, top_k) if ENABLE_HYBRID_SEARCH else self.retriever.invoke(query)[:top_k]
# 1. 文本查询 (Text-to-Text & Text-to-Image)
# 如果提供了文本查询,我们希望它能检索到文本和相关图像
# 此时不应该限制 data_type,或者应该显式包含两者
# 如果没有提供图像,这可能是一个纯文本查询,但也可能想搜图
# 这里我们让 self.retriever (或 hybrid) 负责所有模态的检索
# (前提是它们都在同一个向量空间,CLIP 可以做到这一点)
text_docs = []
if query:
text_docs = self.hybrid_retrieve(query, top_k) if ENABLE_HYBRID_SEARCH else self.retriever.invoke(query)[:top_k]
# 如果没有提供图像输入,直接返回文本查询的结果
if not image_paths:
return text_docs
try:
# 2. 图像查询 (Image-to-Text & Image-to-Image)
image_results = []
for image_path in image_paths:
# 检查文件格式
file_ext = image_path.split('.')[-1].lower()
if file_ext not in SUPPORTED_IMAGE_FORMATS:
print(f"⚠️ 不支持的图像格式: {file_ext}")
continue
# 编码图像
image_embedding = self.encode_image(image_path)
# 使用图像嵌入进行检索
if self.vectorstore:
# 图像可以检索文本描述,也可以检索相似图像
# 这里我们不做限制,检索所有类型
img_docs = self.vectorstore.similarity_search_by_vector(
embedding=image_embedding,
k=top_k
)
image_results.extend(img_docs)
# 合并文本查询结果和图像查询结果
# 简单合并并去重
all_docs = text_docs + image_results
# 去重
unique_docs = []
seen_content = set()
for doc in all_docs:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_docs.append(doc)
final_docs = unique_docs[:top_k]
print(f"✅ 多模态检索完成,返回 {len(final_docs)} 个结果")
return final_docs
except Exception as e:
print(f"⚠️ 多模态检索失败: {e}")
print("回退到文本检索")
return text_docs
def hybrid_retrieve(self, query: str, top_k: int = 5) -> List:
"""混合检索,结合向量检索和关键词检索"""
if not ENABLE_HYBRID_SEARCH or not self.ensemble_retriever:
# 如果混合检索未启用,回退到向量检索
return self.retriever.invoke(query)[:top_k]
try:
# 使用集成检索器进行混合检索
results = self.ensemble_retriever.invoke(query)
return results[:top_k]
except Exception as e:
print(f"⚠️ 混合检索失败: {e}")
print("回退到向量检索")
return self.retriever.invoke(query)[:top_k]
def enhanced_retrieve(self, query: str, top_k: int = 5, rerank_candidates: int = 20,
image_paths: List[str] = None, use_query_expansion: bool = None):
"""增强检索:先检索更多候选,然后重排,支持查询扩展和多模态
Args:
query: 查询字符串
top_k: 返回的文档数量
rerank_candidates: 重排前的候选文档数量
image_paths: 图像路径列表,用于多模态检索
use_query_expansion: 是否使用查询扩展,None表示使用配置默认值
"""
# 确定是否使用查询扩展
if use_query_expansion is None:
use_query_expansion = ENABLE_QUERY_EXPANSION
# 如果启用查询扩展,生成扩展查询
if use_query_expansion:
expanded_queries = self.expand_query(query)
print(f"查询扩展: {len(expanded_queries)} 个查询")
else:
expanded_queries = [query]
# 多模态检索(如果提供了图像)
if image_paths and ENABLE_MULTIMODAL:
return self.multimodal_retrieve(query, image_paths, top_k)
# 混合检索或向量检索
all_candidate_docs = []
for expanded_query in expanded_queries:
if ENABLE_HYBRID_SEARCH:
# 使用混合检索
docs = self.hybrid_retrieve(expanded_query, rerank_candidates)
else:
# 使用向量检索
docs = self.retriever.invoke(expanded_query)
if len(docs) > rerank_candidates:
docs = docs[:rerank_candidates]
all_candidate_docs.extend(docs)
# 去重(基于文档内容)
unique_docs = []
seen_content = set()
for doc in all_candidate_docs:
content = doc.page_content
if content not in seen_content:
seen_content.add(content)
unique_docs.append(doc)
print(f"检索获得 {len(unique_docs)} 个候选文档")
# 重排(如果重排器可用)
if self.reranker and len(unique_docs) > top_k:
try:
reranked_results = self.reranker.rerank(query, unique_docs, top_k)
final_docs = [doc for doc, score in reranked_results]
scores = [score for doc, score in reranked_results]
print(f"重排后返回 {len(final_docs)} 个文档")
print(f"重排分数范围: {min(scores):.4f} - {max(scores):.4f}")
return final_docs
except Exception as e:
print(f"⚠️ 重排失败: {e},使用原始检索结果")
return unique_docs[:top_k]
else:
# 不重排或候选数量不足
return unique_docs[:top_k]
def compare_retrieval_methods(self, query: str, top_k: int = 5, image_paths: List[str] = None):
"""比较不同检索方法的效果"""
if not self.retriever:
return {}
results = {
'query': query,
'image_paths': image_paths
}
# 原始检索 (使用 invoke 替代 get_relevant_documents)
original_docs = self.retriever.invoke(query)[:top_k]
results['vector_retrieval'] = {
'count': len(original_docs),
'documents': [{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': getattr(doc, 'metadata', {})
} for doc in original_docs]
}
# 混合检索(如果启用)
if ENABLE_HYBRID_SEARCH and self.ensemble_retriever:
hybrid_docs = self.hybrid_retrieve(query, top_k)
results['hybrid_retrieval'] = {
'count': len(hybrid_docs),
'documents': [{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': getattr(doc, 'metadata', {})
} for doc in hybrid_docs]
}
# 查询扩展检索(如果启用)
if ENABLE_QUERY_EXPANSION and self.query_expansion_model:
expanded_docs = self.enhanced_retrieve(query, top_k, use_query_expansion=True)
results['expanded_query_retrieval'] = {
'count': len(expanded_docs),
'documents': [{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': getattr(doc, 'metadata', {})
} for doc in expanded_docs]
}
# 多模态检索(如果启用且有图像)
if ENABLE_MULTIMODAL and image_paths:
multimodal_docs = self.multimodal_retrieve(query, image_paths, top_k)
results['multimodal_retrieval'] = {
'count': len(multimodal_docs),
'documents': [{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': getattr(doc, 'metadata', {})
} for doc in multimodal_docs]
}
# 增强检索(带重排)
enhanced_docs = self.enhanced_retrieve(query, top_k)
results['enhanced_retrieval'] = {
'count': len(enhanced_docs),
'documents': [{
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
'metadata': getattr(doc, 'metadata', {})
} for doc in enhanced_docs]
}
# 添加配置信息
results['configuration'] = {
'hybrid_search_enabled': ENABLE_HYBRID_SEARCH,
'query_expansion_enabled': ENABLE_QUERY_EXPANSION,
'multimodal_enabled': ENABLE_MULTIMODAL,
'reranker_used': self.reranker is not None,
'hybrid_weights': HYBRID_SEARCH_WEIGHTS if ENABLE_HYBRID_SEARCH else None,
'multimodal_weights': MULTIMODAL_WEIGHTS if ENABLE_MULTIMODAL else None
}
return results
def format_docs(self, docs):
"""格式化文档用于生成"""
return "\n\n".join(doc.page_content for doc in docs)
def initialize_document_processor():
"""初始化文档处理器并设置知识库"""
print("🚀 初始化文档处理器 (Milvus 版)...")
processor = DocumentProcessor()
# 直接设置知识库
# Milvus 的连接和索引逻辑在 DocumentProcessor.create_vectorstore 中处理
vectorstore, retriever, doc_splits = processor.setup_knowledge_base()
return processor, vectorstore, retriever, doc_splits