|
from typing import List, Dict, Callable, Optional
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_community.document_loaders import (
|
|
DirectoryLoader,
|
|
UnstructuredMarkdownLoader,
|
|
PyPDFLoader,
|
|
TextLoader
|
|
)
|
|
import os
|
|
import requests
|
|
import base64
|
|
from PIL import Image
|
|
import io
|
|
|
|
class DocumentLoader:
|
|
"""通用文档加载器"""
|
|
def __init__(self, file_path: str, original_filename: str = None):
|
|
self.file_path = file_path
|
|
|
|
self.original_filename = original_filename or os.path.basename(file_path)
|
|
|
|
self.extension = os.path.splitext(self.original_filename)[1].lower()
|
|
self.api_key = os.getenv("API_KEY")
|
|
self.api_base = os.getenv("BASE_URL")
|
|
|
|
def process_image(self, image_path: str) -> str:
|
|
"""使用 SiliconFlow VLM 模型处理图片"""
|
|
try:
|
|
|
|
with open(image_path, 'rb') as image_file:
|
|
image_data = image_file.read()
|
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
|
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
response = requests.post(
|
|
f"{self.api_base}/chat/completions",
|
|
headers=headers,
|
|
json={
|
|
"model": "Qwen/Qwen2.5-VL-72B-Instruct",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{base64_image}",
|
|
"detail": "high"
|
|
}
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。"
|
|
}
|
|
]
|
|
}
|
|
],
|
|
"temperature": 0.7,
|
|
"max_tokens": 500
|
|
}
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"图片处理API调用失败: {response.text}")
|
|
|
|
description = response.json()["choices"][0]["message"]["content"]
|
|
return description
|
|
|
|
except Exception as e:
|
|
print(f"处理图片时出错: {str(e)}")
|
|
return "图片处理失败"
|
|
|
|
def load(self):
|
|
try:
|
|
print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}")
|
|
|
|
if self.extension == '.md':
|
|
try:
|
|
loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
|
|
return loader.load()
|
|
except UnicodeDecodeError:
|
|
|
|
loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk')
|
|
return loader.load()
|
|
elif self.extension == '.pdf':
|
|
loader = PyPDFLoader(self.file_path)
|
|
return loader.load()
|
|
elif self.extension == '.txt':
|
|
try:
|
|
loader = TextLoader(self.file_path, encoding='utf-8')
|
|
return loader.load()
|
|
except UnicodeDecodeError:
|
|
|
|
loader = TextLoader(self.file_path, encoding='gbk')
|
|
return loader.load()
|
|
elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
|
|
|
|
description = self.process_image(self.file_path)
|
|
|
|
from langchain.schema import Document
|
|
doc = Document(
|
|
page_content=description,
|
|
metadata={
|
|
'source': self.file_path,
|
|
'file_name': self.original_filename,
|
|
'img_url': os.path.abspath(self.file_path)
|
|
}
|
|
)
|
|
return [doc]
|
|
else:
|
|
print(f"不支持的文件扩展名: {self.extension}")
|
|
raise ValueError(f"不支持的文件格式: {self.extension}")
|
|
|
|
except UnicodeDecodeError:
|
|
|
|
print(f"文件编码错误,尝试其他编码: {self.file_path}")
|
|
if self.extension in ['.md', '.txt']:
|
|
try:
|
|
loader = TextLoader(self.file_path, encoding='gbk')
|
|
return loader.load()
|
|
except Exception as e:
|
|
print(f"尝试GBK编码也失败: {str(e)}")
|
|
raise
|
|
except Exception as e:
|
|
print(f"加载文件 {self.file_path} 时出错: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
class DocumentProcessor:
|
|
def __init__(self):
|
|
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=200,
|
|
length_function=len,
|
|
)
|
|
|
|
def get_index_name(self, path: str) -> str:
|
|
"""根据文件路径生成索引名称"""
|
|
if os.path.isdir(path):
|
|
|
|
return f"rag_{os.path.basename(path).lower()}"
|
|
else:
|
|
|
|
return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"
|
|
|
|
def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]:
|
|
"""
|
|
加载并处理文档,支持目录或单个文件
|
|
参数:
|
|
path: 文档路径
|
|
progress_callback: 进度回调函数,用于报告处理进度
|
|
original_filename: 原始文件名(包括中文)
|
|
返回:处理后的文档列表
|
|
"""
|
|
if os.path.isdir(path):
|
|
documents = []
|
|
total_files = sum([len(files) for _, _, files in os.walk(path)])
|
|
processed_files = 0
|
|
processed_size = 0
|
|
|
|
for root, _, files in os.walk(path):
|
|
for file in files:
|
|
file_path = os.path.join(root, file)
|
|
try:
|
|
|
|
if progress_callback:
|
|
file_size = os.path.getsize(file_path)
|
|
processed_size += file_size
|
|
processed_files += 1
|
|
progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}")
|
|
|
|
|
|
loader = DocumentLoader(file_path, original_filename=file)
|
|
docs = loader.load()
|
|
|
|
for doc in docs:
|
|
doc.metadata['file_name'] = file
|
|
documents.extend(docs)
|
|
except Exception as e:
|
|
print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
|
|
continue
|
|
else:
|
|
try:
|
|
if progress_callback:
|
|
file_size = os.path.getsize(path)
|
|
progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}")
|
|
|
|
|
|
loader = DocumentLoader(path, original_filename=original_filename)
|
|
documents = loader.load()
|
|
|
|
|
|
if progress_callback:
|
|
progress_callback(file_size * 0.6, f"处理文件内容...")
|
|
|
|
|
|
file_name = original_filename or os.path.basename(path)
|
|
for doc in documents:
|
|
doc.metadata['file_name'] = file_name
|
|
except Exception as e:
|
|
print(f"加载文件时出错: {str(e)}")
|
|
raise
|
|
|
|
|
|
chunks = self.text_splitter.split_documents(documents)
|
|
|
|
|
|
if progress_callback:
|
|
if os.path.isdir(path):
|
|
progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段")
|
|
else:
|
|
file_size = os.path.getsize(path)
|
|
progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段")
|
|
|
|
|
|
processed_docs = []
|
|
for i, chunk in enumerate(chunks):
|
|
processed_docs.append({
|
|
'id': f'doc_{i}',
|
|
'content': chunk.page_content,
|
|
'metadata': chunk.metadata
|
|
})
|
|
|
|
return processed_docs |