File size: 20,127 Bytes
2885a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import torch
import torch.nn.functional as F
from transformers import (AutoModelForSeq2SeqLM,
                        AutoTokenizer,
                        PreTrainedTokenizer,
                        PreTrainedTokenizerFast)
import evaluate

from fire import Fire
import pandas as pd
from tqdm import tqdm
import json

from typing import List, Dict, Union
from collections import defaultdict
from functools import partial
from pprint import pprint

from ipdb import set_trace

class Harimplus_Scorer:
    def __init__(self,
                    pretrained_name:str='none',
                    tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast]=None,
                    mixing_factor:float=7., # same as lambda in the paper
                    device:str='cuda',

                    src_maxlen=1024,
                    tgt_maxlen=110,
                ):
        self._pretrained_name = pretrained_name
        self._lambda = mixing_factor

        self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self._encdec_model = AutoModelForSeq2SeqLM.from_pretrained(self._pretrained_name)
        if tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self._pretrained_name)
        else:
            self._tokenizer = tokenizer
        self._encdec_model.to(self._device)
        self._encdec_model.eval()

        self._src_maxlen = src_maxlen
        self._tgt_maxlen = tgt_maxlen



    def _prep_input(self, src_tgt_txts, src_or_tgt='src'):
        L = self._src_maxlen if src_or_tgt=='src' else self._tgt_maxlen
        if isinstance(src_tgt_txts, pd.Series):
            src_tgt_txts=src_tgt_txts.tolist()
            if src_or_tgt == 'src':
                src_tgt_txts = [ s.replace("\n", " ") for s in src_tgt_txts ]
        return self._tokenizer(src_tgt_txts, padding=True, truncation=True, max_length=L, return_tensors='pt') # ModelInput dataclass


    '''below are helper functions w/o dependency to the self, but included inside the class for ease of use'''
    def likelihoods(self, logits, force_decode_indices, tgt_mask):
        probs = F.softmax(logits, dim=-1)
        probs_force_decode_ = probs.gather(-1, force_decode_indices.unsqueeze(-1)).squeeze()
        probs_force_decode= probs_force_decode_ * tgt_mask
        assert probs_force_decode.shape == force_decode_indices.shape
        return probs_force_decode

    def log_likelihoods(self, logits, force_decode_indices, tgt_mask):
        ll = F.log_softmax(logits, dim=-1)
        ll_force_decode_ = ll.gather(-1, force_decode_indices.unsqueeze(-1)).squeeze()
        ll_force_decode = ll_force_decode_ * tgt_mask

        return ll_force_decode

    def harim(self, s2s_logits, lm_logits, force_decode_indices, tgt_mask ):
        p_s2s, p_lm = self.likelihoods(s2s_logits, force_decode_indices, tgt_mask), \
                        self.likelihoods(lm_logits, force_decode_indices, tgt_mask)

        delta = p_s2s - p_lm
        margin_linear = (1-delta) / 2
        harim = -(1-p_s2s) * margin_linear + 1
        return harim # this is -1 * hallucination risk

    def make_minibatches(self, exs:List[str], bsz:int=32):
        idx=0
        minibatches = []
        while True:
            start = idx
            end = idx+bsz
            if start >= len(exs):
                break

            minibatches.append( exs[start:end] )
            idx += bsz
        return minibatches

    def make_empty_minibatches(self, minibatches:List[List[str]]):
        e_minibatches = minibatches.copy()
        for i, mb in enumerate(e_minibatches):
            e_minibatches[i] = ['' for ex in mb]
        return e_minibatches


    def compute(self, predictions:List[str],
                        references:List[str],
                        bsz:int=32,
                        use_aggregator:bool=False,
                        tokenwise_score:bool=False,
                        ):
        '''
        returns harim+ score (List[float]) for predictions (summaries) and references (articles)
        **Note**
            - here, predictions = generated summaries to be evaluated, references = article to be summarized (but to follow the convention of the evaluate, we named kwarg as "references")
            - log_ppl equals to bartscore (yuan et al., neurips 2021)

        if tokenwise_score:
            returns minibatch chunks of harim+ scores and log-likelihoods with tokenized predictions (List[str])
        if use_aggregator:
            returning scores are aggregated (mean) over given test set
        '''


        # tokenize/prep src/tgts
        make_minibatches_bsz = partial(self.make_minibatches, bsz=bsz)
        b_srcs, b_tgts = map(make_minibatches_bsz, [predictions, references])
        b_emps = self.make_empty_minibatches(b_srcs)

        scores=defaultdict(list)
        for mini_s, mini_e, mini_t in tqdm(zip(b_srcs, b_emps, b_tgts), total=len(b_tgts), desc=f"computing HaRiM+ {bsz=}, core={self._pretrained_name}"):
            src_in = self._prep_input(mini_s, src_or_tgt='src')
            emp_in = self._prep_input(mini_e, src_or_tgt='src')
            tgt_in = self._prep_input(mini_t, src_or_tgt='tgt')
            if emp_in.input_ids.shape[-1]==0: # emp_in.input_ids.shape == (32,0)
                boseos = f"{self._tokenizer.bos_token}{self._tokenizer.eos_token}"
                mini_e_ = [boseos for _ in range(len(mini_e))]
                emp_in = self._prep_input( mini_e_, src_or_tgt='src' )

            # if mini_s == b_srcs[0]:
            #     normal =  src_in
            # if mini_s == b_srcs[-1]:
            #     trailing = src_in
            #     set_trace()

            src_in.data['labels'] = tgt_in.input_ids
            emp_in.data['labels'] = tgt_in.input_ids
            # print(f"{emp_in.data['labels']=}")
            # set_trace()
            tgt_mask = tgt_in.attention_mask

            assert (tgt_in.attention_mask == (tgt_in.input_ids != self._tokenizer.pad_token_id)).all()
            # src_in.data['decoder_input_ids'] = tgt_in.input_ids
            # src_in.data['decoder_attention_mask'] = tgt_in.attention_mask
            src_in = src_in.to(self._device)
            emp_in = emp_in.to(self._device)
            tgt_in = tgt_in.to(self._device)
            tgt_mask = tgt_mask.to(self._device)

            with torch.no_grad():
                # token_type_ids attribute causes error
                s2s_logits = self._encdec_model.forward(
                                                    input_ids = src_in.input_ids,
                                                    attention_mask = src_in.attention_mask,
                                                    labels = tgt_in.input_ids,
                                                    # decoder_input_ids = tgt_in.input_ids,
                                                    # decoder_attention_mask = tgt_in.attention_mask,
                                                    return_dict=True).logits
                lm_logits = self._encdec_model.forward(
                                                    input_ids = emp_in.input_ids,
                                                    attention_mask = emp_in.attention_mask,
                                                    labels = tgt_in.input_ids,
                                                    # decoder_input_ids = tgt_in.input_ids,
                                                    # decoder_attention_mask = tgt_in.attention_mask,
                                                    return_dict=True).logits
                sent_lengths = tgt_mask.sum(-1)
                ll_tok = self.log_likelihoods(s2s_logits, src_in.labels, tgt_mask)
                ll = ll_tok.sum(-1) / sent_lengths

                harim_tok = self.harim(s2s_logits, lm_logits, src_in.labels, tgt_mask)
                harim = harim_tok.sum(-1) / sent_lengths

                harim_plus_normalized = ll + self._lambda * harim # loglikelihood + lambda * negative_harim (negative harim=-1* risk)

                scores['harim+'].extend(harim_plus_normalized.tolist())
                scores['harim'].extend(harim.tolist())
                scores['log_ppl'].extend(ll.tolist())

                if tokenwise_score:
                    scores['tok_harim+'].append(harim_tok*self._lambda + ll_tok)
                    scores['tok_predictions'].append( [self._tokenizer.convert_ids_to_token(idxs) for idxs in src_in.labels] )

        if use_aggregator: # after
            for k, v in scores.items():
                if not k.startswith('tok_'):
                    scores[k] = sum(v)/len(v) # aggregate (mean)
        scores['lambda'] = self._lambda
        return scores



