Sarthak
chore: moved model2vec as in internal package
473c3a0
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Self, TypeVar
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from distiller.model2vec import StaticModel
if TYPE_CHECKING:
from tokenizers import Encoding, Tokenizer
logger = logging.getLogger(__name__)
class FinetunableStaticModel(nn.Module):
def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
"""
Initialize a trainable StaticModel from a StaticModel.
:param vectors: The embeddings of the staticmodel.
:param tokenizer: The tokenizer.
:param out_dim: The output dimension of the head.
:param pad_id: The padding id. This is set to 0 in almost all model2vec models
"""
super().__init__()
self.pad_id = pad_id
self.out_dim = out_dim
self.embed_dim = vectors.shape[1]
self.vectors = vectors
if self.vectors.dtype != torch.float32:
dtype = str(self.vectors.dtype)
logger.warning(
f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
)
self.vectors = vectors.float()
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
self.head = self.construct_head()
self.w = self.construct_weights()
self.tokenizer = tokenizer
def construct_weights(self) -> nn.Parameter:
"""Construct the weights for the model."""
weights = torch.zeros(len(self.vectors))
weights[self.pad_id] = -10_000
return nn.Parameter(weights)
def construct_head(self) -> nn.Sequential:
"""Method should be overridden for various other classes."""
return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
@classmethod
def from_pretrained(
cls, *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
) -> Self:
"""Load the model from a pretrained model2vec model."""
model = StaticModel.from_pretrained(model_name)
return cls.from_static_model(model=model, out_dim=out_dim, **kwargs)
@classmethod
def from_static_model(cls, *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> Self:
"""Load the model from a static model."""
model.embedding = np.nan_to_num(model.embedding)
embeddings_converted = torch.from_numpy(model.embedding)
return cls(
vectors=embeddings_converted,
pad_id=model.tokenizer.token_to_id("[PAD]"),
out_dim=out_dim,
tokenizer=model.tokenizer,
**kwargs,
)
def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
A forward pass and mean pooling.
This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients
to pass through.
:param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
:return: The mean over the input ids, weighted by token weights.
"""
w = self.w[input_ids]
w = torch.sigmoid(w)
zeros = (input_ids != self.pad_id).float()
w = w * zeros
# Add a small epsilon to avoid division by zero
length = zeros.sum(1) + 1e-16
embedded = self.embeddings(input_ids)
# Weigh each token
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
# Mean pooling by dividing by the length
embedded = embedded / length[:, None]
return nn.functional.normalize(embedded)
def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the mean, and a classifier layer after."""
encoded = self._encode(input_ids)
return self.head(encoded), encoded
def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor:
"""
Tokenize a bunch of strings into a single padded 2D tensor.
Note that this is not used during training.
:param texts: The texts to tokenize.
:param max_length: If this is None, the sequence lengths are truncated to 512.
:return: A 2D padded tensor
"""
encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)
@property
def device(self) -> str:
"""Get the device of the model."""
return self.embeddings.weight.device
def to_static_model(self) -> StaticModel:
"""Convert the model to a static model."""
emb = self.embeddings.weight.detach().cpu().numpy()
w = torch.sigmoid(self.w).detach().cpu().numpy()
return StaticModel(emb * w[:, None], self.tokenizer, normalize=True)
class TextDataset(Dataset):
def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None:
"""
A dataset of texts.
:param tokenized_texts: The tokenized texts. Each text is a list of token ids.
:param targets: The targets.
:raises ValueError: If the number of labels does not match the number of texts.
"""
if len(targets) != len(tokenized_texts):
msg = "Number of labels does not match number of texts."
raise ValueError(msg)
self.tokenized_texts = tokenized_texts
self.targets = targets
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.tokenized_texts)
def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
"""Gets an item."""
return self.tokenized_texts[index], self.targets[index]
@staticmethod
def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
"""Collate function."""
texts, targets = zip(*batch, strict=False)
tensors = [torch.LongTensor(x) for x in texts]
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
return padded, torch.stack(targets)
def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
"""Convert the dataset to a DataLoader."""
return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
ModelType = TypeVar("ModelType", bound=FinetunableStaticModel)