File size: 12,839 Bytes
260bcd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
"""
扩展点系统实现
提供预定义的扩展点和钩子机制
"""

from typing import Dict, List, Any, Callable, Optional, TypeVar, Generic
from abc import ABC, abstractmethod
from dataclasses import dataclass
import threading

from .plugin_manager import get_plugin_manager
from ..logging.logging_service import get_logging_service


T = TypeVar('T')


@dataclass
class ExtensionContext:
    """扩展上下文"""
    extension_point: str
    data: Dict[str, Any]
    metadata: Dict[str, Any]


class ExtensionHook(ABC, Generic[T]):
    """扩展钩子基类"""

    @abstractmethod
    def execute(self, context: ExtensionContext) -> T:
        """执行扩展钩子

        Args:
            context: 扩展上下文

        Returns:
            扩展结果
        """
        pass


class DocumentProcessingHook(ExtensionHook[Dict[str, Any]]):
    """文档处理扩展钩子"""

    def execute(self, context: ExtensionContext) -> Dict[str, Any]:
        """执行文档处理扩展

        Args:
            context: 包含文档信息的上下文

        Returns:
            处理后的文档数据
        """
        # 默认实现 - 插件可以重写
        return context.data


class QueryProcessingHook(ExtensionHook[Dict[str, Any]]):
    """查询处理扩展钩子"""

    def execute(self, context: ExtensionContext) -> Dict[str, Any]:
        """执行查询处理扩展

        Args:
            context: 包含查询信息的上下文

        Returns:
            处理后的查询数据
        """
        # 默认实现 - 插件可以重写
        return context.data


class ResponseEnhancementHook(ExtensionHook[str]):
    """响应增强扩展钩子"""

    def execute(self, context: ExtensionContext) -> str:
        """执行响应增强扩展

        Args:
            context: 包含响应信息的上下文

        Returns:
            增强后的响应内容
        """
        # 默认实现 - 插件可以重写
        return context.data.get('response', '')


class ExtensionPointRegistry:
    """扩展点注册表"""

    def __init__(self):
        """初始化扩展点注册表"""
        self._logger = get_logging_service()
        self._plugin_manager = get_plugin_manager()
        self._lock = threading.RLock()

        # 预定义扩展点
        self._extension_points = {
            # 文档处理扩展点
            'document.before_upload': [],           # 文档上传前
            'document.after_upload': [],            # 文档上传后
            'document.before_processing': [],       # 文档处理前
            'document.after_processing': [],        # 文档处理后
            'document.before_indexing': [],         # 文档索引前
            'document.after_indexing': [],          # 文档索引后

            # 查询处理扩展点
            'query.before_processing': [],          # 查询处理前
            'query.after_processing': [],           # 查询处理后
            'query.before_retrieval': [],          # 检索前
            'query.after_retrieval': [],           # 检索后
            'query.before_generation': [],         # 生成前
            'query.after_generation': [],          # 生成后

            # 响应处理扩展点
            'response.before_formatting': [],       # 响应格式化前
            'response.after_formatting': [],        # 响应格式化后
            'response.before_delivery': [],         # 响应交付前
            'response.after_delivery': [],          # 响应交付后

            # 系统级扩展点
            'system.startup': [],                   # 系统启动
            'system.shutdown': [],                  # 系统关闭
            'system.error': [],                     # 系统错误
            'system.maintenance': [],               # 系统维护

            # RAG特定扩展点
            'rag.context_enhancement': [],          # 上下文增强
            'rag.answer_validation': [],            # 答案验证
            'rag.source_filtering': [],             # 来源过滤
            'rag.relevance_scoring': [],            # 相关性评分
        }

        self._logger.info("扩展点注册表初始化完成")

    def register_hook(self, extension_point: str, hook: ExtensionHook) -> bool:
        """注册扩展钩子

        Args:
            extension_point: 扩展点名称
            hook: 扩展钩子实例

        Returns:
            是否成功注册
        """
        with self._lock:
            try:
                if extension_point not in self._extension_points:
                    self._extension_points[extension_point] = []

                self._extension_points[extension_point].append(hook)

                # 同时注册到插件管理器
                self._plugin_manager.register_extension_point(
                    extension_point,
                    lambda context: hook.execute(context)
                )

                self._logger.info(f"注册扩展钩子成功: {extension_point}")
                return True

            except Exception as e:
                self._logger.error(f"注册扩展钩子失败: {extension_point}", exception=e)
                return False

    def unregister_hook(self, extension_point: str, hook: ExtensionHook) -> bool:
        """取消注册扩展钩子

        Args:
            extension_point: 扩展点名称
            hook: 扩展钩子实例

        Returns:
            是否成功取消注册
        """
        with self._lock:
            try:
                if extension_point in self._extension_points:
                    hooks = self._extension_points[extension_point]
                    if hook in hooks:
                        hooks.remove(hook)
                        self._logger.info(f"取消注册扩展钩子成功: {extension_point}")
                        return True

                return False

            except Exception as e:
                self._logger.error(f"取消注册扩展钩子失败: {extension_point}", exception=e)
                return False

    def execute_extension_point(self,
                              extension_point: str,
                              data: Dict[str, Any],
                              metadata: Dict[str, Any] = None) -> List[Any]:
        """执行扩展点

        Args:
            extension_point: 扩展点名称
            data: 传递给扩展的数据
            metadata: 扩展元数据

        Returns:
            扩展执行结果列表
        """
        with self._lock:
            try:
                if extension_point not in self._extension_points:
                    self._logger.warning(f"未知扩展点: {extension_point}")
                    return []

                # 创建扩展上下文
                context = ExtensionContext(
                    extension_point=extension_point,
                    data=data,
                    metadata=metadata or {}
                )

                results = []
                hooks = self._extension_points[extension_point]

                for hook in hooks:
                    try:
                        result = hook.execute(context)
                        results.append(result)
                    except Exception as e:
                        self._logger.error(f"执行扩展钩子失败: {extension_point}", exception=e)

                # 同时调用插件管理器的扩展点
                plugin_results = self._plugin_manager.call_extension_point(
                    extension_point, context
                )
                results.extend(plugin_results)

                return results

            except Exception as e:
                self._logger.error(f"执行扩展点失败: {extension_point}", exception=e)
                return []

    def get_extension_points(self) -> List[str]:
        """获取所有扩展点名称

        Returns:
            扩展点名称列表
        """
        with self._lock:
            return list(self._extension_points.keys())

    def get_hook_count(self, extension_point: str) -> int:
        """获取扩展点的钩子数量

        Args:
            extension_point: 扩展点名称

        Returns:
            钩子数量
        """
        with self._lock:
            return len(self._extension_points.get(extension_point, []))


