comparative-explainability
/
Transformer-Explainability
/BERT_rationale_benchmark
/models
/model_utils.py
from dataclasses import dataclass | |
from typing import Dict, List, Set | |
import numpy as np | |
import torch | |
from gensim.models import KeyedVectors | |
from torch import nn | |
from torch.nn.utils.rnn import (PackedSequence, pack_padded_sequence, | |
pad_packed_sequence, pad_sequence) | |
class PaddedSequence: | |
"""A utility class for padding variable length sequences mean for RNN input | |
This class is in the style of PackedSequence from the PyTorch RNN Utils, | |
but is somewhat more manual in approach. It provides the ability to generate masks | |
for outputs of the same input dimensions. | |
The constructor should never be called directly and should only be called via | |
the autopad classmethod. | |
We'd love to delete this, but we pad_sequence, pack_padded_sequence, and | |
pad_packed_sequence all require shuffling around tuples of information, and some | |
convenience methods using these are nice to have. | |
""" | |
data: torch.Tensor | |
batch_sizes: torch.Tensor | |
batch_first: bool = False | |
def autopad( | |
cls, data, batch_first: bool = False, padding_value=0, device=None | |
) -> "PaddedSequence": | |
# handle tensors of size 0 (single item) | |
data_ = [] | |
for d in data: | |
if len(d.size()) == 0: | |
d = d.unsqueeze(0) | |
data_.append(d) | |
padded = pad_sequence( | |
data_, batch_first=batch_first, padding_value=padding_value | |
) | |
if batch_first: | |
batch_lengths = torch.LongTensor([len(x) for x in data_]) | |
if any([x == 0 for x in batch_lengths]): | |
raise ValueError( | |
"Found a 0 length batch element, this can't possibly be right: {}".format( | |
batch_lengths | |
) | |
) | |
else: | |
# TODO actually test this codepath | |
batch_lengths = torch.LongTensor([len(x) for x in data]) | |
return PaddedSequence(padded, batch_lengths, batch_first).to(device=device) | |
def pack_other(self, data: torch.Tensor): | |
return pack_padded_sequence( | |
data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False | |
) | |
def from_packed_sequence( | |
cls, ps: PackedSequence, batch_first: bool, padding_value=0 | |
) -> "PaddedSequence": | |
padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value) | |
return PaddedSequence(padded, batch_sizes, batch_first) | |
def cuda(self) -> "PaddedSequence": | |
return PaddedSequence( | |
self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first | |
) | |
def to( | |
self, dtype=None, device=None, copy=False, non_blocking=False | |
) -> "PaddedSequence": | |
# TODO make to() support all of the torch.Tensor to() variants | |
return PaddedSequence( | |
self.data.to( | |
dtype=dtype, device=device, copy=copy, non_blocking=non_blocking | |
), | |
self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking), | |
batch_first=self.batch_first, | |
) | |
def mask( | |
self, on=int(0), off=int(0), device="cpu", size=None, dtype=None | |
) -> torch.Tensor: | |
if size is None: | |
size = self.data.size() | |
out_tensor = torch.zeros(*size, dtype=dtype) | |
# TODO this can be done more efficiently | |
out_tensor.fill_(off) | |
# note to self: these are probably less efficient than explicilty populating the off values instead of the on values. | |
if self.batch_first: | |
for i, bl in enumerate(self.batch_sizes): | |
out_tensor[i, :bl] = on | |
else: | |
for i, bl in enumerate(self.batch_sizes): | |
out_tensor[:bl, i] = on | |
return out_tensor.to(device) | |
def unpad(self, other: torch.Tensor) -> List[torch.Tensor]: | |
out = [] | |
for o, bl in zip(other, self.batch_sizes): | |
out.append(o[:bl]) | |
return out | |
def flip(self) -> "PaddedSequence": | |
return PaddedSequence( | |
self.data.transpose(0, 1), not self.batch_first, self.padding_value | |
) | |
def extract_embeddings( | |
vocab: Set[str], embedding_file: str, unk_token: str = "UNK", pad_token: str = "PAD" | |
) -> (nn.Embedding, Dict[str, int], List[str]): | |
vocab = vocab | set([unk_token, pad_token]) | |
if embedding_file.endswith(".bin"): | |
WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True) | |
word_to_vector = dict() | |
WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()]) | |
if unk_token not in WVs: | |
mean_vector = np.mean(WV_matrix, axis=0) | |
word_to_vector[unk_token] = mean_vector | |
if pad_token not in WVs: | |
word_to_vector[pad_token] = np.zeros(WVs.vector_size) | |
for v in vocab: | |
if v in WVs: | |
word_to_vector[v] = WVs[v] | |
interner = dict() | |
deinterner = list() | |
vectors = [] | |
count = 0 | |
for word in [pad_token, unk_token] + sorted( | |
list(word_to_vector.keys() - {unk_token, pad_token}) | |
): | |
vector = word_to_vector[word] | |
vectors.append(np.array(vector)) | |
interner[word] = count | |
deinterner.append(word) | |
count += 1 | |
vectors = torch.FloatTensor(np.array(vectors)) | |
embedding = nn.Embedding.from_pretrained( | |
vectors, padding_idx=interner[pad_token] | |
) | |
embedding.weight.requires_grad = False | |
return embedding, interner, deinterner | |
elif embedding_file.endswith(".txt"): | |
word_to_vector = dict() | |
vector = [] | |
with open(embedding_file, "r") as inf: | |
for line in inf: | |
contents = line.strip().split() | |
word = contents[0] | |
vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0) | |
word_to_vector[word] = vector | |
embed_size = vector.size() | |
if unk_token not in word_to_vector: | |
mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0) | |
word_to_vector[unk_token] = mean_vector.unsqueeze(0) | |
if pad_token not in word_to_vector: | |
word_to_vector[pad_token] = torch.zeros(embed_size) | |
interner = dict() | |
deinterner = list() | |
vectors = [] | |
count = 0 | |
for word in [pad_token, unk_token] + sorted( | |
list(word_to_vector.keys() - {unk_token, pad_token}) | |
): | |
vector = word_to_vector[word] | |
vectors.append(vector) | |
interner[word] = count | |
deinterner.append(word) | |
count += 1 | |
vectors = torch.cat(vectors, dim=0) | |
embedding = nn.Embedding.from_pretrained( | |
vectors, padding_idx=interner[pad_token] | |
) | |
embedding.weight.requires_grad = False | |
return embedding, interner, deinterner | |
else: | |
raise ValueError("Unable to open embeddings file {}".format(embedding_file)) | |