import torch import torch.nn.functional as F from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) import evaluate from fire import Fire import pandas as pd from tqdm import tqdm import json from typing import List, Dict, Union from collections import defaultdict from functools import partial from pprint import pprint from ipdb import set_trace class Harimplus_Scorer: def __init__(self, pretrained_name:str='none', tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast]=None, mixing_factor:float=7., # same as lambda in the paper device:str='cuda', src_maxlen=1024, tgt_maxlen=110, ): self._pretrained_name = pretrained_name self._lambda = mixing_factor self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self._encdec_model = AutoModelForSeq2SeqLM.from_pretrained(self._pretrained_name) if tokenizer is None: self._tokenizer = AutoTokenizer.from_pretrained(self._pretrained_name) else: self._tokenizer = tokenizer self._encdec_model.to(self._device) self._encdec_model.eval() self._src_maxlen = src_maxlen self._tgt_maxlen = tgt_maxlen def _prep_input(self, src_tgt_txts, src_or_tgt='src'): L = self._src_maxlen if src_or_tgt=='src' else self._tgt_maxlen if isinstance(src_tgt_txts, pd.Series): src_tgt_txts=src_tgt_txts.tolist() if src_or_tgt == 'src': src_tgt_txts = [ s.replace("\n", " ") for s in src_tgt_txts ] return self._tokenizer(src_tgt_txts, padding=True, truncation=True, max_length=L, return_tensors='pt') # ModelInput dataclass '''below are helper functions w/o dependency to the self, but included inside the class for ease of use''' def likelihoods(self, logits, force_decode_indices, tgt_mask): probs = F.softmax(logits, dim=-1) probs_force_decode_ = probs.gather(-1, force_decode_indices.unsqueeze(-1)).squeeze() probs_force_decode= probs_force_decode_ * tgt_mask assert probs_force_decode.shape == force_decode_indices.shape return probs_force_decode def log_likelihoods(self, logits, force_decode_indices, tgt_mask): ll = F.log_softmax(logits, dim=-1) ll_force_decode_ = ll.gather(-1, force_decode_indices.unsqueeze(-1)).squeeze() ll_force_decode = ll_force_decode_ * tgt_mask return ll_force_decode def harim(self, s2s_logits, lm_logits, force_decode_indices, tgt_mask ): p_s2s, p_lm = self.likelihoods(s2s_logits, force_decode_indices, tgt_mask), \ self.likelihoods(lm_logits, force_decode_indices, tgt_mask) delta = p_s2s - p_lm margin_linear = (1-delta) / 2 harim = -(1-p_s2s) * margin_linear + 1 return harim # this is -1 * hallucination risk def make_minibatches(self, exs:List[str], bsz:int=32): idx=0 minibatches = [] while True: start = idx end = idx+bsz if start >= len(exs): break minibatches.append( exs[start:end] ) idx += bsz return minibatches def make_empty_minibatches(self, minibatches:List[List[str]]): e_minibatches = minibatches.copy() for i, mb in enumerate(e_minibatches): e_minibatches[i] = ['' for ex in mb] return e_minibatches def compute(self, predictions:List[str], references:List[str], bsz:int=32, use_aggregator:bool=False, tokenwise_score:bool=False, ): ''' returns harim+ score (List[float]) for predictions (summaries) and references (articles) **Note** - here, predictions = generated summaries to be evaluated, references = article to be summarized (but to follow the convention of the evaluate, we named kwarg as "references") - log_ppl equals to bartscore (yuan et al., neurips 2021) if tokenwise_score: returns minibatch chunks of harim+ scores and log-likelihoods with tokenized predictions (List[str]) if use_aggregator: returning scores are aggregated (mean) over given test set ''' # tokenize/prep src/tgts make_minibatches_bsz = partial(self.make_minibatches, bsz=bsz) b_srcs, b_tgts = map(make_minibatches_bsz, [predictions, references]) b_emps = self.make_empty_minibatches(b_srcs) scores=defaultdict(list) for mini_s, mini_e, mini_t in tqdm(zip(b_srcs, b_emps, b_tgts), total=len(b_tgts), desc=f"computing HaRiM+ {bsz=}, core={self._pretrained_name}"): src_in = self._prep_input(mini_s, src_or_tgt='src') emp_in = self._prep_input(mini_e, src_or_tgt='src') tgt_in = self._prep_input(mini_t, src_or_tgt='tgt') if emp_in.input_ids.shape[-1]==0: # emp_in.input_ids.shape == (32,0) boseos = f"{self._tokenizer.bos_token}{self._tokenizer.eos_token}" mini_e_ = [boseos for _ in range(len(mini_e))] emp_in = self._prep_input( mini_e_, src_or_tgt='src' ) # if mini_s == b_srcs[0]: # normal = src_in # if mini_s == b_srcs[-1]: # trailing = src_in # set_trace() src_in.data['labels'] = tgt_in.input_ids emp_in.data['labels'] = tgt_in.input_ids # print(f"{emp_in.data['labels']=}") # set_trace() tgt_mask = tgt_in.attention_mask assert (tgt_in.attention_mask == (tgt_in.input_ids != self._tokenizer.pad_token_id)).all() # src_in.data['decoder_input_ids'] = tgt_in.input_ids # src_in.data['decoder_attention_mask'] = tgt_in.attention_mask src_in = src_in.to(self._device) emp_in = emp_in.to(self._device) tgt_in = tgt_in.to(self._device) tgt_mask = tgt_mask.to(self._device) with torch.no_grad(): # token_type_ids attribute causes error s2s_logits = self._encdec_model.forward( input_ids = src_in.input_ids, attention_mask = src_in.attention_mask, labels = tgt_in.input_ids, # decoder_input_ids = tgt_in.input_ids, # decoder_attention_mask = tgt_in.attention_mask, return_dict=True).logits lm_logits = self._encdec_model.forward( input_ids = emp_in.input_ids, attention_mask = emp_in.attention_mask, labels = tgt_in.input_ids, # decoder_input_ids = tgt_in.input_ids, # decoder_attention_mask = tgt_in.attention_mask, return_dict=True).logits sent_lengths = tgt_mask.sum(-1) ll_tok = self.log_likelihoods(s2s_logits, src_in.labels, tgt_mask) ll = ll_tok.sum(-1) / sent_lengths harim_tok = self.harim(s2s_logits, lm_logits, src_in.labels, tgt_mask) harim = harim_tok.sum(-1) / sent_lengths harim_plus_normalized = ll + self._lambda * harim # loglikelihood + lambda * negative_harim (negative harim=-1* risk) scores['harim+'].extend(harim_plus_normalized.tolist()) scores['harim'].extend(harim.tolist()) scores['log_ppl'].extend(ll.tolist()) if tokenwise_score: scores['tok_harim+'].append(harim_tok*self._lambda + ll_tok) scores['tok_predictions'].append( [self._tokenizer.convert_ids_to_token(idxs) for idxs in src_in.labels] ) if use_aggregator: # after for k, v in scores.items(): if not k.startswith('tok_'): scores[k] = sum(v)/len(v) # aggregate (mean) scores['lambda'] = self._lambda return scores def test(bsz = 16, pretrained_name='facebook/bart-large-cnn', tokenizer=None): if tokenizer is None: scorer = Harimplus_Scorer(pretrained_name=pretrained_name) else: scorer = Harimplus_Scorer(pretrained_name=pretrained_name, tokenizer=tokenizer) art1 = """The respected law professor from Philadelphia now being investigated after allegedly emailing students a link to pornographic footage, was once a contestant on Who Wants to Be a Millionaire, it has emerged. Lisa McElroy, a 50-year-old Drexel professor, appeared on the show in 2010 while it was still hosted my Meredith Vieira. And like her apparent March 31 email mishap, her game show appearance ended with a very public mistake. McElroy, who teaches legal writing, got tripped up on the $12,500 level after flying through the first few questions, notes Philly.com. Wishes she was a millionaire: Drexel law profesor professor Lisa McElroy allegedly sent a link to a pornographic website to her students. In 2010, she appeared on the TV game show Who Wants to Be a Milionaire Mother of two: The mother of two shared an anecdote with then-host Meredith Vieira about having to scramble to find a babysitter for her kids and someone to teach her class after learning she was to appear on the show just two days before taping Lost it: McElroy was tripped up on the $12,500 question. Despite having used two lifelines, she answered wrong and walked away with around $5,000 The questions read: 'As a result of General Motor’s bankruptcy declaration in 2009, what foreign government became one of its largest shareholders?' Even after using two of her lifelines to narrow down the answer, McElroy answered China, which was incorrect. The correct answer was Canada. She walked away with around $5,000. McElroy, who is a children's book and biography author, is apparently also a mother. She opened the appearance by sharing an anecdote with Vieira about having to scramble to find a babysitter after being informed she was chosen to be on Millionaire jsut two days prior to taping. She's accused of sending the inappropriate message this past March 31 under the subject line: 'Great article on writing briefs.' However, when recipients opened the enclosed link, philly.com reports that they were directed to a video of 'a woman engaging in a sexually explicit act'. Lisa McElroy, 50, who teaches legal writing at Drexel University, reportedly sent the inappropriate message on March 31 baring the subject line: 'Great article on writing briefs' Following a number of complaints, the college issued an apology to students. The message read: 'As you may be aware, some students erroneously received an email this morning directing them to a... post that included some inappropriate material. 'We take this matter seriously and apologize for any upset it may have caused.' The university says federal law requires it investigate all reports of inappropriate behaviors of a sexual nature. McElroy did not immediately respond to an email sent to her university account by the Associated Press. When recipients opened the enclosed link, philly.com reports that they were directed to a video of 'a woman engaging in a sexually explicit act' It's not the first time the married mother-of-two has appeared in the spotlight. She is also an accomplished author with a number of published biographies and children's books. On her website, www.lisamcelroy.com, she describes herself as a 'Supreme Court junkie.' She adds that her favorites ways of relaxing include 'crawling under the covers with a dog or two and a really good book' or 'hanging out' with her two adolescent daughters. Regarding the recent email scandal, David Lat - a lawyer and legal commenter -suggests she could have been 'hacked' or made a 'copy/paste error'. While an internal investigation gets underway, it's been reported that McElroy has been placed on administrative leave. While an internal investigation gets underway, it's been reported that McElroy has been placed on administrative leave from Drexel University (seen above)""" art2 = """Spain's 2-0 defeat by Holland on Tuesday brought back bitter memories of their disastrous 2014 World Cup, but coach Vicente del Bosque will not be too worried about a third straight friendly defeat, insists Gerard Pique. Holland, whose 5-1 drubbing of Spain in the group stage in Brazil last year marked the end of the Iberian nation's six-year domination of the world game, scored two early goals at the Amsterdam Arena and held on against some determined Spain pressure in the second half for a 2-0 success. They became the first team to inflict two defeats on Del Bosque since he took over in 2008 but the gruff 64-year-old had used the match to try out several new faces and he fielded a largely experimental, second-string team. Stefan de Vrij (right) headed Holland in front against Spain at the Amsterdam Arena on Tuesday Gerard Pique (left) could do nothing to stop Davy Klaassen doubling the Dutch advantage Malaga forward Juanmi and Sevilla midfielder Vitolo became the 55th and 56th players to debut under Del Bosque, while the likes of goalkeeper David de Gea, defenders Raul Albiol, Juan Bernat and Dani Carvajal and midfielder Mario Suarez all started the game. 'The national team's state of health is good,' centre back Gerard Pique told reporters. 'We are in a process where players are coming into the team and gathering experience,' added the Barcelona defender. 'We are second in qualifying (for Euro 2016) and these friendly games are for experimenting. 'I am not that worried about this match because we lost friendlies in previous years and then ended up winning titles.' David de Gea was given a start by Vicente del Bosque but could not keep out De Vrij's header here Dani Carvajal (centre) was another squad player given a chance to impress against Holland Del Bosque will be confident he can find the right mix of players to secure Spain's berth at Euro 2016 in France next year, when they will be chasing an unprecedented third straight title. Slovakia are the surprise leaders in qualifying Group C thanks to a 2-1 win over Spain in Zilina in October and have a maximum 15 points from five of 10 matches. Spain are second on 12 points, three ahead of Ukraine, who they beat 1-0 in Seville on Friday. Del Bosque's side host Slovakia in September in a match that could decide who goes through to the finals as group winners. 'The team is in good shape,' forward Pedro told reporters. 'We have a very clear idea of our playing style and we are able to count on people who are gradually making a place for themselves in the team.'""" summaries = [ "lisa mcelroy , 50 , who teaches legal writing at drexel university , reportedly sent the ` inappropriate ' message on march 31 . when recipients clicked the enclosed link , they were allegedly directed to a video of ' a woman engaging in a sexually explicit act ' . mcelroy appeared on the popular game show in 2010 with then-host meredith vieira but lost the game after reaching just $ 12,500 . along with teaching law , mcelroy is also an accomplished author with a number of published biographies and children 's books . has been placed on leave while school investigates .", # reference 2.3270 "lisa mcelroy, a 50-year-old drexel professor, appeared on the show in 2010 while it was still hosted my meredith vieira. she's accused of sending the inappropriate message this past march 31 under the subject line: 'great article on writing briefs' when recipients opened the enclosed link, philly.com reports that they were directed to a video of 'a woman engaging in a sexually explicit act' the married mother-of-two has been placed on administrative leave.", # BART-large+cnn 4.9714 "lisa mcelroy , 50 , who teaches legal writing at drexel university , appeared on the show in 2010 while it was still hosted my meredith vieira . she got tripped up on the $ 12,500 level after flying through the first few questions , philly.com reports . mcelroy answered wrong and walked away with around $ 5,000 .", # BERTSUM=Factual 3.2028 "lisa mcelroy , 50 , who teaches legal writing at philadelphia university , reportedly sent the ` inappropriate ' message on march 31 . when recipients clicked the enclosed link , they were allegedly directed to a video of ' a woman engaging in a sexually explicit act ' . mcelroy appeared on the popular game show in 2010 with then-host meredith vieira but lost the game after reaching just $ 12,500 . along with teaching law , mcelroy is also an accomplished author with a number of published biographies and children 's books . has been placed on leave while school investigates .", # wrong subj (philadelphia) 2.2122 "lisa mcelroy , 50 , who teaches legal writing at drexel university , reportedly did not send the ` inappropriate ' message on march 31 . when recipients clicked the enclosed link , they were allegedly directed to a video of ' a woman engaging in a sexually explicit act ' . mcelroy appeared on the popular game show in 2010 with then-host meredith vieira but lost the game after reaching just $ 12,500 . along with teaching law , mcelroy is also an accomplished author with a number of published biographies and children 's books . has been placed on leave while school investigates .", # negation 2.2022 "holland beat spain 2-0 at the amsterdam arena on tuesday night . stefan de vrij and davy klaassen scored goals for holland . defeat recalls horror 5-1 defeat by holland at the world cup . vicente del bosque used game to give younger spain players a chance .",# reference "holland beat spain 2-0 in the group stage in brazil on tuesday night . del bosque will be hoping to find the right mix of players to the world cup . gerard pique could make the right mix of players to the tournament .",# summary (factuality = 0, rnn) "del bosque beat spain 2-0 at the amsterdam arena on tuesday night . stefan de vrij and davy klaassen scored goals for holland . defeat recalls horror 5-1 defeat by holland at the world cup . vicente del bosque used game to give younger spain players a chance .",# reference + wrong subj "holland could not beat spain 2-0 at the amsterdam arena on tuesday night . stefan de vrij and davy klaassen scored goals for holland . defeat recalls horror 5-1 defeat by holland at the world cup . vicente del bosque used game to give younger spain players a chance .",# reference + negation ] articles = [ art1 ]*5 + [art2 ]*4 # set_trace() hp_score = scorer.compute(predictions=summaries, references=articles, use_aggregator=False, bsz=bsz) # pprint(f"{articles=}") # pprint(f"{summaries=}") pprint(hp_score) ''' ## drexel example # reference 2.3270 # BART-large+cnn 4.9714 # BERTSUM=Factual 3.2028 # ref + wrong subj (philadelphia) 2.2122 # ref + negation 2.2022 'harim+': [1.6270232200622559, 1.7585878372192383, 1.3859858512878418, 1.5434350967407227, 1.609492301940918], ## main table result 1.6247 (reference, factual) 0.1173 (rnn, unfactual) 1.3229 (ref + wrong subj) 1.4132 (ref + negation) 'harim+': [1.8230078220367432, 1.5361897945404053, 1.806436538696289, 1.7360382080078125], ''' if __name__ == '__main__': Fire(test)