class RAGExtensionPoints:
    """RAG系统扩展点封装类"""

    def __init__(self, registry: ExtensionPointRegistry = None):
        """初始化RAG扩展点

        Args:
            registry: 扩展点注册表实例
        """
        self._registry = registry or get_extension_registry()
        self._logger = get_logging_service()

    def before_document_upload(self, document_data: Dict[str, Any]) -> Dict[str, Any]:
        """文档上传前扩展点

        Args:
            document_data: 文档数据

        Returns:
            处理后的文档数据
        """
        results = self._registry.execute_extension_point(
            'document.before_upload',
            document_data
        )

        # 合并结果
        if results:
            for result in results:
                if isinstance(result, dict):
                    document_data.update(result)

        return document_data

    def after_document_upload(self, document_data: Dict[str, Any]) -> Dict[str, Any]:
        """文档上传后扩展点

        Args:
            document_data: 文档数据

        Returns:
            处理后的文档数据
        """
        results = self._registry.execute_extension_point(
            'document.after_upload',
            document_data
        )

        # 合并结果
        if results:
            for result in results:
                if isinstance(result, dict):
                    document_data.update(result)

        return document_data

    def before_query_processing(self, query_data: Dict[str, Any]) -> Dict[str, Any]:
        """查询处理前扩展点

        Args:
            query_data: 查询数据

        Returns:
            处理后的查询数据
        """
        results = self._registry.execute_extension_point(
            'query.before_processing',
            query_data
        )

        # 合并结果
        if results:
            for result in results:
                if isinstance(result, dict):
                    query_data.update(result)

        return query_data

    def after_query_processing(self, query_data: Dict[str, Any]) -> Dict[str, Any]:
        """查询处理后扩展点

        Args:
            query_data: 查询数据

        Returns:
            处理后的查询数据
        """
        results = self._registry.execute_extension_point(
            'query.after_processing',
            query_data
        )

        # 合并结果
        if results:
            for result in results:
                if isinstance(result, dict):
                    query_data.update(result)

        return query_data

    def enhance_context(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
        """上下文增强扩展点

        Args:
            context_data: 上下文数据

        Returns:
            增强后的上下文数据
        """
        results = self._registry.execute_extension_point(
            'rag.context_enhancement',
            context_data
        )

        # 合并结果
        if results:
            for result in results:
                if isinstance(result, dict):
                    context_data.update(result)

        return context_data

    def validate_answer(self, answer_data: Dict[str, Any]) -> Dict[str, Any]:
        """答案验证扩展点

        Args:
            answer_data: 答案数据

        Returns:
            验证后的答案数据
        """
        results = self._registry.execute_extension_point(
            'rag.answer_validation',
            answer_data
        )

        # 合并结果
        if results:
            for result in results:
                if isinstance(result, dict):
                    answer_data.update(result)

        return answer_data


# 全局扩展点注册表实例
_extension_registry_instance: Optional[ExtensionPointRegistry] = None
_extension_registry_lock = threading.Lock()


def get_extension_registry() -> ExtensionPointRegistry:
    """获取扩展点注册表单例实例"""
    global _extension_registry_instance

    if _extension_registry_instance is None:
        with _extension_registry_lock:
            if _extension_registry_instance is None:
                _extension_registry_instance = ExtensionPointRegistry()

    return _extension_registry_instance


def get_rag_extensions() -> RAGExtensionPoints:
    """获取RAG扩展点实例"""
    return RAGExtensionPoints()