File size: 10,207 Bytes
20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 9e00cc6 20f7a0a 1625bb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
from typing import List, Dict, Callable, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
DirectoryLoader,
UnstructuredMarkdownLoader,
PyPDFLoader,
TextLoader
)
import os
import requests
import base64
from PIL import Image
import io
class DocumentLoader:
"""通用文档加载器"""
def __init__(self, file_path: str, original_filename: str = None):
self.file_path = file_path
# 使用传入的原始文件名或者从路径提取的文件名
self.original_filename = original_filename or os.path.basename(file_path)
# 从原始文件名中获取扩展名,确保中文文件名也能正确识别文件类型
self.extension = os.path.splitext(self.original_filename)[1].lower()
self.api_key = os.getenv("API_KEY")
self.api_base = os.getenv("BASE_URL")
def process_image(self, image_path: str) -> str:
"""使用 SiliconFlow VLM 模型处理图片"""
try:
# 读取图片并转换为base64
with open(image_path, 'rb') as image_file:
image_data = image_file.read()
base64_image = base64.b64encode(image_data).decode('utf-8')
# 调用 SiliconFlow API
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(
f"{self.api_base}/chat/completions",
headers=headers,
json={
"model": "Qwen/Qwen2.5-VL-72B-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
},
{
"type": "text",
"text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。"
}
]
}
],
"temperature": 0.7,
"max_tokens": 500
}
)
if response.status_code != 200:
raise Exception(f"图片处理API调用失败: {response.text}")
description = response.json()["choices"][0]["message"]["content"]
return description
except Exception as e:
print(f"处理图片时出错: {str(e)}")
return "图片处理失败"
def load(self):
try:
print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}")
if self.extension == '.md':
try:
loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
return loader.load()
except UnicodeDecodeError:
# 如果UTF-8失败,尝试GBK
loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk')
return loader.load()
elif self.extension == '.pdf':
loader = PyPDFLoader(self.file_path)
return loader.load()
elif self.extension == '.txt':
try:
loader = TextLoader(self.file_path, encoding='utf-8')
return loader.load()
except UnicodeDecodeError:
# 如果UTF-8失败,尝试GBK
loader = TextLoader(self.file_path, encoding='gbk')
return loader.load()
elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
# 处理图片
description = self.process_image(self.file_path)
# 创建一个包含图片描述的文档
from langchain.schema import Document
doc = Document(
page_content=description,
metadata={
'source': self.file_path,
'file_name': self.original_filename, # 使用原始文件名
'img_url': os.path.abspath(self.file_path) # 存储图片的绝对路径
}
)
return [doc]
else:
print(f"不支持的文件扩展名: {self.extension}")
raise ValueError(f"不支持的文件格式: {self.extension}")
except UnicodeDecodeError:
# 如果默认编码处理失败,尝试其他编码
print(f"文件编码错误,尝试其他编码: {self.file_path}")
if self.extension in ['.md', '.txt']:
try:
loader = TextLoader(self.file_path, encoding='gbk')
return loader.load()
except Exception as e:
print(f"尝试GBK编码也失败: {str(e)}")
raise
except Exception as e:
print(f"加载文件 {self.file_path} 时出错: {str(e)}")
import traceback
traceback.print_exc()
raise
class DocumentProcessor:
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
def get_index_name(self, path: str) -> str:
"""根据文件路径生成索引名称"""
if os.path.isdir(path):
# 如果是目录,使用目录名
return f"rag_{os.path.basename(path).lower()}"
else:
# 如果是文件,使用文件名(不含扩展名)
return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"
def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]:
"""
加载并处理文档,支持目录或单个文件
参数:
path: 文档路径
progress_callback: 进度回调函数,用于报告处理进度
original_filename: 原始文件名(包括中文)
返回:处理后的文档列表
"""
if os.path.isdir(path):
documents = []
total_files = sum([len(files) for _, _, files in os.walk(path)])
processed_files = 0
processed_size = 0
for root, _, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
try:
# 更新处理进度
if progress_callback:
file_size = os.path.getsize(file_path)
processed_size += file_size
processed_files += 1
progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}")
# 为目录中的每个文件,传递原始文件名
loader = DocumentLoader(file_path, original_filename=file)
docs = loader.load()
# 添加文件名到metadata
for doc in docs:
doc.metadata['file_name'] = file # 使用原始文件名
documents.extend(docs)
except Exception as e:
print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
continue
else:
try:
if progress_callback:
file_size = os.path.getsize(path)
progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}")
# 为单个文件,传递原始文件名
loader = DocumentLoader(path, original_filename=original_filename)
documents = loader.load()
# 更新进度
if progress_callback:
progress_callback(file_size * 0.6, f"处理文件内容...")
# 使用原始文件名而不是存储的文件名
file_name = original_filename or os.path.basename(path)
for doc in documents:
doc.metadata['file_name'] = file_name
except Exception as e:
print(f"加载文件时出错: {str(e)}")
raise
# 分块
chunks = self.text_splitter.split_documents(documents)
# 更新进度
if progress_callback:
if os.path.isdir(path):
progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段")
else:
file_size = os.path.getsize(path)
progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段")
# 处理成统一格式
processed_docs = []
for i, chunk in enumerate(chunks):
processed_docs.append({
'id': f'doc_{i}',
'content': chunk.page_content,
'metadata': chunk.metadata
})
return processed_docs |