| |
|
| |
|
| |
|
| | """ PyTorch ViT model."""
|
| |
|
| |
|
| | from functools import partial
|
| | from einops import rearrange
|
| | import torch.nn.functional as F
|
| |
|
| | import collections.abc
|
| | import math
|
| | from typing import Dict, List, Optional, Set, Tuple, Union
|
| |
|
| | import torch
|
| | import torch.utils.checkpoint
|
| | from torch import nn
|
| | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| |
|
| | from transformers.activations import ACT2FN
|
| | from transformers.modeling_outputs import (
|
| | BaseModelOutput,
|
| | BaseModelOutputWithPooling,
|
| | ImageClassifierOutput,
|
| | MaskedImageModelingOutput,
|
| | BaseModelOutputWithNoAttention,
|
| | ImageClassifierOutputWithNoAttention,
|
| | )
|
| | from transformers.modeling_utils import PreTrainedModel
|
| | from transformers.utils import (
|
| | add_code_sample_docstrings,
|
| | add_start_docstrings,
|
| | add_start_docstrings_to_model_forward,
|
| | logging,
|
| | replace_return_docstrings,
|
| | )
|
| | from .configuration_fdvit import FDViTConfig
|
| |
|
| |
|
| | logger = logging.get_logger(__name__)
|
| |
|
| |
|
| | _CONFIG_FOR_DOC = "FDViTConfig"
|
| |
|
| |
|
| | _CHECKPOINT_FOR_DOC = "amd/fdvit_ti"
|
| | _EXPECTED_OUTPUT_SHAPE = [1, 49, 260]
|
| |
|
| |
|
| | _IMAGE_CLASS_CHECKPOINT = "amd/fdvit_ti"
|
| | _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
| |
|
| |
|
| |
|
| |
|
| | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
| | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| |
|
| | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| | 'survival rate' as the argument.
|
| |
|
| | """
|
| | if drop_prob == 0. or not training:
|
| | return x
|
| | keep_prob = 1 - drop_prob
|
| | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| | if keep_prob > 0.0 and scale_by_keep:
|
| | random_tensor.div_(keep_prob)
|
| | return x * random_tensor
|
| |
|
| |
|
| | class FDViTDropPath(nn.Module):
|
| | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| | """
|
| | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| | super().__init__()
|
| | self.drop_prob = drop_prob
|
| | self.scale_by_keep = scale_by_keep
|
| |
|
| | def forward(self, x):
|
| | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| |
|
| | def extra_repr(self):
|
| | return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
| |
|
| |
|
| | class FDViTEmbeddings(nn.Module):
|
| | """
|
| | Construct Patch Embeddings.
|
| | """
|
| |
|
| | def __init__(self, in_channels, out_channels, patch_size, stride, padding):
|
| | super().__init__()
|
| | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size,
|
| | stride=stride, padding=padding, bias=True)
|
| |
|
| | def forward(self, x):
|
| | x = self.conv(x)
|
| | return x
|
| |
|
| |
|
| | class FDViTAttention(nn.Module):
|
| | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| | super().__init__()
|
| | self.num_heads = num_heads
|
| | head_dim = dim // num_heads
|
| |
|
| | self.scale = qk_scale or head_dim ** -0.5
|
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| | self.attn_drop = nn.Dropout(attn_drop)
|
| | self.proj = nn.Linear(dim, dim)
|
| | self.proj_drop = nn.Dropout(proj_drop)
|
| | self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
|
| |
|
| | def get_lepe(self, x):
|
| | B, Head, N, C_p = x.shape
|
| | H = W = int(math.sqrt(N))
|
| | x = x.transpose(-2,-1).contiguous().view(B, C_p*Head, H, W)
|
| |
|
| | lepe = self.get_v(x)
|
| | lepe = lepe.reshape(B, Head, C_p, N).permute(0, 1, 3, 2).contiguous()
|
| | x = x.reshape(B, Head, C_p, N).permute(0, 1, 3, 2).contiguous()
|
| | return x, lepe
|
| |
|
| | def forward(self, x):
|
| | B, N, C = x.shape
|
| | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| | q, k, v = qkv[0], qkv[1], qkv[2]
|
| | v, lepe = self.get_lepe(v)
|
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale
|
| | attn = attn.softmax(dim=-1)
|
| | attn = self.attn_drop(attn)
|
| |
|
| | x = (attn @ v) + lepe
|
| | x = x.transpose(1, 2).reshape(B, N, C)
|
| | x = self.proj(x)
|
| | x = self.proj_drop(x)
|
| | return x
|
| |
|
| |
|
| | class FDViTOutput(nn.Module):
|
| | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| | super().__init__()
|
| | out_features = out_features or in_features
|
| | hidden_features = hidden_features or in_features
|
| | self.fc1 = nn.Linear(in_features, hidden_features)
|
| | self.act = act_layer()
|
| | self.fc2 = nn.Linear(hidden_features, out_features)
|
| | self.drop = nn.Dropout(drop)
|
| |
|
| | def forward(self, x):
|
| | x = self.fc1(x)
|
| | x = self.act(x)
|
| | x = self.drop(x)
|
| | x = self.fc2(x)
|
| | x = self.drop(x)
|
| | return x
|
| |
|
| |
|
| | class Block(nn.Module):
|
| |
|
| | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| | super().__init__()
|
| | self.norm1 = norm_layer(dim)
|
| | self.attn = FDViTAttention(
|
| | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| |
|
| | self.drop_path = FDViTDropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| | self.norm2 = norm_layer(dim)
|
| | mlp_hidden_dim = int(dim * mlp_ratio)
|
| | self.mlp = FDViTOutput(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| |
|
| | def forward(self, x):
|
| | x = x + self.drop_path(self.attn(self.norm1(x)))
|
| | x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| | return x
|
| |
|
| |
|
| | class FDViTLayer(nn.Module):
|
| | def __init__(self, base_dim, depth, heads, mlp_ratio,
|
| | drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
|
| | super().__init__()
|
| | self.layers = nn.ModuleList([])
|
| | embed_dim = base_dim * heads
|
| |
|
| | if drop_path_prob is None:
|
| | drop_path_prob = [0.0 for _ in range(depth)]
|
| |
|
| | self.blocks = nn.ModuleList([
|
| | Block(
|
| | dim=embed_dim,
|
| | num_heads=heads,
|
| | mlp_ratio=mlp_ratio,
|
| | qkv_bias=True,
|
| | drop=drop_rate,
|
| | attn_drop=attn_drop_rate,
|
| | drop_path=drop_path_prob[i],
|
| | norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
| | )
|
| | for i in range(depth)])
|
| |
|
| | def forward(self, x):
|
| | h, w = x.shape[2:4]
|
| | x = rearrange(x, 'b c h w -> b (h w) c')
|
| | for blk in self.blocks:
|
| | x = blk(x)
|
| |
|
| | return x
|
| |
|
| |
|
| | class FDViTPooling(nn.Module):
|
| | def __init__(self, in_feature, out_feature, out_size):
|
| | super().__init__()
|
| |
|
| | d = torch.linspace(-1, 1, out_size)
|
| | meshx, meshy = torch.meshgrid((d, d))
|
| | self.grid = torch.stack((meshy, meshx), 2)
|
| |
|
| | self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=3,
|
| | padding=1, stride=1)
|
| | self.ln = nn.LayerNorm(in_feature)
|
| |
|
| | def forward(self, x):
|
| | h = w = int(math.sqrt(x.shape[1]))
|
| | x = self.ln(x)
|
| | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| |
|
| | grid = self.grid.expand(x.shape[0], -1, -1, -1)
|
| | x = F.grid_sample(x, grid.to(x.device).type_as(x),align_corners=True)
|
| | x = self.conv(x)
|
| |
|
| |
|
| | return x
|
| |
|
| | class FDViTEncoder(nn.Module):
|
| | def __init__(self, config):
|
| | super().__init__()
|
| |
|
| | self.config = config
|
| | image_size, patch_size, stride, base_dims, depth, heads, channels, out_size, mlp_ratio = config.image_size, config.patch_size, config.stride, config.base_dims, config.depth, config.heads, config.channels, config.out_size, config.mlp_ratio
|
| | num_classes = config.num_classes if config.num_classes is not None else 1000
|
| | in_chans = config.in_chans if config.in_chans is not None else 3
|
| | attn_drop_rate = config.attn_drop_rate if config.attn_drop_rate is not None else .0
|
| | drop_rate = config.drop_rate if config.drop_rate is not None else .0
|
| | drop_path_rate = config.drop_path_rate if config.drop_path_rate is not None else .0
|
| |
|
| |
|
| | total_block = sum(depth)
|
| | padding = 0
|
| | block_idx = 0
|
| |
|
| | width = math.floor(
|
| | (image_size + 2 * padding - patch_size) / stride + 1)
|
| |
|
| | self.base_dims = base_dims
|
| | self.heads = heads
|
| | self.num_classes = num_classes
|
| |
|
| | self.patch_size = patch_size
|
| | self.pos_embed = nn.Parameter(
|
| | torch.randn(1, base_dims[0] * heads[0], width, width),
|
| | requires_grad=True
|
| | )
|
| | self.patch_embed = FDViTEmbeddings(in_chans, base_dims[0] * heads[0],
|
| | patch_size, stride, padding)
|
| |
|
| | self.pos_drop = nn.Dropout(p=drop_rate)
|
| |
|
| | self.transformers = nn.ModuleList([])
|
| | self.pools = nn.ModuleList([])
|
| | self.decoders = nn.ModuleList([])
|
| |
|
| | for stage in range(len(depth)):
|
| | drop_path_prob = [drop_path_rate * i / total_block
|
| | for i in range(block_idx, block_idx + depth[stage])]
|
| | block_idx += depth[stage]
|
| |
|
| | self.transformers.append(
|
| | FDViTLayer(base_dims[stage], depth[stage], heads[stage],
|
| | mlp_ratio,
|
| | drop_rate, attn_drop_rate, drop_path_prob)
|
| | )
|
| | if stage < len(heads) - 1:
|
| | self.pools.append(
|
| | FDViTPooling(channels[stage],
|
| | channels[stage+1],
|
| | out_size[stage+1]
|
| | )
|
| | )
|
| |
|
| | self.embed_dim = base_dims[-1] * heads[-1]
|
| |
|
| |
|
| | def forward(self, x, output_hidden_states=False, return_dict=True):
|
| | all_hidden_states = () if output_hidden_states else None
|
| |
|
| | x = self.patch_embed(x)
|
| |
|
| | pos_embed = self.pos_embed
|
| | x = self.pos_drop(x + pos_embed)
|
| |
|
| | for stage in range(len(self.pools)):
|
| | xt = self.transformers[stage](x)
|
| | x = self.pools[stage](xt)
|
| |
|
| | if output_hidden_states:
|
| | all_hidden_states = all_hidden_states + (xt,)
|
| |
|
| | x = self.transformers[-1](x)
|
| | if output_hidden_states:
|
| | all_hidden_states = all_hidden_states + (x,)
|
| |
|
| | if not return_dict:
|
| | return tuple(v for v in [x, all_hidden_states] if v is not None)
|
| |
|
| | return BaseModelOutputWithNoAttention(last_hidden_state=x, hidden_states=all_hidden_states)
|
| |
|
| |
|
| |
|
| |
|
| | class FDViTPreTrainedModel(PreTrainedModel):
|
| | """
|
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| | models.
|
| | """
|
| |
|
| | config_class = FDViTConfig
|
| | base_model_prefix = "fdvit"
|
| | main_input_name = "pixel_values"
|
| |
|
| | def _init_weights(self, module):
|
| | """Initialize the weights"""
|
| | if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| | if module.bias is not None:
|
| | module.bias.data.zero_()
|
| | elif isinstance(module, nn.LayerNorm):
|
| | module.bias.data.zero_()
|
| | module.weight.data.fill_(1.0)
|
| |
|
| |
|
| | FDVIT_START_DOCSTRING = r"""
|
| | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| | as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| | behavior.
|
| |
|
| | Parameters:
|
| | config ([`FDViTConfig`]): Model configuration class with all the parameters of the model.
|
| | Initializing with a config file does not load the weights associated with the model, only the
|
| | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| | """
|
| |
|
| | FDVIT_INPUTS_DOCSTRING = r"""
|
| | Args:
|
| | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| | Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`FDViTImageProcessor.__call__`]
|
| | for details.
|
| | """
|
| |
|
| |
|
| | @add_start_docstrings(
|
| | "The bare FDViT Model transformer outputting raw hidden-states without any specific head on top.",
|
| | FDVIT_START_DOCSTRING,
|
| | )
|
| | class FDViTModel(FDViTPreTrainedModel):
|
| | def __init__(self, config: FDViTConfig):
|
| | super().__init__(config)
|
| | self.config = config
|
| |
|
| | self.encoder = FDViTEncoder(config)
|
| |
|
| |
|
| | self.post_init()
|
| |
|
| | @add_start_docstrings_to_model_forward(FDVIT_INPUTS_DOCSTRING)
|
| | @add_code_sample_docstrings(
|
| | checkpoint=_CHECKPOINT_FOR_DOC,
|
| | output_type=BaseModelOutputWithNoAttention,
|
| | config_class=_CONFIG_FOR_DOC,
|
| | modality="vision",
|
| | expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| | )
|
| | def forward(
|
| | self,
|
| | pixel_values: Optional[torch.Tensor] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | return_dict: Optional[bool] = None,
|
| | ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
|
| |
|
| | output_hidden_states = (
|
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| | )
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| | if pixel_values is None:
|
| | raise ValueError("You have to specify pixel_values")
|
| |
|
| | encoder_outputs = self.encoder(
|
| | pixel_values,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict,
|
| | )
|
| | sequence_output = encoder_outputs[0]
|
| |
|
| | if not return_dict:
|
| | return (sequence_output, None) + encoder_outputs[1:]
|
| |
|
| | return BaseModelOutputWithNoAttention(
|
| | last_hidden_state=sequence_output,
|
| | hidden_states=encoder_outputs.hidden_states,
|
| | )
|
| |
|
| |
|
| | class FDViTPooler(nn.Module):
|
| | def __init__(self, config: FDViTConfig):
|
| | super().__init__()
|
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| | self.activation = nn.Tanh()
|
| |
|
| | def forward(self, hidden_states):
|
| |
|
| |
|
| | pooled_output = self.dense(first_token_tensor)
|
| | pooled_output = self.activation(pooled_output)
|
| | return pooled_output
|
| |
|
| |
|
| | @add_start_docstrings(
|
| | """
|
| | FDViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
| | the [CLS] token) e.g. for ImageNet.
|
| | """,
|
| | FDVIT_START_DOCSTRING,
|
| | )
|
| | class FDViTForImageClassification(FDViTPreTrainedModel):
|
| | def __init__(self, config: FDViTConfig) -> None:
|
| | super().__init__(config)
|
| |
|
| | self.num_labels = config.num_labels
|
| | self.fdvit = FDViTModel(config)
|
| |
|
| |
|
| | self.norm = nn.LayerNorm(config.base_dims[-1] * config.heads[-1], eps=1e-6)
|
| |
|
| | self.classifier = nn.Linear(config.base_dims[-1] * config.heads[-1], config.num_classes) if config.num_classes > 0 else nn.Identity()
|
| |
|
| |
|
| | self.post_init()
|
| |
|
| | @add_start_docstrings_to_model_forward(FDVIT_INPUTS_DOCSTRING)
|
| | @add_code_sample_docstrings(
|
| | checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| | output_type=ImageClassifierOutputWithNoAttention,
|
| | config_class=_CONFIG_FOR_DOC,
|
| | expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| | )
|
| | def forward(
|
| | self,
|
| | pixel_values: Optional[torch.Tensor] = None,
|
| | labels: Optional[torch.Tensor] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | return_dict: Optional[bool] = None,
|
| | ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
|
| | r"""
|
| | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| | Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| | """
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| | outputs = self.fdvit(
|
| | pixel_values,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict,
|
| | )
|
| |
|
| | sequence_output = outputs[0]
|
| |
|
| | logits = self.classifier(self.norm(sequence_output).mean(dim=1))
|
| |
|
| | loss = None
|
| | if labels is not None:
|
| |
|
| | labels = labels.to(logits.device)
|
| | if self.config.problem_type is None:
|
| | if self.num_labels == 1:
|
| | self.config.problem_type = "regression"
|
| | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| | self.config.problem_type = "single_label_classification"
|
| | else:
|
| | self.config.problem_type = "multi_label_classification"
|
| |
|
| | if self.config.problem_type == "regression":
|
| | loss_fct = MSELoss()
|
| | if self.num_labels == 1:
|
| | loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| | else:
|
| | loss = loss_fct(logits, labels)
|
| | elif self.config.problem_type == "single_label_classification":
|
| | loss_fct = CrossEntropyLoss()
|
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| | elif self.config.problem_type == "multi_label_classification":
|
| | loss_fct = BCEWithLogitsLoss()
|
| | loss = loss_fct(logits, labels)
|
| |
|
| | if not return_dict:
|
| | output = (logits,) + outputs[1:]
|
| | return ((loss,) + output) if loss is not None else output
|
| |
|
| | return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
|
| |
|