File size: 7,913 Bytes
2885a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f469daa
2885a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5aef9f
 
 
2885a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5aef9f
2885a60
 
b5aef9f
2885a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f469daa
 
2885a60
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import torch.nn.functional as F
from transformers import (AutoModelForSeq2SeqLM,
                        AutoTokenizer,
                        PreTrainedTokenizer,
                        PreTrainedTokenizerFast)
import evaluate

from fire import Fire
import pandas as pd
from tqdm import tqdm
import json

from typing import List, Dict, Union
from collections import defaultdict
from functools import partial
from pprint import pprint

from ipdb import set_trace

class Harimplus_Scorer:
    def __init__(self,
                    pretrained_name:str='none',
                    tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast]=None,
                    mixing_factor:float=7., # same as lambda in the paper
                    device:str='cuda',

                    src_maxlen=1024,
                    tgt_maxlen=110,
                ):
        self._pretrained_name = pretrained_name
        self._lambda = mixing_factor

        self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self._encdec_model = AutoModelForSeq2SeqLM.from_pretrained(self._pretrained_name)
        if tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self._pretrained_name)
        else:
            self._tokenizer = tokenizer
        self._encdec_model.to(self._device)
        self._encdec_model.eval()

        self._src_maxlen = src_maxlen
        self._tgt_maxlen = tgt_maxlen



    def _prep_input(self, src_tgt_txts, src_or_tgt='src'):
        L = self._src_maxlen if src_or_tgt=='src' else self._tgt_maxlen
        if isinstance(src_tgt_txts, pd.Series):
            src_tgt_txts=src_tgt_txts.tolist()
            if src_or_tgt == 'src':
                src_tgt_txts = [ s.replace("\n", " ") for s in src_tgt_txts ]
        return self._tokenizer(src_tgt_txts, padding=True, truncation=True, max_length=L, return_tensors='pt') # ModelInput dataclass


    '''below are helper functions w/o dependency to the self, but included inside the class for ease of use'''
    def likelihoods(self, logits, force_decode_indices, tgt_mask):
        probs = F.softmax(logits, dim=-1)
        probs_force_decode_ = probs.gather(-1, force_decode_indices.unsqueeze(-1)).squeeze()
        probs_force_decode= probs_force_decode_ * tgt_mask
        assert probs_force_decode.shape == force_decode_indices.shape
        return probs_force_decode

    def log_likelihoods(self, logits, force_decode_indices, tgt_mask):
        ll = F.log_softmax(logits, dim=-1)
        ll_force_decode_ = ll.gather(-1, force_decode_indices.unsqueeze(-1)).squeeze()
        ll_force_decode = ll_force_decode_ * tgt_mask

        return ll_force_decode

    def harim(self, s2s_logits, lm_logits, force_decode_indices, tgt_mask ):
        p_s2s, p_lm = self.likelihoods(s2s_logits, force_decode_indices, tgt_mask), \
                        self.likelihoods(lm_logits, force_decode_indices, tgt_mask)

        delta = p_s2s - p_lm
        margin_linear = (1-delta) / 2
        harim = -(1-p_s2s) * margin_linear + 1
        return harim # this is -1 * hallucination risk

    def make_minibatches(self, exs:List[str], bsz:int=32):
        idx=0
        minibatches = []
        while True:
            start = idx
            end = idx+bsz
            if start >= len(exs):
                break

            minibatches.append( exs[start:end] )
            idx += bsz
        return minibatches

    def make_empty_minibatches(self, minibatches:List[List[str]]):
        e_minibatches = minibatches.copy()
        for i, mb in enumerate(e_minibatches):
            e_minibatches[i] = ['' for ex in mb]
        return e_minibatches


    def compute(self, predictions:List[str],
                        references:List[str],
                        bsz:int=32,
                        use_aggregator:bool=False,
                        return_details:bool=False,
                        tokenwise_score:bool=False,
                        ):
        '''
        returns harim+ score (List[float]) for predictions (summaries) and references (articles)
        **Note**
            - here, predictions = generated summaries to be evaluated, references = article to be summarized (but to follow the convention of the evaluate, we named kwarg as "references")
            - log_ppl equals to bartscore (yuan et al., neurips 2021)

        if tokenwise_score:
            returns minibatch chunks of harim+ scores and log-likelihoods with tokenized predictions (List[str])
        if use_aggregator:
            returning scores are aggregated (mean) over given test set
        '''

        # tokenize/prep src/tgts
        make_minibatches_bsz = partial(self.make_minibatches, bsz=bsz)
        summaries = predictions
        articles = references
        b_srcs, b_tgts = map(make_minibatches_bsz, [articles, summaries])
        b_emps = self.make_empty_minibatches(b_srcs)

        scores=defaultdict(list)
        for mini_s, mini_e, mini_t in tqdm(zip(b_srcs, b_emps, b_tgts), total=len(b_tgts), desc=f"computing HaRiM+ {bsz=}, core={self._pretrained_name}"):
            src_in = self._prep_input(mini_s, src_or_tgt='src')
            emp_in = self._prep_input(mini_e, src_or_tgt='src')
            tgt_in = self._prep_input(mini_t, src_or_tgt='tgt')
            if emp_in.input_ids.shape[-1]==0: # emp_in.input_ids.shape == (32,0)
                boseos = f"{self._tokenizer.bos_token}{self._tokenizer.eos_token}"
                mini_e_ = [boseos for _ in range(len(mini_e))]
                emp_in = self._prep_input( mini_e_, src_or_tgt='src' )


            tgt_mask = tgt_in.attention_mask

            src_in = src_in.to(self._device)
            emp_in = emp_in.to(self._device)
            tgt_in = tgt_in.to(self._device)
            tgt_mask = tgt_mask.to(self._device)

            with torch.no_grad():
                # token_type_ids attribute causes error
                s2s_logits = self._encdec_model.forward(
                                                    input_ids = src_in.input_ids,
                                                    attention_mask = src_in.attention_mask,
                                                    labels = tgt_in.input_ids,
                                                    return_dict=True).logits
                lm_logits = self._encdec_model.forward(
                                                    input_ids = emp_in.input_ids,
                                                    attention_mask = emp_in.attention_mask,
                                                    labels = tgt_in.input_ids,
                                                    return_dict=True).logits
                sent_lengths = tgt_mask.sum(-1)
                ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
                ll = ll_tok.sum(-1) / sent_lengths

                harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
                harim = harim_tok.sum(-1) / sent_lengths

                harim_plus_normalized = ll + self._lambda * harim # loglikelihood + lambda * negative_harim (negative harim=-1* risk)

                scores['harim+'].extend(harim_plus_normalized.tolist())
                scores['harim'].extend(harim.tolist())
                scores['log_ppl'].extend(ll.tolist())

                if tokenwise_score:
                    scores['tok_harim+'].append(harim_tok*self._lambda + ll_tok)
                    scores['tok_predictions'].append( [self._tokenizer.convert_ids_to_token(idxs) for idxs in src_in.labels] )

        if use_aggregator: # after
            for k, v in scores.items():
                if not k.startswith('tok_'):
                    scores[k] = sum(v)/len(v) # aggregate (mean)
        scores['lambda'] = self._lambda
        if not return_details:
            scores = scores['harim+']
        return scores