Matthijs's picture
add model
fc122ab
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
# IMPORTANT: This Apple software is supplied to you by Apple
# Inc. ("Apple") in consideration of your agreement to the following
# terms, and your use, installation, modification or redistribution of
# this Apple software constitutes acceptance of these terms. If you do
# not agree with these terms, please do not use, install, modify or
# redistribute this Apple software.
# In consideration of your agreement to abide by the following terms, and
# subject to these terms, Apple grants you a personal, non-exclusive
# license, under Apple's copyrights in this original Apple software (the
# "Apple Software"), to use, reproduce, modify and redistribute the Apple
# Software, with or without modifications, in source and/or binary forms;
# provided that if you redistribute the Apple Software in its entirety and
# without modifications, you must retain this notice and the following
# text and disclaimers in all such redistributions of the Apple Software.
# Neither the name, trademarks, service marks or logos of Apple Inc. may
# be used to endorse or promote products derived from the Apple Software
# without specific prior written permission from Apple. Except as
# expressly stated in this notice, no other rights or licenses, express or
# implied, are granted by Apple herein, including but not limited to any
# patent rights that may be infringed by your derivative works or by other
# works in which the Apple Software may be incorporated.
# The Apple Software is provided by Apple on an "AS IS" basis. APPLE
# MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
# THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
# OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
# IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
# MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
# AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
# STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import torch
import torch.nn as nn
from transformers.models.distilbert import modeling_distilbert
from .configuration_distilbert_ane import DistilBertConfig
# Note: Original implementation of distilbert uses an epsilon value of 1e-12
# which is not friendly with the float16 precision that ANE uses by default
EPS = 1e-7
WARN_MSG_FOR_TRAINING_ATTEMPT = \
"This model is optimized for on-device execution only. " \
"Please use the original implementation from Hugging Face for training"
WARN_MSG_FOR_DICT_RETURN = \
"coremltools does not support dict outputs. Please set return_dict=False"
class LayerNormANE(nn.Module):
""" LayerNorm optimized for Apple Neural Engine (ANE) execution
Note: This layer only supports normalization over the final dim. It expects `num_channels`
as an argument and not `normalized_shape` which is used by `torch.nn.LayerNorm`.
"""
def __init__(self,
num_channels,
clip_mag=None,
eps=1e-5,
elementwise_affine=True):
"""
Args:
num_channels: Number of channels (C) where the expected input data format is BC1S. S stands for sequence length.
clip_mag: Optional float value to use for clamping the input range before layer norm is applied.
If specified, helps reduce risk of overflow.
eps: Small value to avoid dividing by zero
elementwise_affine: If true, adds learnable channel-wise shift (bias) and scale (weight) parameters
"""
super().__init__()
# Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine)
self.expected_rank = len('BC1S')
self.num_channels = num_channels
self.eps = eps
self.clip_mag = clip_mag
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.Tensor(num_channels))
self.bias = nn.Parameter(torch.Tensor(num_channels))
self._reset_parameters()
def _reset_parameters(self):
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, inputs):
input_rank = len(inputs.size())
# Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine)
# Migrate the data format from BSC to BC1S (most conducive to ANE)
if input_rank == 3 and inputs.size(2) == self.num_channels:
inputs = inputs.transpose(1, 2).unsqueeze(2)
input_rank = len(inputs.size())
assert input_rank == self.expected_rank
assert inputs.size(1) == self.num_channels
if self.clip_mag is not None:
inputs.clamp_(-self.clip_mag, self.clip_mag)
channels_mean = inputs.mean(dim=1, keepdims=True)
zero_mean = inputs - channels_mean
zero_mean_sq = zero_mean * zero_mean
denom = (zero_mean_sq.mean(dim=1, keepdims=True) + self.eps).rsqrt()
out = zero_mean * denom
if self.elementwise_affine:
out = (out + self.bias.view(1, self.num_channels, 1, 1)
) * self.weight.view(1, self.num_channels, 1, 1)
return out
class Embeddings(modeling_distilbert.Embeddings):
""" Embeddings module optimized for Apple Neural Engine
"""
def __init__(self, config):
super().__init__(config)
setattr(self, 'LayerNorm', LayerNormANE(config.dim, eps=EPS))
class MultiHeadSelfAttention(modeling_distilbert.MultiHeadSelfAttention):
""" MultiHeadSelfAttention module optimized for Apple Neural Engine
"""
def __init__(self, config):
super().__init__(config)
setattr(
self, 'q_lin',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.dim,
kernel_size=1,
))
setattr(
self, 'k_lin',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.dim,
kernel_size=1,
))
setattr(
self, 'v_lin',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.dim,
kernel_size=1,
))
setattr(
self, 'out_lin',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.dim,
kernel_size=1,
))
def prune_heads(self, heads):
raise NotImplementedError
def forward(self,
query,
key,
value,
mask,
head_mask=None,
output_attentions=False):
"""
Parameters:
query: torch.tensor(bs, dim, 1, seq_length)
key: torch.tensor(bs, dim, 1, seq_length)
value: torch.tensor(bs, dim, 1, seq_length)
mask: torch.tensor(bs, seq_length) or torch.tensor(bs, seq_length, 1, 1)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
dim, 1, seq_length) Contextualized layer. Optional: only if `output_attentions=True`
"""
# Parse tensor shapes for source and target sequences
assert len(query.size()) == 4 and len(key.size()) == 4 and len(
value.size()) == 4
bs, dim, dummy, seqlen = query.size()
# assert seqlen == key.size(3) and seqlen == value.size(3)
# assert dim == self.dim
# assert dummy == 1
# Project q, k and v
q = self.q_lin(query)
k = self.k_lin(key)
v = self.v_lin(value)
# Validate mask
if mask is not None:
expected_mask_shape = [bs, seqlen, 1, 1]
if mask.dtype == torch.bool:
mask = mask.logical_not().float() * -1e4
elif mask.dtype == torch.int64:
mask = (1 - mask).float() * -1e4
elif mask.dtype != torch.float32:
raise TypeError(f"Unexpected dtype for mask: {mask.dtype}")
if len(mask.size()) == 2:
mask = mask.unsqueeze(2).unsqueeze(2)
if list(mask.size()) != expected_mask_shape:
raise RuntimeError(
f"Invalid shape for `mask` (Expected {expected_mask_shape}, got {list(mask.size())}"
)
if head_mask is not None:
raise NotImplementedError
# Compute scaled dot-product attention
dim_per_head = self.dim // self.n_heads
mh_q = q.split(
dim_per_head,
dim=1) # (bs, dim_per_head, 1, max_seq_length) * n_heads
mh_k = k.transpose(1, 3).split(
dim_per_head,
dim=3) # (bs, max_seq_length, 1, dim_per_head) * n_heads
mh_v = v.split(
dim_per_head,
dim=1) # (bs, dim_per_head, 1, max_seq_length) * n_heads
normalize_factor = float(dim_per_head)**-0.5
attn_weights = [
torch.einsum('bchq,bkhc->bkhq', [qi, ki]) * normalize_factor
for qi, ki in zip(mh_q, mh_k)
] # (bs, max_seq_length, 1, max_seq_length) * n_heads
if mask is not None:
for head_idx in range(self.n_heads):
attn_weights[head_idx] = attn_weights[head_idx] + mask
attn_weights = [aw.softmax(dim=1) for aw in attn_weights
] # (bs, max_seq_length, 1, max_seq_length) * n_heads
attn = [
torch.einsum('bkhq,bchk->bchq', wi, vi)
for wi, vi in zip(attn_weights, mh_v)
] # (bs, dim_per_head, 1, max_seq_length) * n_heads
attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
attn = self.out_lin(attn)
if output_attentions:
return attn, attn_weights.cat(dim=2)
else:
return (attn, )
class FFN(modeling_distilbert.FFN):
""" FFN module optimized for Apple Neural Engine
"""
def __init__(self, config):
super().__init__(config)
self.seq_len_dim = 3
setattr(
self, 'lin1',
nn.Conv2d(
in_channels=config.dim,
out_channels=config.hidden_dim,
kernel_size=1,
))
setattr(
self, 'lin2',
nn.Conv2d(
in_channels=config.hidden_dim,
out_channels=config.dim,
kernel_size=1,
))
class TransformerBlock(modeling_distilbert.TransformerBlock):
def __init__(self, config):
super().__init__(config)
setattr(self, 'attention', MultiHeadSelfAttention(config))
setattr(self, 'sa_layer_norm', LayerNormANE(config.dim, eps=EPS))
setattr(self, 'ffn', FFN(config))
setattr(self, 'output_layer_norm', LayerNormANE(config.dim, eps=EPS))
class Transformer(modeling_distilbert.Transformer):
def __init__(self, config):
super().__init__(config)
setattr(
self, 'layer',
nn.ModuleList(
[TransformerBlock(config) for _ in range(config.n_layers)]))
class DistilBertModel(modeling_distilbert.DistilBertModel):
config_class = DistilBertConfig
def __init__(self, config):
super().__init__(config)
setattr(self, 'embeddings', Embeddings(config))
setattr(self, 'transformer', Transformer(config))
# Register hook for unsqueezing nn.Linear parameters to match nn.Conv2d parameter spec
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
class DistilBertForMaskedLM(modeling_distilbert.DistilBertForMaskedLM):
config_class = DistilBertConfig
def __init__(self, config):
super().__init__(config)
from transformers.activations import get_activation
setattr(self, 'activation', get_activation(config.activation))
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'vocab_transform', nn.Conv2d(config.dim, config.dim, 1))
setattr(self, 'vocab_layer_norm', LayerNormANE(config.dim, eps=EPS))
setattr(self, 'vocab_projector',
nn.Conv2d(config.dim, config.vocab_size, 1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if self.training or labels is not None:
raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
dlbrt_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
hidden_states = dlbrt_output[0] # (bs, dim, 1, seq_len)
prediction_logits = self.vocab_transform(
hidden_states) # (bs, dim, 1, seq_len)
prediction_logits = self.activation(
prediction_logits) # (bs, dim, 1, seq_len)
prediction_logits = self.vocab_layer_norm(
prediction_logits) # (bs, dim, 1, seq_len)
prediction_logits = self.vocab_projector(
prediction_logits) # (bs, dim, 1, seq_len)
prediction_logits = prediction_logits.squeeze(-1).squeeze(
-1) # (bs, dim)
output = (prediction_logits, ) + dlbrt_output[1:]
mlm_loss = None
return ((mlm_loss, ) + output) if mlm_loss is not None else output
class DistilBertForSequenceClassification(
modeling_distilbert.DistilBertForSequenceClassification):
config_class = DistilBertConfig
def __init__(self, config):
super().__init__(config)
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'pre_classifier', nn.Conv2d(config.dim, config.dim, 1))
setattr(self, 'classifier', nn.Conv2d(config.dim, config.num_labels,
1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if labels is not None or self.training:
raise NotImplementedError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
hidden_state = distilbert_output[0] # (bs, dim, 1, seq_len)
pooled_output = hidden_state[:, :, :, 0:1] # (bs, dim, 1, 1)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim, 1, 1)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim, 1, 1)
logits = self.classifier(pooled_output) # (bs, num_labels, 1, 1)
logits = logits.squeeze(-1).squeeze(-1) # (bs, num_labels)
output = (logits, ) + distilbert_output[1:]
loss = None
return ((loss, ) + output) if loss is not None else output
class DistilBertForQuestionAnswering(
modeling_distilbert.DistilBertForQuestionAnswering):
config_class = DistilBertConfig
def __init__(self, config):
super().__init__(config)
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'qa_outputs', nn.Conv2d(config.dim, config.num_labels,
1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if self.training or start_positions is not None or end_positions is not None:
raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
hidden_states = distilbert_output[0] # (bs, dim, 1, max_query_len)
hidden_states = self.dropout(
hidden_states) # (bs, dim, 1, max_query_len)
logits = self.qa_outputs(hidden_states) # (bs, 2, 1, max_query_len)
start_logits, end_logits = logits.split(
1, dim=1) # (bs, 1, 1, max_query_len) * 2
start_logits = start_logits.squeeze().contiguous(
) # (bs, max_query_len)
end_logits = end_logits.squeeze().contiguous() # (bs, max_query_len)
output = (start_logits, end_logits) + distilbert_output[1:]
total_loss = None
return ((total_loss, ) + output) if total_loss is not None else output
class DistilBertForTokenClassification(
modeling_distilbert.DistilBertForTokenClassification):
def __init__(self, config):
super().__init__(config)
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'classifier',
nn.Conv2d(config.hidden_size, config.num_labels, 1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if self.training or labels is not None:
raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
outputs = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
sequence_output = outputs[0] # (bs, dim, 1, seq_len)
logits = self.classifier(
sequence_output) # (bs, num_labels, 1, seq_len)
logits = logits.squeeze(2).transpose(1, 2) # (bs, seq_len, num_labels)
output = (logits, ) + outputs[1:]
loss = None
return ((loss, ) + output) if loss is not None else output
class DistilBertForMultipleChoice(
modeling_distilbert.DistilBertForMultipleChoice):
config_class = DistilBertConfig
def __init__(self, config):
super().__init__(config)
setattr(self, 'distilbert', DistilBertModel(config))
setattr(self, 'pre_classifier', nn.Conv2d(config.dim, config.dim, 1))
setattr(self, 'classifier', nn.Conv2d(config.dim, 1, 1))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if self.training or labels is not None:
raise ValueError(WARN_MSG_FOR_TRAINING_ATTEMPT)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict:
raise ValueError(WARN_MSG_FOR_DICT_RETURN)
num_choices = input_ids.shape[
1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(
-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(
-1,
attention_mask.size(-1)) if attention_mask is not None else None
inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2),
inputs_embeds.size(-1))
if inputs_embeds is not None else None)
outputs = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=False,
)
hidden_state = outputs[0] # (bs * num_choices, dim, 1, seq_len)
pooled_output = hidden_state[:, :, :,
0:1] # (bs * num_choices, dim, 1, 1)
pooled_output = self.pre_classifier(
pooled_output) # (bs * num_choices, dim, 1, 1)
pooled_output = nn.ReLU()(
pooled_output) # (bs * num_choices, dim, 1, 1)
logits = self.classifier(pooled_output) # (bs * num_choices, 1, 1, 1)
logits = logits.squeeze() # (bs * num_choices)
reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices)
output = (reshaped_logits, ) + outputs[1:]
loss = None
return ((loss, ) + output) if loss is not None else output
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
""" Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_internal_proj = all(substr in k for substr in ['lin', '.weight'])
is_output_proj = all(substr in k
for substr in ['classifier', '.weight'])
if is_internal_proj or is_output_proj:
if len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]