comparative-explainability
/
Transformer-Explainability
/BERT_rationale_benchmark
/models
/pipeline
/pipeline_utils.py
import itertools | |
import logging | |
from collections import defaultdict, namedtuple | |
from itertools import chain | |
from typing import Any, Dict, List, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from rationale_benchmark.metrics import (PositionScoredDocument, Rationale, | |
partial_match_score, | |
score_hard_rationale_predictions, | |
score_soft_tokens) | |
from rationale_benchmark.models.model_utils import PaddedSequence | |
from rationale_benchmark.utils import Annotation | |
from sklearn.metrics import accuracy_score, classification_report | |
SentenceEvidence = namedtuple( | |
"SentenceEvidence", "kls ann_id query docid index sentence" | |
) | |
def token_annotations_to_evidence_classification( | |
annotations: List[Annotation], | |
documents: Dict[str, List[List[Any]]], | |
class_interner: Dict[str, int], | |
) -> List[SentenceEvidence]: | |
ret = [] | |
for ann in annotations: | |
docid_to_ev = defaultdict(list) | |
for evidence in ann.all_evidences(): | |
docid_to_ev[evidence.docid].append(evidence) | |
for docid, evidences in docid_to_ev.items(): | |
evidences = sorted(evidences, key=lambda ev: ev.start_token) | |
text = [] | |
covered_tokens = set() | |
doc = list(chain.from_iterable(documents[docid])) | |
for evidence in evidences: | |
assert ( | |
evidence.start_token >= 0 | |
and evidence.end_token > evidence.start_token | |
) | |
assert evidence.start_token < len(doc) and evidence.end_token <= len( | |
doc | |
) | |
text.extend(evidence.text) | |
new_tokens = set(range(evidence.start_token, evidence.end_token)) | |
if len(new_tokens & covered_tokens) > 0: | |
raise ValueError( | |
"Have overlapping token ranges covered in the evidence spans and the implementer was lazy; deal with it" | |
) | |
covered_tokens |= new_tokens | |
assert len(text) > 0 | |
ret.append( | |
SentenceEvidence( | |
kls=class_interner[ann.classification], | |
query=ann.query, | |
ann_id=ann.annotation_id, | |
docid=docid, | |
index=-1, | |
sentence=tuple(text), | |
) | |
) | |
return ret | |
def annotations_to_evidence_classification( | |
annotations: List[Annotation], | |
documents: Dict[str, List[List[Any]]], | |
class_interner: Dict[str, int], | |
include_all: bool, | |
) -> List[SentenceEvidence]: | |
"""Converts Corpus-Level annotations to Sentence Level relevance judgments. | |
As this module is about a pipelined approach for evidence identification, | |
inputs to both an evidence identifier and evidence classifier need to be to | |
be on a sentence level, this module converts data to be that form. | |
The return type is of the form | |
annotation id -> docid -> [sentence level annotations] | |
""" | |
ret = [] | |
for ann in annotations: | |
ann_id = ann.annotation_id | |
docids = set(ev.docid for ev in chain.from_iterable(ann.evidences)) | |
annotations_for_doc = defaultdict(list) | |
for d in docids: | |
for index, sent in enumerate(documents[d]): | |
annotations_for_doc[d].append( | |
SentenceEvidence( | |
kls=class_interner[ann.classification], | |
query=ann.query, | |
ann_id=ann.annotation_id, | |
docid=d, | |
index=index, | |
sentence=tuple(sent), | |
) | |
) | |
if include_all: | |
ret.extend(chain.from_iterable(annotations_for_doc.values())) | |
else: | |
contributes = set() | |
for ev in chain.from_iterable(ann.evidences): | |
for index in range(ev.start_sentence, ev.end_sentence): | |
contributes.add(annotations_for_doc[ev.docid][index]) | |
ret.extend(contributes) | |
assert len(ret) > 0 | |
return ret | |
def annotations_to_evidence_identification( | |
annotations: List[Annotation], documents: Dict[str, List[List[Any]]] | |
) -> Dict[str, Dict[str, List[SentenceEvidence]]]: | |
"""Converts Corpus-Level annotations to Sentence Level relevance judgments. | |
As this module is about a pipelined approach for evidence identification, | |
inputs to both an evidence identifier and evidence classifier need to be to | |
be on a sentence level, this module converts data to be that form. | |
The return type is of the form | |
annotation id -> docid -> [sentence level annotations] | |
""" | |
ret = defaultdict(dict) # annotation id -> docid -> sentences | |
for ann in annotations: | |
ann_id = ann.annotation_id | |
for ev_group in ann.evidences: | |
for ev in ev_group: | |
if len(ev.text) == 0: | |
continue | |
if ev.docid not in ret[ann_id]: | |
ret[ann.annotation_id][ev.docid] = [] | |
# populate the document with "not evidence"; to be filled in later | |
for index, sent in enumerate(documents[ev.docid]): | |
ret[ann.annotation_id][ev.docid].append( | |
SentenceEvidence( | |
kls=0, | |
query=ann.query, | |
ann_id=ann.annotation_id, | |
docid=ev.docid, | |
index=index, | |
sentence=sent, | |
) | |
) | |
# define the evidence sections of the document | |
for s in range(ev.start_sentence, ev.end_sentence): | |
ret[ann.annotation_id][ev.docid][s] = SentenceEvidence( | |
kls=1, | |
ann_id=ann.annotation_id, | |
query=ann.query, | |
docid=ev.docid, | |
index=ret[ann.annotation_id][ev.docid][s].index, | |
sentence=ret[ann.annotation_id][ev.docid][s].sentence, | |
) | |
return ret | |
def annotations_to_evidence_token_identification( | |
annotations: List[Annotation], | |
source_documents: Dict[str, List[List[str]]], | |
interned_documents: Dict[str, List[List[int]]], | |
token_mapping: Dict[str, List[List[Tuple[int, int]]]], | |
) -> Dict[str, Dict[str, List[SentenceEvidence]]]: | |
# TODO document | |
# TODO should we simplify to use only source text? | |
ret = defaultdict(lambda: defaultdict(list)) # annotation id -> docid -> sentences | |
positive_tokens = 0 | |
negative_tokens = 0 | |
for ann in annotations: | |
annid = ann.annotation_id | |
docids = set(ev.docid for ev in chain.from_iterable(ann.evidences)) | |
sentence_offsets = defaultdict(list) # docid -> [(start, end)] | |
classes = defaultdict(list) # docid -> [token is yea or nay] | |
for docid in docids: | |
start = 0 | |
assert len(source_documents[docid]) == len(interned_documents[docid]) | |
for whole_token_sent, wordpiece_sent in zip( | |
source_documents[docid], interned_documents[docid] | |
): | |
classes[docid].extend([0 for _ in wordpiece_sent]) | |
end = start + len(wordpiece_sent) | |
sentence_offsets[docid].append((start, end)) | |
start = end | |
for ev in chain.from_iterable(ann.evidences): | |
if len(ev.text) == 0: | |
continue | |
flat_token_map = list(chain.from_iterable(token_mapping[ev.docid])) | |
if ev.start_token != -1: | |
# start, end = token_mapping[ev.docid][ev.start_token][0], token_mapping[ev.docid][ev.end_token][1] | |
start, end = ( | |
flat_token_map[ev.start_token][0], | |
flat_token_map[ev.end_token - 1][1], | |
) | |
else: | |
start = flat_token_map[sentence_offsets[ev.start_sentence][0]][0] | |
end = flat_token_map[sentence_offsets[ev.end_sentence - 1][1]][1] | |
for i in range(start, end): | |
classes[ev.docid][i] = 1 | |
for docid, offsets in sentence_offsets.items(): | |
token_assignments = classes[docid] | |
positive_tokens += sum(token_assignments) | |
negative_tokens += len(token_assignments) - sum(token_assignments) | |
for s, (start, end) in enumerate(offsets): | |
sent = interned_documents[docid][s] | |
ret[annid][docid].append( | |
SentenceEvidence( | |
kls=tuple(token_assignments[start:end]), | |
query=ann.query, | |
ann_id=ann.annotation_id, | |
docid=docid, | |
index=s, | |
sentence=sent, | |
) | |
) | |
logging.info( | |
f"Have {positive_tokens} positive wordpiece tokens, {negative_tokens} negative wordpiece tokens" | |
) | |
return ret | |
def make_preds_batch( | |
classifier: nn.Module, | |
batch_elements: List[SentenceEvidence], | |
device=None, | |
criterion: nn.Module = None, | |
tensorize_model_inputs: bool = True, | |
) -> Tuple[float, List[float], List[int], List[int]]: | |
"""Batch predictions | |
Args: | |
classifier: a module that looks like an AttentiveClassifier | |
batch_elements: a list of elements to make predictions over. These must be SentenceEvidence objects. | |
device: Optional; what compute device this should run on | |
criterion: Optional; a loss function | |
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization | |
""" | |
# delete any "None" padding, if any (imposed by the use of the "grouper") | |
batch_elements = filter(lambda x: x is not None, batch_elements) | |
targets, queries, sentences = zip( | |
*[(s.kls, s.query, s.sentence) for s in batch_elements] | |
) | |
ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] | |
targets = torch.tensor(targets, dtype=torch.long, device=device) | |
if tensorize_model_inputs: | |
queries = [torch.tensor(q, dtype=torch.long) for q in queries] | |
sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] | |
preds = classifier(queries, ids, sentences) | |
targets = targets.to(device=preds.device) | |
if criterion: | |
loss = criterion(preds, targets) | |
else: | |
loss = None | |
# .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16 | |
hard_preds = torch.argmax(preds.float(), dim=-1) | |
return loss, preds, hard_preds, targets | |
def make_preds_epoch( | |
classifier: nn.Module, | |
data: List[SentenceEvidence], | |
batch_size: int, | |
device=None, | |
criterion: nn.Module = None, | |
tensorize_model_inputs: bool = True, | |
): | |
"""Predictions for more than one batch. | |
Args: | |
classifier: a module that looks like an AttentiveClassifier | |
data: a list of elements to make predictions over. These must be SentenceEvidence objects. | |
batch_size: the biggest chunk we can fit in one batch. | |
device: Optional; what compute device this should run on | |
criterion: Optional; a loss function | |
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization | |
""" | |
epoch_loss = 0 | |
epoch_soft_pred = [] | |
epoch_hard_pred = [] | |
epoch_truth = [] | |
batches = _grouper(data, batch_size) | |
classifier.eval() | |
for batch in batches: | |
loss, soft_preds, hard_preds, targets = make_preds_batch( | |
classifier, | |
batch, | |
device, | |
criterion=criterion, | |
tensorize_model_inputs=tensorize_model_inputs, | |
) | |
if loss is not None: | |
epoch_loss += loss.sum().item() | |
epoch_hard_pred.extend(hard_preds) | |
epoch_soft_pred.extend(soft_preds.cpu()) | |
epoch_truth.extend(targets) | |
epoch_loss /= len(data) | |
epoch_hard_pred = [x.item() for x in epoch_hard_pred] | |
epoch_truth = [x.item() for x in epoch_truth] | |
return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth | |
def make_token_preds_batch( | |
classifier: nn.Module, | |
batch_elements: List[SentenceEvidence], | |
token_mapping: Dict[str, List[List[Tuple[int, int]]]], | |
device=None, | |
criterion: nn.Module = None, | |
tensorize_model_inputs: bool = True, | |
) -> Tuple[float, List[float], List[int], List[int]]: | |
"""Batch predictions | |
Args: | |
classifier: a module that looks like an AttentiveClassifier | |
batch_elements: a list of elements to make predictions over. These must be SentenceEvidence objects. | |
device: Optional; what compute device this should run on | |
criterion: Optional; a loss function | |
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization | |
""" | |
# delete any "None" padding, if any (imposed by the use of the "grouper") | |
batch_elements = filter(lambda x: x is not None, batch_elements) | |
targets, queries, sentences = zip( | |
*[(s.kls, s.query, s.sentence) for s in batch_elements] | |
) | |
ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] | |
targets = PaddedSequence.autopad( | |
[torch.tensor(t, dtype=torch.long, device=device) for t in targets], | |
batch_first=True, | |
device=device, | |
) | |
aggregate_spans = [token_mapping[s.docid][s.index] for s in batch_elements] | |
if tensorize_model_inputs: | |
queries = [torch.tensor(q, dtype=torch.long) for q in queries] | |
sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] | |
preds = classifier(queries, ids, sentences, aggregate_spans) | |
targets = targets.to(device=preds.device) | |
mask = targets.mask(on=1, off=0, device=preds.device, dtype=torch.float) | |
if criterion: | |
loss = criterion( | |
preds, (targets.data.to(device=preds.device) * mask).squeeze() | |
).sum() | |
else: | |
loss = None | |
hard_preds = [ | |
torch.round(x).to(dtype=torch.int).cpu() for x in targets.unpad(preds) | |
] | |
targets = [[y.item() for y in x] for x in targets.unpad(targets.data.cpu())] | |
return loss, preds, hard_preds, targets # targets.unpad(targets.data.cpu()) | |
# TODO fix the arguments | |
def make_token_preds_epoch( | |
classifier: nn.Module, | |
data: List[SentenceEvidence], | |
token_mapping: Dict[str, List[List[Tuple[int, int]]]], | |
batch_size: int, | |
device=None, | |
criterion: nn.Module = None, | |
tensorize_model_inputs: bool = True, | |
): | |
"""Predictions for more than one batch. | |
Args: | |
classifier: a module that looks like an AttentiveClassifier | |
data: a list of elements to make predictions over. These must be SentenceEvidence objects. | |
batch_size: the biggest chunk we can fit in one batch. | |
device: Optional; what compute device this should run on | |
criterion: Optional; a loss function | |
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization | |
""" | |
epoch_loss = 0 | |
epoch_soft_pred = [] | |
epoch_hard_pred = [] | |
epoch_truth = [] | |
batches = _grouper(data, batch_size) | |
classifier.eval() | |
for batch in batches: | |
loss, soft_preds, hard_preds, targets = make_token_preds_batch( | |
classifier, | |
batch, | |
token_mapping, | |
device, | |
criterion=criterion, | |
tensorize_model_inputs=tensorize_model_inputs, | |
) | |
if loss is not None: | |
epoch_loss += loss.sum().item() | |
epoch_hard_pred.extend(hard_preds) | |
epoch_soft_pred.extend(soft_preds.cpu().tolist()) | |
epoch_truth.extend(targets) | |
epoch_loss /= len(data) | |
return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth | |
# copied from https://docs.python.org/3/library/itertools.html#itertools-recipes | |
def _grouper(iterable, n, fillvalue=None): | |
"Collect data into fixed-length chunks or blocks" | |
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" | |
args = [iter(iterable)] * n | |
return itertools.zip_longest(*args, fillvalue=fillvalue) | |
def score_rationales( | |
truth: List[Annotation], | |
documents: Dict[str, List[List[int]]], | |
input_data: List[SentenceEvidence], | |
scores: List[float], | |
) -> dict: | |
results = {} | |
doc_to_sent_scores = dict() # (annid, docid) -> [sentence scores] | |
for sent, score in zip(input_data, scores): | |
k = (sent.ann_id, sent.docid) | |
if k not in doc_to_sent_scores: | |
doc_to_sent_scores[k] = [0.0 for _ in range(len(documents[sent.docid]))] | |
if not isinstance(score[1], float): | |
score[1] = score[1].item() | |
doc_to_sent_scores[(sent.ann_id, sent.docid)][sent.index] = score[1] | |
# hard rationale scoring | |
best_sentence = {k: np.argmax(np.array(v)) for k, v in doc_to_sent_scores.items()} | |
predicted_rationales = [] | |
for (ann_id, docid), sent_idx in best_sentence.items(): | |
start_token = sum(len(s) for s in documents[docid][:sent_idx]) | |
end_token = start_token + len(documents[docid][sent_idx]) | |
predicted_rationales.append(Rationale(ann_id, docid, start_token, end_token)) | |
true_rationales = list( | |
chain.from_iterable(Rationale.from_annotation(rat) for rat in truth) | |
) | |
results["hard_rationale_scores"] = score_hard_rationale_predictions( | |
true_rationales, predicted_rationales | |
) | |
results["hard_rationale_partial_match_scores"] = partial_match_score( | |
true_rationales, predicted_rationales, [0.5] | |
) | |
# soft rationale scoring | |
instance_format = [] | |
for (ann_id, docid), sentences in doc_to_sent_scores.items(): | |
soft_token_predictions = [] | |
for sent_score, sent_text in zip(sentences, documents[docid]): | |
soft_token_predictions.extend(sent_score for _ in range(len(sent_text))) | |
instance_format.append( | |
{ | |
"annotation_id": ann_id, | |
"rationales": [ | |
{ | |
"docid": docid, | |
"soft_rationale_predictions": soft_token_predictions, | |
"soft_sentence_predictions": sentences, | |
} | |
], | |
} | |
) | |
flattened_documents = { | |
k: list(chain.from_iterable(v)) for k, v in documents.items() | |
} | |
token_scoring_format = PositionScoredDocument.from_results( | |
instance_format, truth, flattened_documents, use_tokens=True | |
) | |
results["soft_token_scores"] = score_soft_tokens(token_scoring_format) | |
sentence_scoring_format = PositionScoredDocument.from_results( | |
instance_format, truth, documents, use_tokens=False | |
) | |
results["soft_sentence_scores"] = score_soft_tokens(sentence_scoring_format) | |
return results | |
def decode( | |
evidence_identifier: nn.Module, | |
evidence_classifier: nn.Module, | |
train: List[Annotation], | |
val: List[Annotation], | |
test: List[Annotation], | |
docs: Dict[str, List[List[int]]], | |
class_interner: Dict[str, int], | |
batch_size: int, | |
tensorize_model_inputs: bool, | |
decoding_docs: Dict[str, List[Any]] = None, | |
) -> dict: | |
"""Identifies and then classifies evidence | |
Args: | |
evidence_identifier: a module for identifying evidence statements | |
evidence_classifier: a module for making a classification based on evidence statements | |
train: A List of interned Annotations | |
val: A List of interned Annotations | |
test: A List of interned Annotations | |
docs: A Dict of Documents, which are interned sentences. | |
class_interner: Converts an Annotation's final class into ints | |
batch_size: how big should our batches be? | |
tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization | |
""" | |
device = None | |
class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])] | |
if decoding_docs is None: | |
decoding_docs = docs | |
def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]: | |
"""Prepares data for evidence identification and classification. | |
Creates paired evaluation data, wherein each (annotation, docid, sentence, kls) | |
tuplet appears first as the kls determining if the sentence is evidence, and | |
secondarily what the overall classification for the (annotation/docid) pair is. | |
This allows selection based on model scores of the evidence_identifier for | |
input to the evidence_classifier. | |
""" | |
identification_data = annotations_to_evidence_identification(data, docs) | |
classification_data = annotations_to_evidence_classification( | |
data, docs, class_interner, include_all=True | |
) | |
ann_doc_sents = defaultdict( | |
lambda: defaultdict(dict) | |
) # ann id -> docid -> sent idx -> sent data | |
ret = [] | |
for sent_ev in classification_data: | |
id_data = identification_data[sent_ev.ann_id][sent_ev.docid][sent_ev.index] | |
ret.append((id_data, sent_ev)) | |
assert id_data.ann_id == sent_ev.ann_id | |
assert id_data.docid == sent_ev.docid | |
assert id_data.index == sent_ev.index | |
assert len(ret) == len(classification_data) | |
return ret | |
def decode_batch( | |
data: List[Tuple[SentenceEvidence, SentenceEvidence]], | |
name: str, | |
score: bool = False, | |
annotations: List[Annotation] = None, | |
) -> dict: | |
"""Identifies evidence statements and then makes classifications based on it. | |
Args: | |
data: a paired list of SentenceEvidences, differing only in the kls field. | |
The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class | |
name: a name for a results dict | |
""" | |
num_uniques = len(set((x.ann_id, x.docid) for x, _ in data)) | |
logging.info( | |
f"Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations" | |
) | |
identifier_data, classifier_data = zip(*data) | |
results = dict() | |
IdentificationClassificationResult = namedtuple( | |
"IdentificationClassificationResult", | |
"identification_data classification_data soft_identification hard_identification soft_classification hard_classification", | |
) | |
with torch.no_grad(): | |
# make predictions for the evidence_identifier | |
evidence_identifier.eval() | |
evidence_classifier.eval() | |
( | |
_, | |
soft_identification_preds, | |
hard_identification_preds, | |
_, | |
) = make_preds_epoch( | |
evidence_identifier, | |
identifier_data, | |
batch_size, | |
device, | |
tensorize_model_inputs=tensorize_model_inputs, | |
) | |
assert len(soft_identification_preds) == len(data) | |
identification_results = defaultdict(list) | |
for id_data, cls_data, soft_id_pred, hard_id_pred in zip( | |
identifier_data, | |
classifier_data, | |
soft_identification_preds, | |
hard_identification_preds, | |
): | |
res = IdentificationClassificationResult( | |
identification_data=id_data, | |
classification_data=cls_data, | |
# 1 is p(evidence|sent,query) | |
soft_identification=soft_id_pred[1].float().item(), | |
hard_identification=hard_id_pred, | |
soft_classification=None, | |
hard_classification=False, | |
) | |
identification_results[(id_data.ann_id, id_data.docid)].append(res) | |
best_identification_results = { | |
key: max(value, key=lambda x: x.soft_identification) | |
for key, value in identification_results.items() | |
} | |
logging.info( | |
f"Selected the best sentence for {len(identification_results)} examples from a total of {len(soft_identification_preds)} sentences" | |
) | |
ids, classification_data = zip( | |
*[ | |
(k, v.classification_data) | |
for k, v in best_identification_results.items() | |
] | |
) | |
( | |
_, | |
soft_classification_preds, | |
hard_classification_preds, | |
classification_truth, | |
) = make_preds_epoch( | |
evidence_classifier, | |
classification_data, | |
batch_size, | |
device, | |
tensorize_model_inputs=tensorize_model_inputs, | |
) | |
classification_results = dict() | |
for eyeD, soft_class, hard_class in zip( | |
ids, soft_classification_preds, hard_classification_preds | |
): | |
input_id_result = best_identification_results[eyeD] | |
res = IdentificationClassificationResult( | |
identification_data=input_id_result.identification_data, | |
classification_data=input_id_result.classification_data, | |
soft_identification=input_id_result.soft_identification, | |
hard_identification=input_id_result.hard_identification, | |
soft_classification=soft_class, | |
hard_classification=hard_class, | |
) | |
classification_results[eyeD] = res | |
if score: | |
truth = [] | |
pred = [] | |
for res in classification_results.values(): | |
truth.append(res.classification_data.kls) | |
pred.append(res.hard_classification) | |
# results[f'{name}_f1'] = classification_report(classification_truth, pred, target_names=class_labels, output_dict=True) | |
results[f"{name}_f1"] = classification_report( | |
classification_truth, | |
hard_classification_preds, | |
target_names=class_labels, | |
output_dict=True, | |
) | |
results[f"{name}_acc"] = accuracy_score( | |
classification_truth, hard_classification_preds | |
) | |
results[f"{name}_rationale"] = score_rationales( | |
annotations, | |
decoding_docs, | |
identifier_data, | |
soft_identification_preds, | |
) | |
# turn the above results into a format suitable for scoring via the rationale scorer | |
# n.b. the sentence-level evidence predictions (hard and soft) are | |
# broadcast to the token level for scoring. The comprehensiveness class | |
# score is also a lie since the pipeline model above is faithful by | |
# design. | |
decoded = dict() | |
decoded_scores = defaultdict(list) | |
for (ann_id, docid), pred in classification_results.items(): | |
sentence_prediction_scores = [ | |
x.soft_identification | |
for x in identification_results[(ann_id, docid)] | |
] | |
sentence_start_token = sum( | |
len(s) | |
for s in decoding_docs[docid][: pred.identification_data.index] | |
) | |
sentence_end_token = sentence_start_token + len( | |
decoding_docs[docid][pred.classification_data.index] | |
) | |
hard_rationale_predictions = [ | |
{ | |
"start_token": sentence_start_token, | |
"end_token": sentence_end_token, | |
} | |
] | |
soft_rationale_predictions = [] | |
for sent_result in sorted( | |
identification_results[(ann_id, docid)], | |
key=lambda x: x.identification_data.index, | |
): | |
soft_rationale_predictions.extend( | |
sent_result.soft_identification | |
for _ in range( | |
len( | |
decoding_docs[sent_result.identification_data.docid][ | |
sent_result.identification_data.index | |
] | |
) | |
) | |
) | |
if ann_id not in decoded: | |
decoded[ann_id] = { | |
"annotation_id": ann_id, | |
"rationales": [], | |
"classification": class_labels[pred.hard_classification], | |
"classification_scores": { | |
class_labels[i]: s.item() | |
for i, s in enumerate(pred.soft_classification) | |
}, | |
# TODO this should turn into the data distribution for the predicted class | |
# "comprehensiveness_classification_scores": 0.0, | |
"truth": pred.classification_data.kls, | |
} | |
decoded[ann_id]["rationales"].append( | |
{ | |
"docid": docid, | |
"hard_rationale_predictions": hard_rationale_predictions, | |
"soft_rationale_predictions": soft_rationale_predictions, | |
"soft_sentence_predictions": sentence_prediction_scores, | |
} | |
) | |
decoded_scores[ann_id].append(pred.soft_classification) | |
# in practice, this is always a single element operation: | |
# in evidence inference (prompt is really a prompt + document), fever (we split documents into two classifications), movies (you only have one opinion about a movie), or boolQ (single document prompts) | |
# this exists to support weird models we *might* implement for cose/esnli | |
for ann_id, scores_list in decoded_scores.items(): | |
scores = torch.stack(scores_list) | |
score_avg = torch.mean(scores, dim=0) | |
# .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16 | |
hard_pred = torch.argmax(score_avg.float()).item() | |
decoded[ann_id]["classification"] = class_labels[hard_pred] | |
decoded[ann_id]["classification_scores"] = { | |
class_labels[i]: s.item() for i, s in enumerate(score_avg) | |
} | |
return results, list(decoded.values()) | |
test_results, test_decoded = decode_batch(prep(test), "test", score=False) | |
val_results, val_decoded = dict(), [] | |
train_results, train_decoded = dict(), [] | |
# val_results, val_decoded = decode_batch(prep(val), 'val', score=True, annotations=val) | |
# train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train) | |
return ( | |
dict(**train_results, **val_results, **test_results), | |
train_decoded, | |
val_decoded, | |
test_decoded, | |
) | |
def decode_evidence_tokens_and_classify( | |
evidence_token_identifier: nn.Module, | |
evidence_classifier: nn.Module, | |
train: List[Annotation], | |
val: List[Annotation], | |
test: List[Annotation], | |
docs: Dict[str, List[List[int]]], | |
source_documents: Dict[str, List[List[str]]], | |
token_mapping: Dict[str, List[List[Tuple[int, int]]]], | |
class_interner: Dict[str, int], | |
batch_size: int, | |
decoding_docs: Dict[str, List[Any]], | |
use_cose_hack: bool = False, | |
) -> dict: | |
"""Identifies and then classifies evidence | |
Args: | |
evidence_token_identifier: a module for identifying evidence statements | |
evidence_classifier: a module for making a classification based on evidence statements | |
train: A List of interned Annotations | |
val: A List of interned Annotations | |
test: A List of interned Annotations | |
docs: A Dict of Documents, which are interned sentences. | |
class_interner: Converts an Annotation's final class into ints | |
batch_size: how big should our batches be? | |
""" | |
device = None | |
class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])] | |
if decoding_docs is None: | |
decoding_docs = docs | |
def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]: | |
"""Prepares data for evidence identification and classification. | |
Creates paired evaluation data, wherein each (annotation, docid, sentence, kls) | |
tuplet appears first as the kls determining if the sentence is evidence, and | |
secondarily what the overall classification for the (annotation/docid) pair is. | |
This allows selection based on model scores of the evidence_token_identifier for | |
input to the evidence_classifier. | |
""" | |
# identification_data = annotations_to_evidence_identification(data, docs) | |
classification_data = token_annotations_to_evidence_classification( | |
data, docs, class_interner | |
) | |
# annotation id -> docid -> [SentenceEvidence]) | |
identification_data = annotations_to_evidence_token_identification( | |
data, | |
source_documents=decoding_docs, | |
interned_documents=docs, | |
token_mapping=token_mapping, | |
) | |
ann_doc_sents = defaultdict( | |
lambda: defaultdict(dict) | |
) # ann id -> docid -> sent idx -> sent data | |
ret = [] | |
for sent_ev in classification_data: | |
id_data = identification_data[sent_ev.ann_id][sent_ev.docid][sent_ev.index] | |
ret.append((id_data, sent_ev)) | |
assert id_data.ann_id == sent_ev.ann_id | |
assert id_data.docid == sent_ev.docid | |
# assert id_data.index == sent_ev.index | |
assert len(ret) == len(classification_data) | |
return ret | |
def decode_batch( | |
data: List[Tuple[SentenceEvidence, SentenceEvidence]], | |
name: str, | |
score: bool = False, | |
annotations: List[Annotation] = None, | |
class_labels: dict = class_labels, | |
) -> dict: | |
"""Identifies evidence statements and then makes classifications based on it. | |
Args: | |
data: a paired list of SentenceEvidences, differing only in the kls field. | |
The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class | |
name: a name for a results dict | |
""" | |
num_uniques = len(set((x.ann_id, x.docid) for x, _ in data)) | |
logging.info( | |
f"Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations" | |
) | |
identifier_data, classifier_data = zip(*data) | |
results = dict() | |
with torch.no_grad(): | |
# make predictions for the evidence_token_identifier | |
evidence_token_identifier.eval() | |
evidence_classifier.eval() | |
( | |
_, | |
soft_identification_preds, | |
hard_identification_preds, | |
id_preds_truth, | |
) = make_token_preds_epoch( | |
evidence_token_identifier, | |
identifier_data, | |
token_mapping, | |
batch_size, | |
device, | |
tensorize_model_inputs=True, | |
) | |
assert len(soft_identification_preds) == len(data) | |
evidence_only_cls = [] | |
for id_data, cls_data, soft_id_pred, hard_id_pred in zip( | |
identifier_data, | |
classifier_data, | |
soft_identification_preds, | |
hard_identification_preds, | |
): | |
assert cls_data.ann_id == id_data.ann_id | |
sent = [] | |
for start, end in token_mapping[cls_data.docid][0]: | |
if bool(hard_id_pred[start]): | |
sent.extend(id_data.sentence[start:end]) | |
# assert len(sent) > 0 | |
new_cls_data = SentenceEvidence( | |
cls_data.kls, | |
cls_data.ann_id, | |
cls_data.query, | |
cls_data.docid, | |
cls_data.index, | |
tuple(sent), | |
) | |
evidence_only_cls.append(new_cls_data) | |
( | |
_, | |
soft_classification_preds, | |
hard_classification_preds, | |
classification_truth, | |
) = make_preds_epoch( | |
evidence_classifier, | |
evidence_only_cls, | |
batch_size, | |
device, | |
tensorize_model_inputs=True, | |
) | |
if use_cose_hack: | |
logging.info( | |
"Reformatting identification and classification results to fit COS-E" | |
) | |
grouping = 5 | |
new_soft_identification_preds = [] | |
new_hard_identification_preds = [] | |
new_id_preds_truth = [] | |
new_soft_classification_preds = [] | |
new_hard_classification_preds = [] | |
new_classification_truth = [] | |
new_identifier_data = [] | |
class_labels = [] | |
# TODO fix the labels for COS-E | |
for i in range(0, len(soft_identification_preds), grouping): | |
cls_scores = torch.stack( | |
soft_classification_preds[i : i + grouping] | |
) | |
cls_scores = nn.functional.softmax(cls_scores, dim=-1) | |
cls_scores = cls_scores[:, 1] | |
choice = torch.argmax(cls_scores) | |
cls_labels = [ | |
x.ann_id.split("_")[-1] | |
for x in evidence_only_cls[i : i + grouping] | |
] | |
class_labels = cls_labels # we need to update the class labels because of the terrible hackery used to train this | |
cls_truths = [x.kls for x in evidence_only_cls[i : i + grouping]] | |
# cls_choice = evidence_only_cls[i + choice].ann_id.split('_')[-1] | |
cls_truth = np.argmax(cls_truths) | |
new_soft_identification_preds.append( | |
soft_identification_preds[i + choice] | |
) | |
new_hard_identification_preds.append( | |
hard_identification_preds[i + choice] | |
) | |
new_id_preds_truth.append(id_preds_truth[i + choice]) | |
new_soft_classification_preds.append( | |
soft_classification_preds[i + choice] | |
) | |
new_hard_classification_preds.append(choice) | |
new_identifier_data.append(identifier_data[i + choice]) | |
# new_hard_classification_preds.append(hard_classification_preds[i + choice]) | |
# new_classification_truth.append(classification_truth[i + choice]) | |
new_classification_truth.append(cls_truth) | |
soft_identification_preds = new_soft_identification_preds | |
hard_identification_preds = new_hard_identification_preds | |
id_preds_truth = new_id_preds_truth | |
soft_classification_preds = new_soft_classification_preds | |
hard_classification_preds = new_hard_classification_preds | |
classification_truth = new_classification_truth | |
identifier_data = new_identifier_data | |
if score: | |
results[f"{name}_f1"] = classification_report( | |
classification_truth, | |
hard_classification_preds, | |
target_names=class_labels, | |
output_dict=True, | |
) | |
results[f"{name}_acc"] = accuracy_score( | |
classification_truth, hard_classification_preds | |
) | |
results[f"{name}_token_pred_acc"] = accuracy_score( | |
list(chain.from_iterable(id_preds_truth)), | |
list(chain.from_iterable(hard_identification_preds)), | |
) | |
results[f"{name}_token_pred_f1"] = classification_report( | |
list(chain.from_iterable(id_preds_truth)), | |
list(chain.from_iterable(hard_identification_preds)), | |
output_dict=True, | |
) | |
# TODO for token level stuff! | |
soft_id_scores = [ | |
[1 - x, x] for x in chain.from_iterable(soft_identification_preds) | |
] | |
results[f"{name}_rationale"] = score_rationales( | |
annotations, decoding_docs, identifier_data, soft_id_scores | |
) | |
logging.info(f"Results: {results}") | |
# turn the above results into a format suitable for scoring via the rationale scorer | |
# n.b. the sentence-level evidence predictions (hard and soft) are | |
# broadcast to the token level for scoring. The comprehensiveness class | |
# score is also a lie since the pipeline model above is faithful by | |
# design. | |
decoded = dict() | |
scores = [] | |
assert len(identifier_data) == len(soft_identification_preds) | |
for ( | |
id_data, | |
soft_id_pred, | |
hard_id_pred, | |
soft_cls_preds, | |
hard_cls_pred, | |
) in zip( | |
identifier_data, | |
soft_identification_preds, | |
hard_identification_preds, | |
soft_classification_preds, | |
hard_classification_preds, | |
): | |
docid = id_data.docid | |
if use_cose_hack: | |
docid = "_".join(docid.split("_")[0:-1]) | |
assert len(docid) > 0 | |
rationales = { | |
"docid": docid, | |
"hard_rationale_predictions": [], | |
# token level classifications, a value must be provided per-token | |
# in an ideal world, these correspond to the hard-decoding above. | |
"soft_rationale_predictions": [], | |
# sentence level classifications, a value must be provided for every | |
# sentence in each document, or not at all | |
"soft_sentence_predictions": [1.0], | |
} | |
last = -1 | |
start_span = -1 | |
for pos, (start, _) in enumerate(token_mapping[id_data.docid][0]): | |
rationales["soft_rationale_predictions"].append(soft_id_pred[start]) | |
if bool(hard_id_pred[start]): | |
if start_span == -1: | |
start_span = pos | |
last = pos | |
else: | |
if start_span != -1: | |
rationales["hard_rationale_predictions"].append( | |
{ | |
"start_token": start_span, | |
"end_token": last + 1, | |
} | |
) | |
last = -1 | |
start_span = -1 | |
if start_span != -1: | |
rationales["hard_rationale_predictions"].append( | |
{ | |
"start_token": start_span, | |
"end_token": last + 1, | |
} | |
) | |
ann_id = id_data.ann_id | |
if use_cose_hack: | |
ann_id = "_".join(ann_id.split("_")[0:-1]) | |
soft_cls_preds = nn.functional.softmax(soft_cls_preds) | |
decoded[id_data.ann_id] = { | |
"annotation_id": ann_id, | |
"rationales": [rationales], | |
"classification": class_labels[hard_cls_pred], | |
"classification_scores": { | |
class_labels[i]: score.item() | |
for i, score in enumerate(soft_cls_preds) | |
}, | |
} | |
return results, list(decoded.values()) | |
# test_results, test_decoded = dict(), [] | |
# val_results, val_decoded = dict(), [] | |
train_results, train_decoded = dict(), [] | |
val_results, val_decoded = decode_batch( | |
prep(val), "val", score=True, annotations=val, class_labels=class_labels | |
) | |
test_results, test_decoded = decode_batch( | |
prep(test), "test", score=False, class_labels=class_labels | |
) | |
# train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train, class_labels=class_labels) | |
return ( | |
dict(**train_results, **val_results, **test_results), | |
train_decoded, | |
val_decoded, | |
test_decoded, | |
) | |