alverciito
fix huggingface model missmatch
34f99b8
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
from .config import ModelConfig
from .cosenet import CosineDistanceLayer, CoSeNet
from .transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling
class SegmentationNetwork(torch.nn.Module):
"""
Segmentation network combining Transformer encoders with CoSeNet.
This model integrates token embeddings and positional encodings with
a stack of Transformer encoder blocks to produce contextualized
representations. These representations are then processed by a
CoSeNet module to perform structured segmentation, followed by a
cosine-based distance computation.
The final output is a pair-wise distance matrix suitable for
segmentation or boundary detection tasks.
"""
def __init__(self, model_config: ModelConfig, task='segmentation', **kwargs):
"""
Initialize the segmentation network.
The network is composed of an embedding layer, positional encoding,
multiple Transformer encoder blocks, a CoSeNet segmentation module,
and a cosine distance layer.
Args:
model_config (ModelConfig): Configuration object containing all
hyperparameters required to build the model, including
vocabulary size, model dimensionality, transformer settings,
and CoSeNet parameters.
**kwargs: Additional keyword arguments forwarded to
`torch.nn.Module`.
"""
super().__init__(**kwargs)
self.valid_padding = model_config.valid_padding
# Build layers:
self.embedding = torch.nn.Embedding(
model_config.vocab_size,
model_config.model_dim
)
self.positional_encoding = PositionalEncoding(
emb_dim=model_config.model_dim,
max_len=model_config.max_tokens
)
self.cosenet = CoSeNet(
trainable=model_config.cosenet.trainable,
init_scale=model_config.cosenet.init_scale
)
self.distance_layer = CosineDistanceLayer()
self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding)
# Build encoder blocks:
module_list = list()
for transformer_config in model_config.transformers:
encoder_block = EncoderBlock(
feature_dim=model_config.model_dim,
attention_heads=transformer_config.attention_heads,
feed_forward_multiplier=transformer_config.feed_forward_multiplier,
dropout=transformer_config.dropout,
valid_padding=model_config.valid_padding,
pre_normalize=transformer_config.pre_normalize
)
module_list.append(encoder_block)
self.encoder_blocks = torch.nn.ModuleList(module_list)
self.task = task
if self.task not in ['segmentation', 'similarity', 'token_encoding', 'sentence_encoding']:
raise ValueError(f"Invalid task '{self.task}'. Supported tasks are 'segmentation', 'similarity', "
f"'token_encoding', and 'sentence_encoding'.")
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the segmentation network.
The input token indices are embedded and enriched with positional
information, then processed by a stack of Transformer encoder
blocks. The resulting representations are segmented using CoSeNet
and finally transformed into a pair-wise distance representation.
Args:
x (torch.Tensor): Input tensor of token indices with shape
(batch_size, sequence_length).
mask (torch.Tensor, optional): Optional mask tensor indicating
valid or padded positions, depending on the configuration
of the Transformer blocks. Defaults to None.
If `valid_padding` is disabled, the mask is inverted before being
passed to CoSeNet to match its masking convention.
candidate_mask (torch.Tensor, optional): Optional mask tensor for
candidate positions in CoSeNet. Defaults to None.
If `valid_padding` is disabled, the mask is inverted before being
passed to CoSeNet to match its masking convention.
Returns:
torch.Tensor: Output tensor containing pairwise distance values
derived from the segmented representations.
"""
# Convert to type:
x = x.int()
# Embedding and positional encoding:
x = self.embedding(x)
x = self.positional_encoding(x)
# Reshape x and mask:
_b, _s, _t, _d = x.shape
x = x.reshape(_b * _s, _t, _d)
if mask is not None:
mask = mask.reshape(_b * _s, _t).bool()
# Encode the sequence:
for encoder in self.encoder_blocks:
x = encoder(x, mask=mask)
# Reshape x and mask:
x = x.reshape(_b, _s, _t, _d)
if mask is not None:
mask = mask.reshape(_b, _s, _t)
mask = torch.logical_not(mask) if not self.valid_padding else mask
if self.task == 'token_encoding':
return x
# Apply pooling:
x, mask = self.pooling(x, mask=mask)
if self.task == 'sentence_encoding':
return x
# Compute distances:
x = self.distance_layer(x)
if self.task == 'similarity':
return x
# Pass through CoSeNet:
x = self.cosenet(x, mask=mask)
# Apply candidate mask if provided:
if candidate_mask is not None:
candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool())
candidate_mask = candidate_mask.to(device=x.device)
x = x.masked_fill(candidate_mask, 0)
return x
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #