|
from dataclasses import dataclass
|
|
from typing import List
|
|
from .models import ChatMessage
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
@dataclass
|
|
class ConversationStrategy(ABC):
|
|
"""对话策略的抽象基类"""
|
|
|
|
is_single_chat: bool
|
|
|
|
@abstractmethod
|
|
def is_same_conversation(
|
|
self, history_msg: List[ChatMessage], current_msg: ChatMessage
|
|
) -> bool:
|
|
"""判断两条消息是否属于同一个对话"""
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class TimeWindowStrategy(ConversationStrategy):
|
|
"""基于时间窗口的判断策略"""
|
|
|
|
time_window: int
|
|
|
|
def is_same_conversation(
|
|
self, history_msg: List[ChatMessage], current_msg: ChatMessage
|
|
) -> bool:
|
|
time_diff = abs(
|
|
(current_msg.CreateTime - history_msg[-1].CreateTime)
|
|
).total_seconds()
|
|
return time_diff <= self.time_window
|
|
|
|
|
|
@dataclass
|
|
class LLMStrategy(ConversationStrategy):
|
|
"""基于大模型判断策略"""
|
|
|
|
def is_same_conversation(
|
|
self, history_msg: List[ChatMessage], current_msg: ChatMessage
|
|
) -> bool:
|
|
|
|
return current_msg.talker == history_msg[-1].talker if history_msg else False
|
|
|
|
|
|
@dataclass
|
|
class CompositeStrategy(ConversationStrategy):
|
|
"""组合多个策略的复合策略"""
|
|
|
|
strategies: List[ConversationStrategy]
|
|
require_all: bool = True
|
|
|
|
def is_same_conversation(
|
|
self, history_msg: List[ChatMessage], current_msg: ChatMessage
|
|
) -> bool:
|
|
results = [
|
|
s.is_same_conversation(history_msg, current_msg) for s in self.strategies
|
|
]
|
|
return all(results) if self.require_all else any(results)
|
|
|