def test(bsz = 16, pretrained_name='facebook/bart-large-cnn', tokenizer=None):
    if tokenizer is None:
        scorer = Harimplus_Scorer(pretrained_name=pretrained_name)
    else:
        scorer = Harimplus_Scorer(pretrained_name=pretrained_name, tokenizer=tokenizer)

    art1 = """The respected law professor from Philadelphia now being investigated after allegedly emailing students a link to pornographic footage, was once a contestant on Who Wants to Be a Millionaire, it has emerged. Lisa McElroy, a 50-year-old Drexel professor, appeared on the show in 2010 while it was still hosted my Meredith Vieira. And like her apparent March 31 email mishap, her game show appearance ended with a very public mistake. McElroy, who teaches legal writing, got tripped up on the $12,500 level after flying through the first few questions, notes Philly.com. Wishes she was a millionaire: Drexel law profesor professor Lisa McElroy allegedly sent a link to a pornographic website to her students. In 2010, she appeared on the TV game show Who Wants to Be a Milionaire Mother of two: The mother of two shared an anecdote with then-host Meredith Vieira about having to scramble to find a babysitter for her kids and someone to teach her class after learning she was to appear on the show just two days before taping Lost it: McElroy was tripped up on the $12,500 question. Despite having used two lifelines, she answered wrong and walked away with around $5,000 The questions read: 'As a result of General Motor’s bankruptcy declaration in 2009, what foreign government became one of its largest shareholders?' Even after using two of her lifelines to narrow down the answer, McElroy answered China, which was incorrect. The correct answer was Canada. She walked away with around $5,000. McElroy, who is a children's book and biography author, is apparently also a mother. She opened the appearance by sharing an anecdote with Vieira about having to scramble to find a babysitter after being informed she was chosen to be on Millionaire jsut two days prior to taping. She's accused of sending the inappropriate message this past March 31 under the subject line: 'Great article on writing briefs.' However, when recipients opened the enclosed link, philly.com reports that they were directed to a video of 'a woman engaging in a sexually explicit act'. Lisa McElroy, 50, who teaches legal writing at Drexel University, reportedly sent the inappropriate message on March 31 baring the subject line: 'Great article on writing briefs' Following a number of complaints, the college issued an apology to students. The message read: 'As you may be aware, some students erroneously received an email this morning directing them to a... post that included some inappropriate material. 'We take this matter seriously and apologize for any upset it may have caused.' The university says federal law requires it investigate all reports of inappropriate behaviors of a sexual nature. McElroy did not immediately respond to an email sent to her university account by the Associated Press. When recipients opened the enclosed link, philly.com reports that they were directed to a video of 'a woman engaging in a sexually explicit act' It's not the first time the married mother-of-two has appeared in the spotlight. She is also an accomplished author with a number of published biographies and children's books. On her website, www.lisamcelroy.com, she describes herself as a 'Supreme Court junkie.' She adds that her favorites ways of relaxing include 'crawling under the covers with a dog or two and a really good book' or 'hanging out' with her two adolescent daughters. Regarding the recent email scandal, David Lat - a lawyer and legal commenter -suggests she could have been 'hacked' or made a 'copy/paste error'. While an internal investigation gets underway, it's been reported that McElroy has been placed on administrative leave. While an internal investigation gets underway, it's been reported that McElroy has been placed on administrative leave from Drexel University (seen above)"""
    art2 = """Spain's 2-0 defeat by Holland on Tuesday brought back bitter memories of their disastrous 2014 World Cup, but coach Vicente del Bosque will not be too worried about a third straight friendly defeat, insists Gerard Pique. Holland, whose 5-1 drubbing of Spain in the group stage in Brazil last year marked the end of the Iberian nation's six-year domination of the world game, scored two early goals at the Amsterdam Arena and held on against some determined Spain pressure in the second half for a 2-0 success. They became the first team to inflict two defeats on Del Bosque since he took over in 2008 but the gruff 64-year-old had used the match to try out several new faces and he fielded a largely experimental, second-string team. Stefan de Vrij (right) headed Holland in front against Spain at the Amsterdam Arena on Tuesday Gerard Pique (left) could do nothing to stop Davy Klaassen doubling the Dutch advantage Malaga forward Juanmi and Sevilla midfielder Vitolo became the 55th and 56th players to debut under Del Bosque, while the likes of goalkeeper David de Gea, defenders Raul Albiol, Juan Bernat and Dani Carvajal and midfielder Mario Suarez all started the game. 'The national team's state of health is good,' centre back Gerard Pique told reporters. 'We are in a process where players are coming into the team and gathering experience,' added the Barcelona defender. 'We are second in qualifying (for Euro 2016) and these friendly games are for experimenting. 'I am not that worried about this match because we lost friendlies in previous years and then ended up winning titles.' David de Gea was given a start by Vicente del Bosque but could not keep out De Vrij's header here Dani Carvajal (centre) was another squad player given a chance to impress against Holland Del Bosque will be confident he can find the right mix of players to secure Spain's berth at Euro 2016 in France next year, when they will be chasing an unprecedented third straight title. Slovakia are the surprise leaders in qualifying Group C thanks to a 2-1 win over Spain in Zilina in October and have a maximum 15 points from five of 10 matches. Spain are second on 12 points, three ahead of Ukraine, who they beat 1-0 in Seville on Friday. Del Bosque's side host Slovakia in September in a match that could decide who goes through to the finals as group winners. 'The team is in good shape,' forward Pedro told reporters. 'We have a very clear idea of our playing style and we are able to count on people who are gradually making a place for themselves in the team.'"""

    summaries = [
        "lisa mcelroy , 50 , who teaches legal writing at drexel university , reportedly sent the ` inappropriate ' message on march 31 . when recipients clicked the enclosed link , they were allegedly directed to a video of ' a woman engaging in a sexually explicit act ' . mcelroy appeared on the popular game show in 2010 with then-host meredith vieira but lost the game after reaching just $ 12,500 . along with teaching law , mcelroy is also an accomplished author with a number of published biographies and children 's books . has been placed on leave while school investigates .", # reference 2.3270
        "lisa mcelroy, a 50-year-old drexel professor, appeared on the show in 2010 while it was still hosted my meredith vieira. she's accused of sending the inappropriate message this past march 31 under the subject line: 'great article on writing briefs' when recipients opened the enclosed link, philly.com reports that they were directed to a video of 'a woman engaging in a sexually explicit act' the married mother-of-two has been placed on administrative leave.", # BART-large+cnn 4.9714
        "lisa mcelroy , 50 , who teaches legal writing at drexel university , appeared on the show in 2010 while it was still hosted my meredith vieira . she got tripped up on the $ 12,500 level after flying through the first few questions , philly.com reports . mcelroy answered wrong and walked away with around $ 5,000 .", # BERTSUM=Factual 3.2028

        "lisa mcelroy , 50 , who teaches legal writing at philadelphia university , reportedly sent the ` inappropriate ' message on march 31 . when recipients clicked the enclosed link , they were allegedly directed to a video of ' a woman engaging in a sexually explicit act ' . mcelroy appeared on the popular game show in 2010 with then-host meredith vieira but lost the game after reaching just $ 12,500 . along with teaching law , mcelroy is also an accomplished author with a number of published biographies and children 's books . has been placed on leave while school investigates .", # wrong subj (philadelphia) 2.2122
        "lisa mcelroy , 50 , who teaches legal writing at drexel university , reportedly did not send the ` inappropriate ' message on march 31 . when recipients clicked the enclosed link , they were allegedly directed to a video of ' a woman engaging in a sexually explicit act ' . mcelroy appeared on the popular game show in 2010 with then-host meredith vieira but lost the game after reaching just $ 12,500 . along with teaching law , mcelroy is also an accomplished author with a number of published biographies and children 's books . has been placed on leave while school investigates .", # negation 2.2022

        "holland beat spain 2-0 at the amsterdam arena on tuesday night . stefan de vrij and davy klaassen scored goals for holland . defeat recalls horror 5-1 defeat by holland at the world cup . vicente del bosque used game to give younger spain players a chance .",# reference
        "holland beat spain 2-0 in the group stage in brazil on tuesday night . del bosque will be hoping to find the right mix of players to the world cup . gerard pique could make the right mix of players to the tournament .",# summary (factuality = 0, rnn)
        "del bosque beat spain 2-0 at the amsterdam arena on tuesday night . stefan de vrij and davy klaassen scored goals for holland . defeat recalls horror 5-1 defeat by holland at the world cup . vicente del bosque used game to give younger spain players a chance .",# reference + wrong subj
        "holland could not beat spain 2-0 at the amsterdam arena on tuesday night . stefan de vrij and davy klaassen scored goals for holland . defeat recalls horror 5-1 defeat by holland at the world cup . vicente del bosque used game to give younger spain players a chance .",# reference + negation



    ]
    articles = [ art1 ]*5 + [art2 ]*4
    # set_trace()
    hp_score = scorer.compute(predictions=summaries, references=articles, use_aggregator=False, bsz=bsz)
    # pprint(f"{articles=}")
    # pprint(f"{summaries=}")
    pprint(hp_score)



'''
## drexel example
# reference 2.3270
# BART-large+cnn 4.9714
# BERTSUM=Factual 3.2028
# ref + wrong subj (philadelphia) 2.2122
# ref + negation 2.2022


   'harim+': [1.6270232200622559,
            1.7585878372192383,
            1.3859858512878418,
            1.5434350967407227,
            1.609492301940918],



## main table result

1.6247 (reference, factual)
0.1173 (rnn, unfactual)
1.3229 (ref + wrong subj)
1.4132 (ref + negation)

        'harim+': [1.8230078220367432,
                1.5361897945404053,
                1.806436538696289,
                1.7360382080078125],

'''

if __name__ == '__main__':
    Fire(test)