Spaces:
Runtime error
Runtime error
File size: 6,328 Bytes
6065472 487e498 6065472 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
def sort_pack_padded_sequence(input, lengths):
sorted_lengths, indices = torch.sort(lengths, descending=True)
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
inv_ix = indices.clone()
inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix)
return tmp, inv_ix
def pad_unsort_packed_sequence(input, inv_ix):
tmp, _ = pad_packed_sequence(input, batch_first=True)
tmp = tmp[inv_ix]
return tmp
def pack_wrapper(module, attn_feats, attn_feat_lens):
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
if isinstance(module, torch.nn.RNNBase):
return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
else:
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
def generate_length_mask(lens, max_length=None):
lens = torch.as_tensor(lens)
N = lens.size(0)
if max_length is None:
max_length = max(lens)
if isinstance(max_length, torch.Tensor):
max_length = max_length.item()
idxs = torch.arange(max_length).repeat(N).view(N, max_length)
idxs = idxs.to(lens.device)
mask = (idxs < lens.view(-1, 1))
return mask
def mean_with_lens(features, lens):
"""
features: [N, T, ...] (assume the second dimension represents length)
lens: [N,]
"""
lens = torch.as_tensor(lens)
if max(lens) != features.size(1):
max_length = features.size(1)
mask = generate_length_mask(lens, max_length)
else:
mask = generate_length_mask(lens)
mask = mask.to(features.device) # [N, T]
while mask.ndim < features.ndim:
mask = mask.unsqueeze(-1)
feature_mean = features * mask
feature_mean = feature_mean.sum(1)
while lens.ndim < feature_mean.ndim:
lens = lens.unsqueeze(1)
feature_mean = feature_mean / lens.to(features.device)
# feature_mean = features * mask.unsqueeze(-1)
# feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
return feature_mean
def max_with_lens(features, lens):
"""
features: [N, T, ...] (assume the second dimension represents length)
lens: [N,]
"""
lens = torch.as_tensor(lens)
if max(lens) != features.size(1):
max_length = features.size(1)
mask = generate_length_mask(lens, max_length)
else:
mask = generate_length_mask(lens)
mask = mask.to(features.device) # [N, T]
feature_max = features.clone()
feature_max[~mask] = float("-inf")
feature_max, _ = feature_max.max(1)
return feature_max
def repeat_tensor(x, n):
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
def init(m, method="kaiming"):
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
if method == "kaiming":
nn.init.kaiming_uniform_(m.weight)
elif method == "xavier":
nn.init.xavier_uniform_(m.weight)
else:
raise Exception(f"initialization method {method} not supported")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
if method == "kaiming":
nn.init.kaiming_uniform_(m.weight)
elif method == "xavier":
nn.init.xavier_uniform_(m.weight)
else:
raise Exception(f"initialization method {method} not supported")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
if method == "kaiming":
nn.init.kaiming_uniform_(m.weight)
elif method == "xavier":
nn.init.xavier_uniform_(m.weight)
else:
raise Exception(f"initialization method {method} not supported")
def compute_batch_score(decode_res,
key2refs,
keys,
start_idx,
end_idx,
vocabulary,
scorer):
"""
Args:
decode_res: decoding results of model, [N, max_length]
key2refs: references of all samples, dict(<key> -> [ref_1, ref_2, ..., ref_n]
keys: keys of this batch, used to match decode results and refs
Return:
scores of this batch, [N,]
"""
if scorer is None:
from pycocoevalcap.cider.cider import Cider
scorer = Cider()
hypothesis = {}
references = {}
for i in range(len(keys)):
if keys[i] in hypothesis.keys():
continue
# prepare candidate sentence
candidate = []
for w_t in decode_res[i]:
if w_t == start_idx:
continue
elif w_t == end_idx:
break
candidate.append(vocabulary.idx2word[w_t])
hypothesis[keys[i]] = [" ".join(candidate), ]
# prepare reference sentences
references[keys[i]] = key2refs[keys[i]]
score, scores = scorer.compute_score(references, hypothesis)
key2score = {key: scores[i] for i, key in enumerate(references.keys())}
results = np.zeros(decode_res.shape[0])
for i in range(decode_res.shape[0]):
results[i] = key2score[keys[i]]
return results
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=100):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
# self.register_buffer("pe", pe)
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
def forward(self, x):
# x: [T, N, E]
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
|