File size: 8,714 Bytes
7756dfe
868518d
1b59f0f
49bc88a
cfb951e
a9e075d
49bc88a
75dc93c
223da69
 
 
 
 
 
df5ef1a
6183cb0
a9e075d
6183cb0
 
1b59f0f
49bc88a
252fd1e
e931572
49bc88a
9e5ab6c
49bc88a
 
 
0361024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49bc88a
a9e075d
e931572
 
5a49e7f
 
 
 
 
 
ffcb61e
 
e931572
52111a1
5a49e7f
e931572
5a49e7f
 
 
 
 
 
 
 
 
 
 
 
 
e931572
 
5a49e7f
 
 
 
e931572
5a49e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49bc88a
0361024
 
 
c3d3c82
 
 
df5ef1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868518d
49bc88a
0361024
 
758f12d
7057632
758f12d
0361024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758f12d
0361024
931e174
0361024
931e174
758f12d
2825b70
 
df5ef1a
 
 
 
 
 
 
 
 
 
 
 
758f12d
 
 
df5ef1a
 
758f12d
 
931e174
df5ef1a
758f12d
df5ef1a
803396c
 
758f12d
b841e5b
2e36b27
758f12d
df5ef1a
2e36b27
49bc88a
803396c
 
 
df5ef1a
 
803396c
758f12d
223da69
2e36b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7de46e8
2e36b27
 
868518d
c3d3c82
49bc88a
c3d3c82
a7ea7d8
 
c3d3c82
 
a7ea7d8
0361024
223da69
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import gradio as gr
import numpy as np
import os
import pandas as pd
import faiss
from huggingface_hub import hf_hub_download
import time
import json
import fastapi
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import threading
import math

# 创建安全缓存目录
CACHE_DIR = "/home/user/cache"
os.makedirs(CACHE_DIR, exist_ok=True)

# 减少内存占用
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# 全局变量
index = None
metadata = None

# 新增全局变量
last_updated = 0
index_refresh_interval = 300  # 5分钟刷新一次

# 新增索引刷新函数
def refresh_index():
    global index, metadata, last_updated
    
    while True:
        try:
            # 检查是否有更新
            current_time = time.time()
            if current_time - last_updated > index_refresh_interval:
                print("🔄 检查索引更新...")
                
                # 获取最新元数据
                METADATA_PATH = hf_hub_download(
                    repo_id="GOGO198/GOGO_rag_index",
                    filename="metadata.csv",
                    cache_dir=CACHE_DIR,
                    token=os.getenv("HF_TOKEN"),
                    force_download=True  # 强制更新
                )
                
                # 检查文件修改时间
                file_mtime = os.path.getmtime(METADATA_PATH)
                if file_mtime > last_updated:
                    print("📥 检测到新索引,重新加载...")
                    
                    # 重新加载索引
                    INDEX_PATH = hf_hub_download(
                        repo_id="GOGO198/GOGO_rag_index",
                        filename="faiss_index.bin",
                        cache_dir=CACHE_DIR,
                        token=os.getenv("HF_TOKEN"),
                        force_download=True
                    )
                    new_index = faiss.read_index(INDEX_PATH)
                    new_metadata = pd.read_csv(METADATA_PATH)
                    
                    # 原子替换
                    index = new_index
                    metadata = new_metadata
                    last_updated = file_mtime
                    
                    print(f"✅ 索引更新完成 | 记录数: {len(metadata)}")
        
        except Exception as e:
            print(f"索引更新失败: {str(e)}")
        
        # 每30秒检查一次
        time.sleep(30)

