|
import torch |
|
import logging |
|
|
|
from torch import Tensor |
|
from typing import Mapping |
|
|
|
|
|
def _setup_logger(): |
|
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") |
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setFormatter(log_format) |
|
logger.handlers = [console_handler] |
|
|
|
return logger |
|
|
|
|
|
logger = _setup_logger() |
|
|
|
|
|
def move_to_cuda(sample): |
|
if len(sample) == 0: |
|
return {} |
|
|
|
def _move_to_cuda(maybe_tensor): |
|
if torch.is_tensor(maybe_tensor): |
|
return maybe_tensor.cuda(non_blocking=True) |
|
elif isinstance(maybe_tensor, dict): |
|
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} |
|
elif isinstance(maybe_tensor, list): |
|
return [_move_to_cuda(x) for x in maybe_tensor] |
|
elif isinstance(maybe_tensor, tuple): |
|
return tuple([_move_to_cuda(x) for x in maybe_tensor]) |
|
elif isinstance(maybe_tensor, Mapping): |
|
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()}) |
|
else: |
|
return maybe_tensor |
|
|
|
return _move_to_cuda(sample) |
|
|
|
|
|
def pool(last_hidden_states: Tensor, |
|
attention_mask: Tensor, |
|
pool_type: str) -> Tensor: |
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
|
|
if pool_type == "avg": |
|
emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
elif pool_type == "cls": |
|
emb = last_hidden[:, 0] |
|
else: |
|
raise ValueError(f"pool_type {pool_type} not supported") |
|
|
|
return emb |