File size: 3,273 Bytes
2ea9ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import *  # pylint: disable=wildcard-import,unused-wildcard-import
from abc import ABC, abstractmethod

import math

import torch


class LMScorer(ABC):
    def __init__(self, model_name: str, **kwargs: Any) -> None:
        self._build(model_name, kwargs)

    @overload
    def sentence_score(
        self, text: str, log: bool = False, reduce: str = "prod"
    ) -> float:
        ...

    @overload
    def sentence_score(
        self, text: List[str], log: bool = False, reduce: str = "prod"
    ) -> List[float]:
        ...

    def sentence_score(
        self, text: Union[str, List[str]], log: bool = False, reduce: str = "prod",
    ) -> Union[float, List[float]]:
        sentences = [text] if isinstance(text, str) else text
        scores: List[float] = []
        if len(sentences) == 0:
            return scores

        outputs = self._tokens_log_prob(sentences)
        for output in outputs:
            log_probs = output[0]
            tlen = log_probs.shape[0]

            if reduce == "prod":
                score = log_probs.sum()
            elif reduce == "mean":
                score = log_probs.logsumexp(0) - math.log(tlen)
            elif reduce == "gmean":
                score = log_probs.mean(0)
            elif reduce == "hmean":
                score = log_probs.neg().logsumexp(0).neg() + math.log(tlen)
            else:
                raise ValueError("Unrecognized scoring strategy: %s" % reduce)
            if not log:
                score = score.exp()

            scores.append(score.item())

        return scores[0] if isinstance(text, str) else scores

    @overload
    def tokens_score(
        self, text: str, log: bool = False
    ) -> Tuple[List[float], List[int], List[str]]:
        ...

    @overload
    def tokens_score(
        self, text: List[str], log: bool = False
    ) -> List[Tuple[List[float], List[int], List[str]]]:
        ...

    def tokens_score(
        self, text: Union[str, List[str]], log: bool = False
    ) -> Union[
        Tuple[List[float], List[int], List[str]],
        List[Tuple[List[float], List[int], List[str]]],
    ]:
        sentences = [text] if isinstance(text, str) else text
        outputs: List[Tuple[List[float], List[int], List[str]]] = []
        if len(sentences) == 0:
            return outputs

        for log_probs, ids, tokens in self._tokens_log_prob(sentences):
            scores = log_probs if log else log_probs.exp()
            scores = cast(torch.DoubleTensor, scores)
            output = (scores.tolist(), ids.tolist(), tokens)
            outputs.append(output)

        return outputs[0] if isinstance(text, str) else outputs

    @classmethod
    def supported_model_names(cls) -> Iterable[str]:
        return cls._supported_model_names()

    def _build(self, model_name: str, options: Dict[str, Any]) -> None:
        # pylint: disable=attribute-defined-outside-init, unused-argument
        self.model_name = model_name

    @abstractmethod
    def _tokens_log_prob(
        self, text: List[str]
    ) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
        ...  # pragma: no cover

    @classmethod
    @abstractmethod
    def _supported_model_names(cls) -> Iterable[str]:
        ...  # pragma: no cover