def load_resources():
    """加载所有必要资源(768维专用)"""
    global index, metadata

    # 清理残留锁文件
    lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
    for lock_file in lock_files:
        try:
            os.remove(os.path.join(CACHE_DIR, lock_file))
            print(f"🧹 清理锁文件: {lock_file}")
        except:
            pass

    if index is None or metadata is None:
        print("🔄 正在加载所有资源...")

        # 加载FAISS索引(768维)
        if index is None:
            print("📥 正在下载FAISS索引...")
            try:
                INDEX_PATH = hf_hub_download(
                    repo_id="GOGO198/GOGO_rag_index",
                    filename="faiss_index.bin",
                    cache_dir=CACHE_DIR,
                    token=os.getenv("HF_TOKEN")
                )
                index = faiss.read_index(INDEX_PATH)
                
                if index.d != 768:
                    raise ValueError("❌ 索引维度错误:预期768维")
                
                print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
            except Exception as e:
                print(f"❌ FAISS索引加载失败: {str(e)}")
                raise

        # 加载元数据
        if metadata is None:
            print("📄 正在下载元数据...")
            try:
                METADATA_PATH = hf_hub_download(
                    repo_id="GOGO198/GOGO_rag_index",
                    filename="metadata.csv",
                    cache_dir=CACHE_DIR,
                    token=os.getenv("HF_TOKEN")
                )
                metadata = pd.read_csv(METADATA_PATH)
                print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
            except Exception as e:
                print(f"❌ 元数据加载失败: {str(e)}")
                raise

# 启动索引刷新线程
threading.Thread(target=refresh_index, daemon=True).start()

# 确保资源在API调用前加载
load_resources()

def sanitize_floats(obj):
    if isinstance(obj, float):
        if math.isnan(obj) or math.isinf(obj):
            return 0.0  # 替换非法值为默认值
        return obj
    elif isinstance(obj, dict):
        return {k: sanitize_floats(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [sanitize_floats(x) for x in obj]
    else:
        return obj

    # 在返回结果前调用清理器
    return {
        "status": "success",
        "results": sanitize_floats(results)  # 深度清理
    }

def predict(vector):
    try:
        print(f"接收向量: {vector[:3]}... (长度: {len(vector)})")
        
        # 确保向量格式正确
        query_vector = np.array(vector).astype('float32').reshape(1, -1)
        
        # 动态结果数量 (最大不超过总文档数)
        k = min(3, index.ntotal)
        if k == 0:
            return {
                "status": "success",
                "results": [],
                "message": "索引为空"
            }
        
        print(f"执行FAISS搜索 | k={k}")
        D, I = index.search(query_vector, k=k)
        
        # 打印搜索结果
        print(f"搜索结果索引: {I[0]}")
        print(f"搜索距离: {D[0]}")
        
        # 构建结果
        results = []
        for i in range(k):
            try:
                idx = I[0][i]
                distance = D[0][i]
    
                # 修复1:处理非法浮点数
                if not np.isfinite(distance) or distance < 0:
                    distance = 100.0  # 设置为安全阈值
    
                # 修复2:安全计算置信度 (0-1范围)
                confidence = 1 / (1 + distance)
                confidence = max(0.0, min(1.0, confidence))  # 钳制到[0,1]
    
                # 修复3:强制转换为合法浮点
                distance = float(distance)
                confidence = float(confidence)
    
                result = {
                    "source": metadata.iloc[idx]["source"],
                    "content": metadata.iloc[idx].get("content", ""),
                    "confidence": confidence,
                    "distance": distance
                }
                results.append(result)
            except Exception as e:
                # 确保异常结果也符合JSON规范
                results.append({
                    "error": str(e),
                    "confidence": 0.5,
                    "distance": 0.0
                })

        return {
            "status": "success",
            "results": sanitize_floats(results)
        }
    except Exception as e:
        # 返回错误响应
        return {
            "status": "error",
            "message": f"服务器内部错误: {str(e)}",
            "details": sanitize_floats({"trace": traceback.format_exc()})
        }
        
# 创建FastAPI应用
app = FastAPI()

# 添加CORS支持
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/predict")
async def api_predict(request: Request):
    """API预测端点"""
    try:
        data = await request.json()
        vector = data.get("vector")
        
        if not vector or not isinstance(vector, list):
            return JSONResponse(
                status_code=400,
                content={"status": "error", "message": "无效输入: 需要向量数据"}
            )
            
        result = predict(vector)
        return JSONResponse(content=result)
        
    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={
                "status": "error",
                "message": f"服务器内部错误了: {str(e)}"
            }
        )

# 启动应用
if __name__ == "__main__":
    # 验证资源
    print("="*50)
    print("Space启动完成 | 准备接收请求")
    print(f"索引维度: {index.d}")
    print(f"元数据记录数: {len(metadata)}")
    print("="*50)

    # 只启动FastAPI服务
    uvicorn.run(app, host="0.0.0.0", port=7860)