Spaces:
Running
Running
File size: 7,913 Bytes
2885a60 f469daa 2885a60 b5aef9f 2885a60 b5aef9f 2885a60 b5aef9f 2885a60 f469daa 2885a60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import torch
import torch.nn.functional as F
from transformers import (AutoModelForSeq2SeqLM,
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast)
import evaluate
from fire import Fire
import pandas as pd
from tqdm import tqdm
import json
from typing import List, Dict, Union
from collections import defaultdict
from functools import partial
from pprint import pprint
from ipdb import set_trace
class Harimplus_Scorer:
def __init__(self,
pretrained_name:str='none',
tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast]=None,
mixing_factor:float=7., # same as lambda in the paper
device:str='cuda',
src_maxlen=1024,
tgt_maxlen=110,
):
self._pretrained_name = pretrained_name
self._lambda = mixing_factor
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self._encdec_model = AutoModelForSeq2SeqLM.from_pretrained(self._pretrained_name)
if tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self._pretrained_name)
else:
self._tokenizer = tokenizer
self._encdec_model.to(self._device)
self._encdec_model.eval()
self._src_maxlen = src_maxlen
self._tgt_maxlen = tgt_maxlen
def _prep_input(self, src_tgt_txts, src_or_tgt='src'):
L = self._src_maxlen if src_or_tgt=='src' else self._tgt_maxlen
if isinstance(src_tgt_txts, pd.Series):
src_tgt_txts=src_tgt_txts.tolist()
if src_or_tgt == 'src':
src_tgt_txts = [ s.replace("\n", " ") for s in src_tgt_txts ]
return self._tokenizer(src_tgt_txts, padding=True, truncation=True, max_length=L, return_tensors='pt') # ModelInput dataclass
'''below are helper functions w/o dependency to the self, but included inside the class for ease of use'''
def likelihoods(self, logits, force_decode_indices, tgt_mask):
probs = F.softmax(logits, dim=-1)
probs_force_decode_ = probs.gather(-1, force_decode_indices.unsqueeze(-1)).squeeze()
probs_force_decode= probs_force_decode_ * tgt_mask
assert probs_force_decode.shape == force_decode_indices.shape
return probs_force_decode
def log_likelihoods(self, logits, force_decode_indices, tgt_mask):
ll = F.log_softmax(logits, dim=-1)
ll_force_decode_ = ll.gather(-1, force_decode_indices.unsqueeze(-1)).squeeze()
ll_force_decode = ll_force_decode_ * tgt_mask
return ll_force_decode
def harim(self, s2s_logits, lm_logits, force_decode_indices, tgt_mask ):
p_s2s, p_lm = self.likelihoods(s2s_logits, force_decode_indices, tgt_mask), \
self.likelihoods(lm_logits, force_decode_indices, tgt_mask)
delta = p_s2s - p_lm
margin_linear = (1-delta) / 2
harim = -(1-p_s2s) * margin_linear + 1
return harim # this is -1 * hallucination risk
def make_minibatches(self, exs:List[str], bsz:int=32):
idx=0
minibatches = []
while True:
start = idx
end = idx+bsz
if start >= len(exs):
break
minibatches.append( exs[start:end] )
idx += bsz
return minibatches
def make_empty_minibatches(self, minibatches:List[List[str]]):
e_minibatches = minibatches.copy()
for i, mb in enumerate(e_minibatches):
e_minibatches[i] = ['' for ex in mb]
return e_minibatches
def compute(self, predictions:List[str],
references:List[str],
bsz:int=32,
use_aggregator:bool=False,
return_details:bool=False,
tokenwise_score:bool=False,
):
'''
returns harim+ score (List[float]) for predictions (summaries) and references (articles)
**Note**
- here, predictions = generated summaries to be evaluated, references = article to be summarized (but to follow the convention of the evaluate, we named kwarg as "references")
- log_ppl equals to bartscore (yuan et al., neurips 2021)
if tokenwise_score:
returns minibatch chunks of harim+ scores and log-likelihoods with tokenized predictions (List[str])
if use_aggregator:
returning scores are aggregated (mean) over given test set
'''
# tokenize/prep src/tgts
make_minibatches_bsz = partial(self.make_minibatches, bsz=bsz)
summaries = predictions
articles = references
b_srcs, b_tgts = map(make_minibatches_bsz, [articles, summaries])
b_emps = self.make_empty_minibatches(b_srcs)
scores=defaultdict(list)
for mini_s, mini_e, mini_t in tqdm(zip(b_srcs, b_emps, b_tgts), total=len(b_tgts), desc=f"computing HaRiM+ {bsz=}, core={self._pretrained_name}"):
src_in = self._prep_input(mini_s, src_or_tgt='src')
emp_in = self._prep_input(mini_e, src_or_tgt='src')
tgt_in = self._prep_input(mini_t, src_or_tgt='tgt')
if emp_in.input_ids.shape[-1]==0: # emp_in.input_ids.shape == (32,0)
boseos = f"{self._tokenizer.bos_token}{self._tokenizer.eos_token}"
mini_e_ = [boseos for _ in range(len(mini_e))]
emp_in = self._prep_input( mini_e_, src_or_tgt='src' )
tgt_mask = tgt_in.attention_mask
src_in = src_in.to(self._device)
emp_in = emp_in.to(self._device)
tgt_in = tgt_in.to(self._device)
tgt_mask = tgt_mask.to(self._device)
with torch.no_grad():
# token_type_ids attribute causes error
s2s_logits = self._encdec_model.forward(
input_ids = src_in.input_ids,
attention_mask = src_in.attention_mask,
labels = tgt_in.input_ids,
return_dict=True).logits
lm_logits = self._encdec_model.forward(
input_ids = emp_in.input_ids,
attention_mask = emp_in.attention_mask,
labels = tgt_in.input_ids,
return_dict=True).logits
sent_lengths = tgt_mask.sum(-1)
ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
ll = ll_tok.sum(-1) / sent_lengths
harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
harim = harim_tok.sum(-1) / sent_lengths
harim_plus_normalized = ll + self._lambda * harim # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
scores['harim+'].extend(harim_plus_normalized.tolist())
scores['harim'].extend(harim.tolist())
scores['log_ppl'].extend(ll.tolist())
if tokenwise_score:
scores['tok_harim+'].append(harim_tok*self._lambda + ll_tok)
scores['tok_predictions'].append( [self._tokenizer.convert_ids_to_token(idxs) for idxs in src_in.labels] )
if use_aggregator: # after
for k, v in scores.items():
if not k.startswith('tok_'):
scores[k] = sum(v)/len(v) # aggregate (mean)
scores['lambda'] = self._lambda
if not return_details:
scores = scores['harim+']
return scores
|