piecurus's picture
extractive
7215c40
raw
history blame
6.65 kB
from typing import List, Union
import torch
import streamlit as st
import numpy as np
from numpy import ndarray
from transformers import (AlbertModel, AlbertTokenizer, BertModel,
BertTokenizer, DistilBertModel, DistilBertTokenizer,
PreTrainedModel, PreTrainedTokenizer, XLMModel,
XLMTokenizer, XLNetModel, XLNetTokenizer)
@st.cache()
def load_hf_model(base_model, model_name, device):
model = base_model.from_pretrained(model_name, output_hidden_states=True).to(device)
return model
class BertParent(object):
"""
Base handler for BERT models.
"""
MODELS = {
'bert-base-uncased': (BertModel, BertTokenizer),
'bert-large-uncased': (BertModel, BertTokenizer),
'xlnet-base-cased': (XLNetModel, XLNetTokenizer),
'xlm-mlm-enfr-1024': (XLMModel, XLMTokenizer),
'distilbert-base-uncased': (DistilBertModel, DistilBertTokenizer),
'albert-base-v1': (AlbertModel, AlbertTokenizer),
'albert-large-v1': (AlbertModel, AlbertTokenizer)
}
def __init__(
self,
model: str,
custom_model: PreTrainedModel = None,
custom_tokenizer: PreTrainedTokenizer = None,
gpu_id: int = 0,
):
"""
:param model: Model is the string path for the bert weights. If given a keyword, the s3 path will be used.
:param custom_model: This is optional if a custom bert model is used.
:param custom_tokenizer: Place to use custom tokenizer.
"""
base_model, base_tokenizer = self.MODELS.get(model, (None, None))
self.device = torch.device("cpu")
if torch.cuda.is_available():
assert (
isinstance(gpu_id, int) and (0 <= gpu_id and gpu_id < torch.cuda.device_count())
), f"`gpu_id` must be an integer between 0 to {torch.cuda.device_count() - 1}. But got: {gpu_id}"
self.device = torch.device(f"cuda:{gpu_id}")
if custom_model:
self.model = custom_model.to(self.device)
else:
# self.model = base_model.from_pretrained(
# model, output_hidden_states=True).to(self.device)
self.model = load_hf_model(base_model, model, self.device)
if custom_tokenizer:
self.tokenizer = custom_tokenizer
else:
self.tokenizer = base_tokenizer.from_pretrained(model)
self.model.eval()
def tokenize_input(self, text: str) -> torch.tensor:
"""
Tokenizes the text input.
:param text: Text to tokenize.
:return: Returns a torch tensor.
"""
tokenized_text = self.tokenizer.tokenize(text)
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
return torch.tensor([indexed_tokens]).to(self.device)
def _pooled_handler(self, hidden: torch.Tensor,
reduce_option: str) -> torch.Tensor:
"""
Handles torch tensor.
:param hidden: The hidden torch tensor to process.
:param reduce_option: The reduce option to use, such as mean, etc.
:return: Returns a torch tensor.
"""
if reduce_option == 'max':
return hidden.max(dim=1)[0].squeeze()
elif reduce_option == 'median':
return hidden.median(dim=1)[0].squeeze()
return hidden.mean(dim=1).squeeze()
def extract_embeddings(
self,
text: str,
hidden: Union[List[int], int] = -2,
reduce_option: str = 'mean',
hidden_concat: bool = False,
) -> torch.Tensor:
"""
Extracts the embeddings for the given text.
:param text: The text to extract embeddings for.
:param hidden: The hidden layer(s) to use for a readout handler.
:param squeeze: If we should squeeze the outputs (required for some layers).
:param reduce_option: How we should reduce the items.
:param hidden_concat: Whether or not to concat multiple hidden layers.
:return: A torch vector.
"""
tokens_tensor = self.tokenize_input(text)
pooled, hidden_states = self.model(tokens_tensor)[-2:]
# deprecated temporary keyword functions.
if reduce_option == 'concat_last_4':
last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
cat_hidden_states = torch.cat(tuple(last_4), dim=-1)
return torch.mean(cat_hidden_states, dim=1).squeeze()
elif reduce_option == 'reduce_last_4':
last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
return torch.cat(tuple(last_4), dim=1).mean(axis=1).squeeze()
elif type(hidden) == int:
hidden_s = hidden_states[hidden]
return self._pooled_handler(hidden_s, reduce_option)
elif hidden_concat:
last_states = [hidden_states[i] for i in hidden]
cat_hidden_states = torch.cat(tuple(last_states), dim=-1)
return torch.mean(cat_hidden_states, dim=1).squeeze()
last_states = [hidden_states[i] for i in hidden]
hidden_s = torch.cat(tuple(last_states), dim=1)
return self._pooled_handler(hidden_s, reduce_option)
def create_matrix(
self,
content: List[str],
hidden: Union[List[int], int] = -2,
reduce_option: str = 'mean',
hidden_concat: bool = False,
) -> ndarray:
"""
Create matrix from the embeddings.
:param content: The list of sentences.
:param hidden: Which hidden layer to use.
:param reduce_option: The reduce option to run.
:param hidden_concat: Whether or not to concat multiple hidden layers.
:return: A numpy array matrix of the given content.
"""
return np.asarray([
np.squeeze(self.extract_embeddings(
t, hidden=hidden, reduce_option=reduce_option, hidden_concat=hidden_concat
).data.cpu().numpy()) for t in content
])
def __call__(
self,
content: List[str],
hidden: int = -2,
reduce_option: str = 'mean',
hidden_concat: bool = False,
) -> ndarray:
"""
Create matrix from the embeddings.
:param content: The list of sentences.
:param hidden: Which hidden layer to use.
:param reduce_option: The reduce option to run.
:param hidden_concat: Whether or not to concat multiple hidden layers.
:return: A numpy array matrix of the given content.
"""
return self.create_matrix(content, hidden, reduce_option, hidden_concat)