Spaces:
Sleeping
Sleeping
File size: 4,264 Bytes
03f6091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# -*- coding: utf-8 -*-
import torch
from polos.tokenizers_ import TextEncoderBase
def average_pooling(
tokens: torch.Tensor,
embeddings: torch.Tensor,
mask: torch.Tensor,
padding_index: int,
) -> torch.Tensor:
"""Average pooling function.
:param tokens: Word ids [batch_size x seq_length]
:param embeddings: Word embeddings [batch_size x seq_length x hidden_size]
:param mask: Padding mask [batch_size x seq_length]
:param padding_index: Padding value.
"""
wordemb = mask_fill(0.0, tokens, embeddings, padding_index)
sentemb = torch.sum(wordemb, 1)
sum_mask = mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1)
return sentemb / sum_mask
def max_pooling(
tokens: torch.Tensor, embeddings: torch.Tensor, padding_index: int
) -> torch.Tensor:
"""Max pooling function.
:param tokens: Word ids [batch_size x seq_length]
:param embeddings: Word embeddings [batch_size x seq_length x hidden_size]
:param padding_index: Padding value.
"""
return mask_fill(float("-inf"), tokens, embeddings, padding_index).max(dim=1)[0]
def mask_fill(
fill_value: float,
tokens: torch.Tensor,
embeddings: torch.Tensor,
padding_index: int,
) -> torch.Tensor:
"""
Function that masks embeddings representing padded elements.
:param fill_value: the value to fill the embeddings belonging to padded tokens.
:param tokens: The input sequences [bsz x seq_len].
:param embeddings: word embeddings [bsz x seq_len x hiddens].
:param padding_index: Index of the padding token.
"""
padding_mask = tokens.eq(padding_index).unsqueeze(-1)
return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings)
def sort_sequences(inputs: torch.Tensor, input_lengths: torch.Tensor):
"""
Sort sequences according to lengths of the input sequence (descendingly).
:param inputs (Tensor): input sequences, size [B, T, D]
:param input_lengths (Tensor): length of each sequence, size [B]
"""
lengths_sorted, sorted_idx = input_lengths.sort(descending=True)
_, unsorted_idx = sorted_idx.sort()
return inputs[sorted_idx], lengths_sorted, unsorted_idx
def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
def _apply(x):
if torch.is_tensor(x):
return f(x)
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
else:
return x
return _apply(sample)
def move_to_cuda(sample):
""" Moves a sample to cuda. Works with dictionaries, tensors and lists. """
def _move_to_cuda(tensor):
return tensor.cuda()
return apply_to_sample(_move_to_cuda, sample)
def move_to_cpu(sample):
""" Moves a sample to cuda. Works with dictionaries, tensors and lists. """
def _move_to_cpu(tensor):
return tensor.cpu()
return apply_to_sample(_move_to_cpu, sample)
# --------------- LASER auxiliar functions from facebook research ------------------------------
def buffered_arange(max):
if not hasattr(buffered_arange, "buf"):
buffered_arange.buf = torch.LongTensor()
if max > buffered_arange.buf.numel():
torch.arange(max, out=buffered_arange.buf)
return buffered_arange.buf[:max]
def convert_padding_direction(
src_tokens, padding_idx, right_to_left=False, left_to_right=False
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
|