| | |
| | |
| | |
| | |
| | |
| | |
| | import torch |
| | from transformers import PreTrainedModel, PretrainedConfig |
| | from src.model import SegmentationNetwork |
| | from src.model.config import ModelConfig, TransformerConfig, CoSeNetConfig |
| |
|
| |
|
| | class SentenceCoseNetConfig(PretrainedConfig): |
| | """ |
| | Configuration class for SentenceCoseNet. |
| | |
| | This class stores all hyperparameters needed to initialize |
| | a `SentenceCoseNet` model. It follows Hugging Face's |
| | `PretrainedConfig` interface so the model can be saved, |
| | loaded, and shared via the Hub. |
| | |
| | Attributes: |
| | model_type (str): |
| | Identifier used by Hugging Face to register the model. |
| | vocab_size (int): |
| | Size of the tokenizer vocabulary. |
| | emb_dim (int): |
| | Dimensionality of token embeddings. |
| | seq_len (int): |
| | Maximum input sequence length supported by the model. |
| | dropout (float): |
| | Dropout probability applied in Transformer blocks. |
| | valid_padding (bool): |
| | Whether padding tokens are treated as valid positions. |
| | cosenet (dict): |
| | Configuration of the cosine-similarity network head. |
| | transformers (list[dict]): |
| | List of Transformer encoder block configurations. |
| | """ |
| |
|
| | model_type = "sentence_cosenet" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size: int = 32768, |
| | emb_dim: int = 256, |
| | seq_len: int = 382, |
| | dropout: float = 0.0, |
| | valid_padding: bool = True, |
| | cosenet: dict | None = None, |
| | transformers: list | None = None, |
| | **kwargs, |
| | ): |
| | """ |
| | Initialize SentenceCoseNet configuration. |
| | |
| | Args: |
| | vocab_size: |
| | Size of the tokenizer vocabulary. |
| | emb_dim: |
| | Dimension of token embeddings. |
| | seq_len: |
| | Maximum number of tokens per input sequence. |
| | dropout: |
| | Dropout probability used throughout the network. |
| | valid_padding: |
| | Whether padded tokens should be considered valid. |
| | cosenet: |
| | Optional configuration dictionary for the cosine |
| | similarity network head. |
| | transformers: |
| | Optional list of dictionaries describing each |
| | Transformer encoder block. |
| | **kwargs: |
| | Additional keyword arguments passed to |
| | `PretrainedConfig`. |
| | """ |
| | super().__init__(**kwargs) |
| |
|
| | self.vocab_size = vocab_size |
| | self.emb_dim = emb_dim |
| | self.seq_len = seq_len |
| | self.dropout = dropout |
| | self.valid_padding = valid_padding |
| |
|
| | self.cosenet = cosenet or { |
| | "trainable": True, |
| | "init_scale": 5.0 |
| | } |
| |
|
| | self.transformers = transformers or [ |
| | { |
| | "attention_heads": 16, |
| | "feed_forward_multiplier": 8, |
| | "dropout": 0.0, |
| | "pre_normalize": True |
| | }, |
| | { |
| | "attention_heads": 16, |
| | "feed_forward_multiplier": 8, |
| | "dropout": 0.0, |
| | "pre_normalize": True |
| | } |
| | ] |
| |
|
| | self.hidden_size = emb_dim |
| | self.max_position_embeddings = seq_len |
| |
|
| |
|
| | class SentenceCoseNet(PreTrainedModel): |
| | """ |
| | Sentence-level encoder model based on CoseNet. |
| | |
| | This class wraps a custom PyTorch segmentation network |
| | and exposes it as a Hugging Face `PreTrainedModel`, |
| | enabling interoperability with the Transformers ecosystem. |
| | |
| | The model is intended for: |
| | - Sentence embeddings |
| | - Semantic search |
| | - Information retrieval |
| | - Similarity learning |
| | """ |
| |
|
| | config_class = SentenceCoseNetConfig |
| | base_model_prefix = "cosenet" |
| |
|
| | def __init__(self, config: SentenceCoseNetConfig): |
| | """ |
| | Initialize the SentenceCoseNet model. |
| | |
| | Args: |
| | config: |
| | Instance of `SentenceCoseNetConfig` containing |
| | model hyperparameters. |
| | """ |
| | super().__init__(config) |
| | |
| | |
| | self.model = SegmentationNetwork(self.to_model_config(config)) |
| |
|
| | |
| | self.post_init() |
| |
|
| | |
| | self.model.eval() |
| |
|
| | def encode( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask=None |
| | ) -> torch.Tensor: |
| | """ |
| | Encode input token sequences into contextualized embeddings. |
| | |
| | This method performs embedding lookup, positional encoding, |
| | and Transformer-based contextualization, returning token-level |
| | representations. |
| | |
| | Args: |
| | input_ids: |
| | Tensor of token IDs with shape |
| | `(batch_size, sequence_length)`. |
| | attention_mask: |
| | Optional attention mask indicating valid (1) and |
| | padded (0) positions. Shape: |
| | `(batch_size, sequence_length)`. |
| | |
| | Returns: |
| | torch.Tensor: |
| | Contextualized token embeddings with shape |
| | `(batch_size, sequence_length, emb_dim)`. |
| | """ |
| | |
| | self.model.task = 'token_encoding' |
| | |
| | if len(input_ids.shape) == 2: |
| | x = input_ids.int().unsqueeze(1) |
| | mask = attention_mask.unsqueeze(1) if attention_mask is not None else None |
| | output = self.model(x=x, mask=mask).squeeze(1) |
| | elif len(input_ids.shape) == 3: |
| | x = input_ids.int() |
| | mask = attention_mask if attention_mask is not None else None |
| | output = self.model(x=x, mask=mask) |
| | else: |
| | raise ValueError("Input tensor must be of shape (Batch, Tokens) or (Batch, Sentences, Tokens).") |
| | return output |
| |
|
| | def get_sentence_embedding( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask=None, |
| | normalize: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | Compute sentence embeddings for zero-shot transfer and |
| | information retrieval. |
| | |
| | Args: |
| | input_ids (torch.Tensor): |
| | Tensor of shape (B, T) |
| | attention_mask (torch.Tensor, optional): |
| | Boolean or binary mask of shape (B, T) |
| | normalize (bool, optional): |
| | Whether to L2-normalize the output embeddings. |
| | |
| | Returns: |
| | torch.Tensor: |
| | Sentence embeddings of shape (B, D) |
| | """ |
| | |
| | self.model.task = 'sentence_encoding' |
| | output = self.call(input_ids, attention_mask) |
| |
|
| | if normalize: |
| | output = torch.nn.functional.normalize(output, p=2, dim=-1) |
| |
|
| | return output |
| |
|
| | def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Compute cosine similarity scores between two sets of embeddings. |
| | |
| | Args: |
| | embeddings_1 (torch.Tensor): |
| | Tensor of shape (B, S, D) containing the first set of |
| | embeddings concatenated along the first dimension. |
| | |
| | embeddings_2 (torch.Tensor): |
| | Tensor of shape (B, S, D) containing the second set of |
| | embeddings concatenated along the first dimension. |
| | |
| | Returns: |
| | torch.Tensor: |
| | Similarity scores of shape (B, S) |
| | """ |
| | |
| | embeddings = torch.stack([embeddings_1, embeddings_2], dim=-2) |
| | |
| | embeddings = self.model.distance_layer(embeddings) |
| | |
| | return (embeddings[..., 0, 1] + embeddings[..., 1, 0]) / 2 |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask=None, |
| | candidate_mask=None, |
| | **kwargs, |
| | ): |
| | """ |
| | Forward pass of the SentenceCoseNet model. |
| | |
| | This method delegates execution to the underlying |
| | `SegmentationNetwork`. |
| | |
| | Args: |
| | input_ids: |
| | Tensor of token IDs with shape |
| | `(batch_size, sequence_length)`. |
| | attention_mask: |
| | Optional attention mask tensor. |
| | candidate_mask: |
| | Optional mask indicating candidate segments or spans. |
| | **kwargs: |
| | Additional arguments forwarded to the core model. |
| | |
| | Returns: |
| | Model-specific output as produced by `SegmentationNetwork`. |
| | """ |
| | self.model.task = 'segmentation' |
| | return self.model( |
| | x=input_ids, |
| | mask=attention_mask, |
| | candidate_mask=candidate_mask, |
| | **kwargs, |
| | ) |
| |
|
| | def call(self, input_ids: torch.Tensor, attention_mask=None) -> torch.Tensor: |
| | """ |
| | Internal method to handle different input shapes (task already selected). |
| | Args: |
| | input_ids: |
| | Tensor of token IDs with shape |
| | `(batch_size, sequence_length)`. |
| | attention_mask: |
| | Optional attention mask tensor. |
| | """ |
| | |
| | if len(input_ids.shape) == 2: |
| | x = input_ids.int().unsqueeze(1) |
| | mask = attention_mask.unsqueeze(1) if attention_mask is not None else None |
| | output = self.model(x=x, mask=mask).squeeze(1) |
| | elif len(input_ids.shape) == 3: |
| | x = input_ids.int() |
| | mask = attention_mask if attention_mask is not None else None |
| | output = self.model(x=x, mask=mask) |
| | else: |
| | raise ValueError("Input tensor must be of shape (Batch, Tokens) or (Batch, Sentences, Tokens).") |
| | return output |
| |
|
| | @staticmethod |
| | def to_model_config(config: SentenceCoseNetConfig) -> ModelConfig: |
| | """ |
| | Convert Hugging Face config to internal ModelConfig. |
| | """ |
| | mc = ModelConfig() |
| |
|
| | |
| | mc.vocab_size = config.vocab_size |
| | mc.model_dim = config.emb_dim |
| | mc.valid_padding = config.valid_padding |
| |
|
| | |
| | mc.cosenet = CoSeNetConfig(**config.cosenet) |
| |
|
| | |
| | mc.transformers = [ |
| | TransformerConfig(**cfg) |
| | for cfg in config.transformers |
| | ] |
| |
|
| | return mc |
| | |
| | |
| | |
| |
|