File size: 8,714 Bytes
7756dfe 868518d 1b59f0f 49bc88a cfb951e a9e075d 49bc88a 75dc93c 223da69 df5ef1a 6183cb0 a9e075d 6183cb0 1b59f0f 49bc88a 252fd1e e931572 49bc88a 9e5ab6c 49bc88a 0361024 49bc88a a9e075d e931572 5a49e7f ffcb61e e931572 52111a1 5a49e7f e931572 5a49e7f e931572 5a49e7f e931572 5a49e7f 49bc88a 0361024 c3d3c82 df5ef1a 868518d 49bc88a 0361024 758f12d 7057632 758f12d 0361024 758f12d 0361024 931e174 0361024 931e174 758f12d 2825b70 df5ef1a 758f12d df5ef1a 758f12d 931e174 df5ef1a 758f12d df5ef1a 803396c 758f12d b841e5b 2e36b27 758f12d df5ef1a 2e36b27 49bc88a 803396c df5ef1a 803396c 758f12d 223da69 2e36b27 7de46e8 2e36b27 868518d c3d3c82 49bc88a c3d3c82 a7ea7d8 c3d3c82 a7ea7d8 0361024 223da69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
import gradio as gr
import numpy as np
import os
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
import json
import fastapi
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import threading
import math
# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# 减少内存占用
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 全局变量
index = None
metadata = None
# 新增全局变量
last_updated = 0
index_refresh_interval = 300 # 5分钟刷新一次
# 新增索引刷新函数
def refresh_index():
global index, metadata, last_updated
while True:
try:
# 检查是否有更新
current_time = time.time()
if current_time - last_updated > index_refresh_interval:
print("🔄 检查索引更新...")
# 获取最新元数据
METADATA_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="metadata.csv",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN"),
force_download=True # 强制更新
)
# 检查文件修改时间
file_mtime = os.path.getmtime(METADATA_PATH)
if file_mtime > last_updated:
print("📥 检测到新索引,重新加载...")
# 重新加载索引
INDEX_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="faiss_index.bin",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN"),
force_download=True
)
new_index = faiss.read_index(INDEX_PATH)
new_metadata = pd.read_csv(METADATA_PATH)
# 原子替换
index = new_index
metadata = new_metadata
last_updated = file_mtime
print(f"✅ 索引更新完成 | 记录数: {len(metadata)}")
except Exception as e:
print(f"索引更新失败: {str(e)}")
# 每30秒检查一次
time.sleep(30)
def load_resources():
"""加载所有必要资源(768维专用)"""
global index, metadata
# 清理残留锁文件
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
for lock_file in lock_files:
try:
os.remove(os.path.join(CACHE_DIR, lock_file))
print(f"🧹 清理锁文件: {lock_file}")
except:
pass
if index is None or metadata is None:
print("🔄 正在加载所有资源...")
# 加载FAISS索引(768维)
if index is None:
print("📥 正在下载FAISS索引...")
try:
INDEX_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="faiss_index.bin",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
index = faiss.read_index(INDEX_PATH)
if index.d != 768:
raise ValueError("❌ 索引维度错误:预期768维")
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
except Exception as e:
print(f"❌ FAISS索引加载失败: {str(e)}")
raise
# 加载元数据
if metadata is None:
print("📄 正在下载元数据...")
try:
METADATA_PATH = hf_hub_download(
repo_id="GOGO198/GOGO_rag_index",
filename="metadata.csv",
cache_dir=CACHE_DIR,
token=os.getenv("HF_TOKEN")
)
metadata = pd.read_csv(METADATA_PATH)
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
except Exception as e:
print(f"❌ 元数据加载失败: {str(e)}")
raise
# 启动索引刷新线程
threading.Thread(target=refresh_index, daemon=True).start()
# 确保资源在API调用前加载
load_resources()
def sanitize_floats(obj):
if isinstance(obj, float):
if math.isnan(obj) or math.isinf(obj):
return 0.0 # 替换非法值为默认值
return obj
elif isinstance(obj, dict):
return {k: sanitize_floats(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [sanitize_floats(x) for x in obj]
else:
return obj
# 在返回结果前调用清理器
return {
"status": "success",
"results": sanitize_floats(results) # 深度清理
}
def predict(vector):
try:
print(f"接收向量: {vector[:3]}... (长度: {len(vector)})")
# 确保向量格式正确
query_vector = np.array(vector).astype('float32').reshape(1, -1)
# 动态结果数量 (最大不超过总文档数)
k = min(3, index.ntotal)
if k == 0:
return {
"status": "success",
"results": [],
"message": "索引为空"
}
print(f"执行FAISS搜索 | k={k}")
D, I = index.search(query_vector, k=k)
# 打印搜索结果
print(f"搜索结果索引: {I[0]}")
print(f"搜索距离: {D[0]}")
# 构建结果
results = []
for i in range(k):
try:
idx = I[0][i]
distance = D[0][i]
# 修复1:处理非法浮点数
if not np.isfinite(distance) or distance < 0:
distance = 100.0 # 设置为安全阈值
# 修复2:安全计算置信度 (0-1范围)
confidence = 1 / (1 + distance)
confidence = max(0.0, min(1.0, confidence)) # 钳制到[0,1]
# 修复3:强制转换为合法浮点
distance = float(distance)
confidence = float(confidence)
result = {
"source": metadata.iloc[idx]["source"],
"content": metadata.iloc[idx].get("content", ""),
"confidence": confidence,
"distance": distance
}
results.append(result)
except Exception as e:
# 确保异常结果也符合JSON规范
results.append({
"error": str(e),
"confidence": 0.5,
"distance": 0.0
})
return {
"status": "success",
"results": sanitize_floats(results)
}
except Exception as e:
# 返回错误响应
return {
"status": "error",
"message": f"服务器内部错误: {str(e)}",
"details": sanitize_floats({"trace": traceback.format_exc()})
}
# 创建FastAPI应用
app = FastAPI()
# 添加CORS支持
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def api_predict(request: Request):
"""API预测端点"""
try:
data = await request.json()
vector = data.get("vector")
if not vector or not isinstance(vector, list):
return JSONResponse(
status_code=400,
content={"status": "error", "message": "无效输入: 需要向量数据"}
)
result = predict(vector)
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"status": "error",
"message": f"服务器内部错误了: {str(e)}"
}
)
# 启动应用
if __name__ == "__main__":
# 验证资源
print("="*50)
print("Space启动完成 | 准备接收请求")
print(f"索引维度: {index.d}")
print(f"元数据记录数: {len(metadata)}")
print("="*50)
# 只启动FastAPI服务
uvicorn.run(app, host="0.0.0.0", port=7860) |