|
from torch import nn |
|
import torch |
|
from typing import Optional, Tuple, Union |
|
|
|
from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ |
|
DonutSwinEncoder, DonutSwinModelOutput, DonutSwinEncoderOutput, DonutSwinAttention, DonutSwinDropPath, \ |
|
DonutSwinIntermediate, DonutSwinOutput, window_partition, window_reverse |
|
|
|
|
|
|
|
from .config import VariableDonutSwinConfig |
|
|
|
|
|
class VariableDonutSwinEmbeddings(DonutSwinEmbeddings): |
|
""" |
|
Construct the patch and position embeddings. Optionally, also the mask token. |
|
""" |
|
|
|
def __init__(self, config, use_mask_token=False): |
|
super().__init__(config, use_mask_token) |
|
|
|
self.patch_embeddings = DonutSwinPatchEmbeddings(config) |
|
num_patches = self.patch_embeddings.num_patches |
|
self.patch_grid = self.patch_embeddings.grid_size |
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None |
|
self.position_embeddings = None |
|
|
|
if config.use_absolute_embeddings: |
|
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) |
|
|
|
self.norm = nn.LayerNorm(config.embed_dim) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward( |
|
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None |
|
) -> Tuple[torch.Tensor]: |
|
|
|
embeddings, output_dimensions = self.patch_embeddings(pixel_values) |
|
|
|
embeddings = self.norm(embeddings) |
|
batch_size, seq_len, embed_dim = embeddings.size() |
|
|
|
if bool_masked_pos is not None: |
|
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) |
|
|
|
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) |
|
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask |
|
|
|
if self.position_embeddings is not None: |
|
embeddings = embeddings + self.position_embeddings[:, :seq_len, :] |
|
|
|
embeddings = self.dropout(embeddings) |
|
|
|
return embeddings, output_dimensions |
|
|
|
|
|
class VariableDonutSwinPatchMerging(nn.Module): |
|
""" |
|
Patch Merging Layer. |
|
|
|
Args: |
|
input_resolution (`Tuple[int]`): |
|
Resolution of input feature. |
|
dim (`int`): |
|
Number of input channels. |
|
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): |
|
Normalization layer class. |
|
""" |
|
|
|
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: |
|
super().__init__() |
|
self.input_resolution = input_resolution |
|
self.dim = dim |
|
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) |
|
self.norm = norm_layer(4 * dim) |
|
|
|
def maybe_pad(self, input_feature, height, width): |
|
should_pad = (height % 2 == 1) or (width % 2 == 1) |
|
if should_pad: |
|
pad_values = (0, 0, 0, width % 2, 0, height % 2) |
|
input_feature = nn.functional.pad(input_feature, pad_values) |
|
|
|
return input_feature |
|
|
|
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: |
|
height, width = input_dimensions |
|
|
|
batch_size, dim, num_channels = input_feature.shape |
|
|
|
input_feature = input_feature.view(batch_size, height, width, num_channels) |
|
|
|
input_feature = self.maybe_pad(input_feature, height, width) |
|
|
|
input_feature_0 = input_feature[:, 0::2, 0::2, :] |
|
|
|
input_feature_1 = input_feature[:, 1::2, 0::2, :] |
|
|
|
input_feature_2 = input_feature[:, 0::2, 1::2, :] |
|
|
|
input_feature_3 = input_feature[:, 1::2, 1::2, :] |
|
|
|
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) |
|
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) |
|
|
|
input_feature = self.norm(input_feature) |
|
input_feature = self.reduction(input_feature) |
|
|
|
return input_feature |
|
|
|
|
|
class VariableDonutSwinLayer(nn.Module): |
|
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): |
|
super().__init__() |
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward |
|
self.shift_size = shift_size |
|
self.window_size = config.window_size |
|
self.input_resolution = input_resolution |
|
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) |
|
self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size) |
|
self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() |
|
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) |
|
self.intermediate = DonutSwinIntermediate(config, dim) |
|
self.output = DonutSwinOutput(config, dim) |
|
|
|
def set_shift_and_window_size(self, input_resolution): |
|
if min(input_resolution) <= self.window_size: |
|
|
|
self.shift_size = 0 |
|
self.window_size = min(input_resolution) |
|
|
|
def get_attn_mask(self, height, width, dtype): |
|
if self.shift_size > 0: |
|
|
|
img_mask = torch.zeros((1, height, width, 1), dtype=dtype) |
|
height_slices = ( |
|
slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None), |
|
) |
|
width_slices = ( |
|
slice(0, -self.window_size), |
|
slice(-self.window_size, -self.shift_size), |
|
slice(-self.shift_size, None), |
|
) |
|
count = 0 |
|
for height_slice in height_slices: |
|
for width_slice in width_slices: |
|
img_mask[:, height_slice, width_slice, :] = count |
|
count += 1 |
|
|
|
mask_windows = window_partition(img_mask, self.window_size) |
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
|
else: |
|
attn_mask = None |
|
return attn_mask |
|
|
|
def maybe_pad(self, hidden_states, height, width): |
|
pad_right = (self.window_size - width % self.window_size) % self.window_size |
|
pad_bottom = (self.window_size - height % self.window_size) % self.window_size |
|
pad_values = (0, 0, 0, pad_right, 0, pad_bottom) |
|
hidden_states = nn.functional.pad(hidden_states, pad_values) |
|
return hidden_states, pad_values |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
input_dimensions: Tuple[int, int], |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
always_partition: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if not always_partition: |
|
self.set_shift_and_window_size(input_dimensions) |
|
else: |
|
pass |
|
height, width = input_dimensions |
|
batch_size, _, channels = hidden_states.size() |
|
shortcut = hidden_states |
|
|
|
hidden_states = self.layernorm_before(hidden_states) |
|
|
|
hidden_states = hidden_states.view(batch_size, height, width, channels) |
|
|
|
|
|
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) |
|
|
|
_, height_pad, width_pad, _ = hidden_states.shape |
|
|
|
if self.shift_size > 0: |
|
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
|
else: |
|
shifted_hidden_states = hidden_states |
|
|
|
|
|
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) |
|
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) |
|
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) |
|
if attn_mask is not None: |
|
attn_mask = attn_mask.to(hidden_states_windows.device) |
|
|
|
attention_outputs = self.attention( |
|
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions |
|
) |
|
|
|
attention_output = attention_outputs[0] |
|
|
|
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) |
|
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) |
|
|
|
|
|
if self.shift_size > 0: |
|
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
|
else: |
|
attention_windows = shifted_windows |
|
|
|
was_padded = pad_values[3] > 0 or pad_values[5] > 0 |
|
if was_padded: |
|
attention_windows = attention_windows[:, :height, :width, :].contiguous() |
|
|
|
attention_windows = attention_windows.view(batch_size, height * width, channels) |
|
|
|
hidden_states = shortcut + self.drop_path(attention_windows) |
|
|
|
layer_output = self.layernorm_after(hidden_states) |
|
layer_output = self.intermediate(layer_output) |
|
layer_output = hidden_states + self.output(layer_output) |
|
|
|
layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) |
|
return layer_outputs |
|
|
|
|
|
class VariableDonutSwinStage(nn.Module): |
|
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): |
|
super().__init__() |
|
self.config = config |
|
self.dim = dim |
|
self.blocks = nn.ModuleList( |
|
[ |
|
VariableDonutSwinLayer( |
|
config=config, |
|
dim=dim, |
|
input_resolution=input_resolution, |
|
num_heads=num_heads, |
|
shift_size=0 if (i % 2 == 0) else int(config.window_size // 2), |
|
) |
|
for i in range(depth) |
|
] |
|
) |
|
|
|
|
|
if downsample is not None: |
|
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) |
|
else: |
|
self.downsample = None |
|
|
|
self.pointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
input_dimensions: Tuple[int, int], |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
always_partition: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor]: |
|
height, width = input_dimensions |
|
for i, layer_module in enumerate(self.blocks): |
|
layer_head_mask = head_mask[i] if head_mask is not None else None |
|
|
|
layer_outputs = layer_module( |
|
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
hidden_states_before_downsampling = hidden_states |
|
if self.downsample is not None: |
|
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 |
|
output_dimensions = (height, width, height_downsampled, width_downsampled) |
|
hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) |
|
else: |
|
output_dimensions = (height, width, height, width) |
|
|
|
stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) |
|
|
|
if output_attentions: |
|
stage_outputs += layer_outputs[1:] |
|
return stage_outputs |
|
|
|
|
|
class VariableDonutSwinEncoder(nn.Module): |
|
def __init__(self, config, grid_size): |
|
super().__init__() |
|
self.num_layers = len(config.depths) |
|
self.config = config |
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] |
|
self.layers = nn.ModuleList( |
|
[ |
|
VariableDonutSwinStage( |
|
config=config, |
|
dim=int(config.embed_dim * 2**i_layer), |
|
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), |
|
depth=config.depths[i_layer], |
|
num_heads=config.num_heads[i_layer], |
|
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], |
|
downsample=VariableDonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, |
|
) |
|
for i_layer in range(self.num_layers) |
|
] |
|
) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
input_dimensions: Tuple[int, int], |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
output_hidden_states: Optional[bool] = False, |
|
output_hidden_states_before_downsampling: Optional[bool] = False, |
|
always_partition: Optional[bool] = False, |
|
return_dict: Optional[bool] = True, |
|
) -> Union[Tuple, DonutSwinEncoderOutput]: |
|
all_hidden_states = () if output_hidden_states else None |
|
all_reshaped_hidden_states = () if output_hidden_states else None |
|
all_self_attentions = () if output_attentions else None |
|
|
|
if output_hidden_states: |
|
batch_size, _, hidden_size = hidden_states.shape |
|
|
|
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) |
|
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) |
|
all_hidden_states += (hidden_states,) |
|
all_reshaped_hidden_states += (reshaped_hidden_state,) |
|
|
|
for i, layer_module in enumerate(self.layers): |
|
layer_head_mask = head_mask[i] if head_mask is not None else None |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
layer_module.__call__, |
|
hidden_states, |
|
input_dimensions, |
|
layer_head_mask, |
|
output_attentions, |
|
always_partition, |
|
) |
|
else: |
|
layer_outputs = layer_module( |
|
hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
hidden_states_before_downsampling = layer_outputs[1] |
|
output_dimensions = layer_outputs[2] |
|
|
|
input_dimensions = (output_dimensions[-2], output_dimensions[-1]) |
|
|
|
if output_hidden_states and output_hidden_states_before_downsampling: |
|
batch_size, _, hidden_size = hidden_states_before_downsampling.shape |
|
|
|
|
|
reshaped_hidden_state = hidden_states_before_downsampling.view( |
|
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size |
|
) |
|
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) |
|
all_hidden_states += (hidden_states_before_downsampling,) |
|
all_reshaped_hidden_states += (reshaped_hidden_state,) |
|
elif output_hidden_states and not output_hidden_states_before_downsampling: |
|
batch_size, _, hidden_size = hidden_states.shape |
|
|
|
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) |
|
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) |
|
all_hidden_states += (hidden_states,) |
|
all_reshaped_hidden_states += (reshaped_hidden_state,) |
|
|
|
if output_attentions: |
|
all_self_attentions += layer_outputs[3:] |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) |
|
|
|
return DonutSwinEncoderOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
reshaped_hidden_states=all_reshaped_hidden_states, |
|
) |
|
|
|
|
|
class VariableDonutSwinModel(DonutSwinModel): |
|
config_class = VariableDonutSwinConfig |
|
def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs): |
|
super().__init__(config) |
|
self.config = config |
|
self.num_layers = len(config.depths) |
|
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) |
|
|
|
self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token) |
|
self.encoder = VariableDonutSwinEncoder(config, self.embeddings.patch_grid) |
|
|
|
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
bool_masked_pos: Optional[torch.BoolTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs |
|
) -> Union[Tuple, DonutSwinModelOutput]: |
|
r""" |
|
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): |
|
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) |
|
|
|
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) |
|
|
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
input_dimensions, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = encoder_outputs[0] |
|
|
|
pooled_output = None |
|
if self.pooler is not None: |
|
pooled_output = self.pooler(sequence_output.transpose(1, 2)) |
|
pooled_output = torch.flatten(pooled_output, 1) |
|
|
|
if not return_dict: |
|
output = (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
|
return output |
|
|
|
return DonutSwinModelOutput( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, |
|
) |
|
|