from dataclasses import dataclass, field from typing import Set from graphgen.models.evaluate.base_evaluator import BaseEvaluator from graphgen.models.text.text_pair import TextPair from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop nltk_helper = NLTKHelper() @dataclass class MTLDEvaluator(BaseEvaluator): """ 衡量文本词汇多样性的指标 """ stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english"))) stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))) async def evaluate_single(self, pair: TextPair) -> float: loop = create_event_loop() return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer) def _calculate_mtld_score(self, text: str, threshold=0.72) -> float: """ 计算MTLD (向前和向后的平均值) min is 1.0 higher is better """ if not text or not text.strip(): return 0.0 lang = detect_main_language(text) tokens = nltk_helper.word_tokenize(text, lang) stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en filtered_tokens = [word for word in tokens if word not in stopwords] filtered_tokens = [word for word in filtered_tokens if word.isalnum()] if not filtered_tokens: return 0 # 计算向前的MTLD forward_factors = self._compute_factors(filtered_tokens, threshold) # 计算向后的MTLD backward_factors = self._compute_factors(filtered_tokens[::-1], threshold) # 取平均值 return (forward_factors + backward_factors) / 2 @staticmethod def _compute_factors(tokens: list, threshold: float) -> float: factors = 0 current_segment = [] unique_words = set() for token in tokens: current_segment.append(token) unique_words.add(token) ttr = len(unique_words) / len(current_segment) if ttr <= threshold: factors += 1 current_segment = [] unique_words = set() # 处理最后一个不完整片段 if current_segment: ttr = len(unique_words) / len(current_segment) if ttr <= threshold: factors += 1 else: factors += (1 - (ttr - threshold) / (1 - threshold)) return len(tokens) / factors if factors > 0 else len(tokens)