|
from torch import nn |
|
import torch |
|
from typing import Optional, Tuple, Union |
|
import collections |
|
import math |
|
|
|
from transformers import DonutSwinPreTrainedModel |
|
from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ |
|
DonutSwinEncoder |
|
|
|
from surya.model.ordering.config import VariableDonutSwinConfig |
|
|
|
class VariableDonutSwinEmbeddings(DonutSwinEmbeddings): |
|
""" |
|
Construct the patch and position embeddings. Optionally, also the mask token. |
|
""" |
|
|
|
def __init__(self, config, use_mask_token=False, **kwargs): |
|
super().__init__(config, use_mask_token) |
|
|
|
self.patch_embeddings = DonutSwinPatchEmbeddings(config) |
|
num_patches = self.patch_embeddings.num_patches |
|
self.patch_grid = self.patch_embeddings.grid_size |
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None |
|
self.position_embeddings = None |
|
|
|
if config.use_absolute_embeddings: |
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) |
|
|
|
self.row_embeddings = None |
|
self.column_embeddings = None |
|
if config.use_2d_embeddings: |
|
self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)) |
|
self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)) |
|
|
|
self.norm = nn.LayerNorm(config.embed_dim) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward( |
|
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs |
|
) -> Tuple[torch.Tensor]: |
|
|
|
embeddings, output_dimensions = self.patch_embeddings(pixel_values) |
|
|
|
embeddings = self.norm(embeddings) |
|
batch_size, seq_len, embed_dim = embeddings.size() |
|
|
|
if bool_masked_pos is not None: |
|
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) |
|
|
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) |
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask |
|
|
|
if self.position_embeddings is not None: |
|
embeddings = embeddings + self.position_embeddings[:, :seq_len, :] |
|
|
|
if self.row_embeddings is not None and self.column_embeddings is not None: |
|
|
|
row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1) |
|
column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1) |
|
|
|
embeddings = embeddings + row_embeddings + column_embeddings |
|
|
|
embeddings = self.dropout(embeddings) |
|
|
|
return embeddings, output_dimensions |
|
|
|
|
|
class VariableDonutSwinModel(DonutSwinModel): |
|
config_class = VariableDonutSwinConfig |
|
def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs): |
|
super().__init__(config) |
|
self.config = config |
|
self.num_layers = len(config.depths) |
|
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) |
|
|
|
self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token) |
|
self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) |
|
|
|
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None |
|
|
|
|
|
self.post_init() |