File size: 7,109 Bytes
c4b2b37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from dataclasses import dataclass
from typing import Dict, List, Set
import numpy as np
import torch
from gensim.models import KeyedVectors
from torch import nn
from torch.nn.utils.rnn import (PackedSequence, pack_padded_sequence,
pad_packed_sequence, pad_sequence)
@dataclass(eq=True, frozen=True)
class PaddedSequence:
"""A utility class for padding variable length sequences mean for RNN input
This class is in the style of PackedSequence from the PyTorch RNN Utils,
but is somewhat more manual in approach. It provides the ability to generate masks
for outputs of the same input dimensions.
The constructor should never be called directly and should only be called via
the autopad classmethod.
We'd love to delete this, but we pad_sequence, pack_padded_sequence, and
pad_packed_sequence all require shuffling around tuples of information, and some
convenience methods using these are nice to have.
"""
data: torch.Tensor
batch_sizes: torch.Tensor
batch_first: bool = False
@classmethod
def autopad(
cls, data, batch_first: bool = False, padding_value=0, device=None
) -> "PaddedSequence":
# handle tensors of size 0 (single item)
data_ = []
for d in data:
if len(d.size()) == 0:
d = d.unsqueeze(0)
data_.append(d)
padded = pad_sequence(
data_, batch_first=batch_first, padding_value=padding_value
)
if batch_first:
batch_lengths = torch.LongTensor([len(x) for x in data_])
if any([x == 0 for x in batch_lengths]):
raise ValueError(
"Found a 0 length batch element, this can't possibly be right: {}".format(
batch_lengths
)
)
else:
# TODO actually test this codepath
batch_lengths = torch.LongTensor([len(x) for x in data])
return PaddedSequence(padded, batch_lengths, batch_first).to(device=device)
def pack_other(self, data: torch.Tensor):
return pack_padded_sequence(
data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False
)
@classmethod
def from_packed_sequence(
cls, ps: PackedSequence, batch_first: bool, padding_value=0
) -> "PaddedSequence":
padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value)
return PaddedSequence(padded, batch_sizes, batch_first)
def cuda(self) -> "PaddedSequence":
return PaddedSequence(
self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first
)
def to(
self, dtype=None, device=None, copy=False, non_blocking=False
) -> "PaddedSequence":
# TODO make to() support all of the torch.Tensor to() variants
return PaddedSequence(
self.data.to(
dtype=dtype, device=device, copy=copy, non_blocking=non_blocking
),
self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking),
batch_first=self.batch_first,
)
def mask(
self, on=int(0), off=int(0), device="cpu", size=None, dtype=None
) -> torch.Tensor:
if size is None:
size = self.data.size()
out_tensor = torch.zeros(*size, dtype=dtype)
# TODO this can be done more efficiently
out_tensor.fill_(off)
# note to self: these are probably less efficient than explicilty populating the off values instead of the on values.
if self.batch_first:
for i, bl in enumerate(self.batch_sizes):
out_tensor[i, :bl] = on
else:
for i, bl in enumerate(self.batch_sizes):
out_tensor[:bl, i] = on
return out_tensor.to(device)
def unpad(self, other: torch.Tensor) -> List[torch.Tensor]:
out = []
for o, bl in zip(other, self.batch_sizes):
out.append(o[:bl])
return out
def flip(self) -> "PaddedSequence":
return PaddedSequence(
self.data.transpose(0, 1), not self.batch_first, self.padding_value
)
def extract_embeddings(
vocab: Set[str], embedding_file: str, unk_token: str = "UNK", pad_token: str = "PAD"
) -> (nn.Embedding, Dict[str, int], List[str]):
vocab = vocab | set([unk_token, pad_token])
if embedding_file.endswith(".bin"):
WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True)
word_to_vector = dict()
WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()])
if unk_token not in WVs:
mean_vector = np.mean(WV_matrix, axis=0)
word_to_vector[unk_token] = mean_vector
if pad_token not in WVs:
word_to_vector[pad_token] = np.zeros(WVs.vector_size)
for v in vocab:
if v in WVs:
word_to_vector[v] = WVs[v]
interner = dict()
deinterner = list()
vectors = []
count = 0
for word in [pad_token, unk_token] + sorted(
list(word_to_vector.keys() - {unk_token, pad_token})
):
vector = word_to_vector[word]
vectors.append(np.array(vector))
interner[word] = count
deinterner.append(word)
count += 1
vectors = torch.FloatTensor(np.array(vectors))
embedding = nn.Embedding.from_pretrained(
vectors, padding_idx=interner[pad_token]
)
embedding.weight.requires_grad = False
return embedding, interner, deinterner
elif embedding_file.endswith(".txt"):
word_to_vector = dict()
vector = []
with open(embedding_file, "r") as inf:
for line in inf:
contents = line.strip().split()
word = contents[0]
vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0)
word_to_vector[word] = vector
embed_size = vector.size()
if unk_token not in word_to_vector:
mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0)
word_to_vector[unk_token] = mean_vector.unsqueeze(0)
if pad_token not in word_to_vector:
word_to_vector[pad_token] = torch.zeros(embed_size)
interner = dict()
deinterner = list()
vectors = []
count = 0
for word in [pad_token, unk_token] + sorted(
list(word_to_vector.keys() - {unk_token, pad_token})
):
vector = word_to_vector[word]
vectors.append(vector)
interner[word] = count
deinterner.append(word)
count += 1
vectors = torch.cat(vectors, dim=0)
embedding = nn.Embedding.from_pretrained(
vectors, padding_idx=interner[pad_token]
)
embedding.weight.requires_grad = False
return embedding, interner, deinterner
else:
raise ValueError("Unable to open embeddings file {}".format(embedding_file))
|