from torch import nn import torch from typing import Optional, Tuple, Union import collections import math from transformers import DonutSwinPreTrainedModel from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ DonutSwinEncoder from surya.model.ordering.config import VariableDonutSwinConfig class VariableDonutSwinEmbeddings(DonutSwinEmbeddings): """ Construct the patch and position embeddings. Optionally, also the mask token. """ def __init__(self, config, use_mask_token=False, **kwargs): super().__init__(config, use_mask_token) self.patch_embeddings = DonutSwinPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.patch_grid = self.patch_embeddings.grid_size self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None self.position_embeddings = None if config.use_absolute_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) self.row_embeddings = None self.column_embeddings = None if config.use_2d_embeddings: self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)) self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)) self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs ) -> Tuple[torch.Tensor]: embeddings, output_dimensions = self.patch_embeddings(pixel_values) # Layernorm across the last dimension (each patch is a single row) embeddings = self.norm(embeddings) batch_size, seq_len, embed_dim = embeddings.size() if bool_masked_pos is not None: mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) # replace the masked visual tokens by mask_tokens mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask if self.position_embeddings is not None: embeddings = embeddings + self.position_embeddings[:, :seq_len, :] if self.row_embeddings is not None and self.column_embeddings is not None: # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ... row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1) column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1) embeddings = embeddings + row_embeddings + column_embeddings embeddings = self.dropout(embeddings) return embeddings, output_dimensions class VariableDonutSwinModel(DonutSwinModel): config_class = VariableDonutSwinConfig def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs): super().__init__(config) self.config = config self.num_layers = len(config.depths) self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token) self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None # Initialize weights and apply final processing self.post_init()