File size: 2,550 Bytes
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from dataclasses import  dataclass, field
from typing import Set

from graphgen.models.evaluate.base_evaluator import BaseEvaluator
from graphgen.models.text.text_pair import TextPair
from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop


nltk_helper = NLTKHelper()

@dataclass
class MTLDEvaluator(BaseEvaluator):
    """
    衡量文本词汇多样性的指标
    """
    stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english")))
    stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese")))

    async def evaluate_single(self, pair: TextPair) -> float:
        loop = create_event_loop()
        return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer)

    def _calculate_mtld_score(self, text: str, threshold=0.72) -> float:
        """
        计算MTLD (向前和向后的平均值)

        min is 1.0
        higher is better
        """
        if not text or not text.strip():
            return 0.0

        lang = detect_main_language(text)
        tokens = nltk_helper.word_tokenize(text, lang)

        stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en
        filtered_tokens = [word for word in tokens if word not in stopwords]
        filtered_tokens = [word for word in filtered_tokens if word.isalnum()]

        if not filtered_tokens:
            return 0

        # 计算向前的MTLD
        forward_factors = self._compute_factors(filtered_tokens, threshold)

        # 计算向后的MTLD
        backward_factors = self._compute_factors(filtered_tokens[::-1], threshold)

        # 取平均值
        return (forward_factors + backward_factors) / 2

    @staticmethod
    def _compute_factors(tokens: list, threshold: float) -> float:
        factors = 0
        current_segment = []
        unique_words = set()

        for token in tokens:
            current_segment.append(token)
            unique_words.add(token)
            ttr = len(unique_words) / len(current_segment)

            if ttr <= threshold:
                factors += 1
                current_segment = []
                unique_words = set()

        # 处理最后一个不完整片段
        if current_segment:
            ttr = len(unique_words) / len(current_segment)
            if ttr <= threshold:
                factors += 1
            else:
                factors += (1 - (ttr - threshold) / (1 - threshold))

        return len(tokens) / factors if factors > 0 else len(tokens)