cren / weclone /data /strategies.py
CrenCren's picture
Upload folder using huggingface_hub
88aba71 verified
from dataclasses import dataclass
from typing import List
from .models import ChatMessage
from abc import ABC, abstractmethod
@dataclass
class ConversationStrategy(ABC):
"""对话策略的抽象基类"""
is_single_chat: bool
@abstractmethod
def is_same_conversation(
self, history_msg: List[ChatMessage], current_msg: ChatMessage
) -> bool:
"""判断两条消息是否属于同一个对话"""
pass
@dataclass
class TimeWindowStrategy(ConversationStrategy):
"""基于时间窗口的判断策略"""
time_window: int # 时间窗口(分钟)
def is_same_conversation(
self, history_msg: List[ChatMessage], current_msg: ChatMessage
) -> bool:
time_diff = abs(
(current_msg.CreateTime - history_msg[-1].CreateTime)
).total_seconds()
return time_diff <= self.time_window
@dataclass
class LLMStrategy(ConversationStrategy):
"""基于大模型判断策略"""
def is_same_conversation(
self, history_msg: List[ChatMessage], current_msg: ChatMessage
) -> bool:
# 修复user_id错误,使用talker字段代替user_id
return current_msg.talker == history_msg[-1].talker if history_msg else False
@dataclass
class CompositeStrategy(ConversationStrategy):
"""组合多个策略的复合策略"""
strategies: List[ConversationStrategy]
require_all: bool = True # True表示所有策略都满足,False表示任一策略满足即可
def is_same_conversation(
self, history_msg: List[ChatMessage], current_msg: ChatMessage
) -> bool:
results = [
s.is_same_conversation(history_msg, current_msg) for s in self.strategies
]
return all(results) if self.require_all else any(results)