Spaces:
Paused
Paused
File size: 1,190 Bytes
0fdb130 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
def mean_pooling(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Perform attention-aware mean pooling.
This method takes in embeddings of shape (batch, sequence, embedding_size) and performs average
pooling across the sequence dimension to yield embeddings of size (batch, embedding_size).
From:
https://github.com/UKPLab/sentence-transformers/blob/0b5ef4be93d2b21de3a918a084b48aab6ba48595/sentence_transformers/model_card_templates.py#L134 # noqa: E501
Args:
token_embeddings (`torch.Tensor`): The embeddings we wish to pool over of shape
(batch, sequence, embedding_size). This will pool over the sequence to yield
(batch, embedding_size).
attention_mask (`torch.Tensor`): The binary attention mask across the embedings of shape
Returns:
(`torch.Tensor`) The mean pooled embeddings of size (batch, embedding_size).
"""
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|