benjamin's picture
Upload ZettHypernet
2f5f68a verified
raw
history blame
9.77 kB
from .configuration_hypernet import ZettHypernetConfig
from transformers import PreTrainedModel, RobertaConfig, RobertaModel
from functools import partial
from torch import nn as nn
import torch
from torch.nn import functional as F
class Rescaler(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
self.w = nn.Parameter(torch.ones((1, self.dim)), requires_grad=False)
self.b = nn.Parameter(torch.ones((1, self.dim)), requires_grad=False)
def __call__(self, x):
return self.w * x + self.b
class ProjectorBlock(nn.Module):
def __init__(self, input_dim: int, dim: int, intermediate_dim: int):
super().__init__()
self.input_dim = input_dim
self.dim = dim
self.intermediate_dim = intermediate_dim
self.dense1 = nn.Linear(self.input_dim, self.intermediate_dim)
self.dense2 = nn.Linear(self.intermediate_dim, self.dim)
self.ln = nn.LayerNorm(self.dim, eps=1e-6)
def __call__(self, x):
h = F.gelu(
self.dense2(F.gelu(self.dense1(x), approximate="tanh")),
approximate="tanh",
)
return self.ln(h + x)
class ZettHypernet(PreTrainedModel):
config_class = ZettHypernetConfig
def __init__(self, config: ZettHypernetConfig):
super().__init__(config)
self.config = config
self.has_separate_out_embeddings = getattr(
self.config, "separate_out_embeddings", False
)
if self.config.hn_embed_lang_id:
self.lang_embeddings = nn.Embedding(
self.config.n_langs, self.config.hn_hidden_size
)
if self.has_separate_out_embeddings:
n_in_embd = self.config.n_embd * 2
n_out_embd = self.config.n_embd
else:
n_in_embd = self.config.n_embd
n_out_embd = self.config.n_embd
if self.config.hn_model_type == "roberta":
config = RobertaConfig.from_pretrained(
self.config.hn_model_name_or_path
)
config.num_hidden_layers = self.config.hn_n_layers
config.hidden_size = self.config.hn_hidden_size
config.intermediate_size = self.config.hn_intermediate_size
if getattr(self.config, "hn_num_attention_heads", None) is None:
self.config.hn_num_attention_heads = self.config.hn_hidden_size // 64
config.num_attention_heads = self.config.hn_num_attention_heads
self.embed_init_range = config.initializer_range
module_class = partial(RobertaModel, add_pooling_layer=False)
elif self.config.hn_model_type == "t5":
raise NotImplementedError()
if self.config.hn_embed_using_source_embeddings:
# do not need to alloc embeddings since inputs_embeds is always used
config.vocab_size = self.config.pad_token_id + 1
if (
self.config.hn_add_inter_token_attention
or self.config.hn_embed_target_priors
):
raise NotImplementedError()
self.pad_token_id = self.config.pad_token_id
assert self.pad_token_id is not None
self.model = module_class(config)
# need at least one embedding
self.fallback_embeddings = nn.Embedding(
max(self.config.hn_n_extra_tokens, 1), n_in_embd
)
if self.config.hn_embed_using_source_embeddings:
self.input_projection = nn.Sequential(
*[
nn.Linear(n_in_embd, self.config.hn_hidden_size),
ProjectorBlock(
self.config.hn_hidden_size,
self.config.hn_hidden_size,
self.config.hn_intermediate_size,
),
]
)
if self.config.hn_single_head:
self.output_projection = nn.Sequential(
*[
ProjectorBlock(
self.config.hn_hidden_size,
self.config.hn_hidden_size,
self.config.hn_intermediate_size,
),
nn.Linear(self.config.hn_hidden_size, n_in_embd),
]
)
else:
self.output_projection = nn.Sequential(
*[
ProjectorBlock(
self.config.hn_hidden_size,
self.config.hn_hidden_size,
self.config.hn_intermediate_size,
),
nn.Linear(self.config.hn_hidden_size, n_out_embd),
]
)
if self.has_separate_out_embeddings:
self.output_projection_out = nn.Sequential(
*[
ProjectorBlock(
self.config.hn_hidden_size,
self.config.hn_hidden_size,
self.config.hn_intermediate_size,
),
nn.Linear(self.config.hn_hidden_size, self.config.n_embd),
]
)
if self.config.hn_rescale_embeddings:
self.in_scaler = Rescaler(n_in_embd)
self.scaler = Rescaler(n_out_embd)
if self.has_separate_out_embeddings:
self.out_scaler = Rescaler(self.config.n_embd)
if getattr(self.config, "hn_predict_bias", False):
self.bias_projection = nn.Linear(self.config.hn_hidden_size, 1)
def __call__(
self,
target_surface_forms,
target_priors=None,
source_embeddings=None,
lang_index=None,
deterministic: bool = True,
):
if target_priors is not None:
raise NotImplementedError()
if not self.config.hn_embed_using_source_embeddings:
raise NotImplementedError()
use_fallback = target_surface_forms >= self.config.original_vocab_size
main_ids = torch.minimum(
target_surface_forms, torch.tensor(self.config.original_vocab_size - 1, device=self.device)
)
fallback_ids = torch.maximum(
target_surface_forms - self.config.original_vocab_size, torch.tensor(0, device=self.device)
)
source_embeds = F.embedding(main_ids, weight=source_embeddings)
if self.config.hn_rescale_embeddings:
source_embeds = self.in_scaler(source_embeds)
inputs_embeds = torch.where(
use_fallback[..., None],
self.fallback_embeddings(fallback_ids),
source_embeds,
)
inputs_embeds = self.input_projection(inputs_embeds)
attention_mask = target_surface_forms != self.pad_token_id
if self.config.hn_embed_lang_id:
lang_embedding = self.lang_embeddings(lang_index).squeeze()
# position embed and type embed are added afterwards only in PT version so we need to subtract them here
lang_embedding -= self.model.embeddings.token_type_embeddings(
torch.tensor(0, device=self.device)
) + self.model.embeddings.position_embeddings(
torch.tensor(attention_mask.shape[1], device=self.device)
)
lang_embedding = lang_embedding[None, None, :].expand(
inputs_embeds.shape[0], -1, -1
)
inputs_embeds = torch.cat(
[
inputs_embeds,
lang_embedding,
],
axis=1,
)
attention_mask = torch.cat(
[
attention_mask,
torch.ones(lang_embedding.shape[:-1], dtype=torch.bool, device=self.device),
],
axis=1,
)
position_ids = torch.broadcast_to(
torch.arange(torch.atleast_2d(attention_mask).shape[-1], device=self.device),
attention_mask.shape,
)
hidden_states = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
).last_hidden_state
if self.config.hn_concat_last_hidden_state:
hidden_states = hidden_states.reshape(target_surface_forms.shape[0], -1)
else:
hidden_states = hidden_states[:, 0]
predicted_embeddings = self.output_projection(hidden_states)
if self.config.hn_single_head:
predicted_embeddings_in = predicted_embeddings[..., : self.config.n_embd]
if self.has_separate_out_embeddings:
predicted_embeddings_out = predicted_embeddings[
..., self.config.n_embd :
]
else:
predicted_embeddings_out = None
else:
predicted_embeddings_in = predicted_embeddings
if self.has_separate_out_embeddings:
predicted_embeddings_out = self.output_projection_out(hidden_states)
else:
predicted_embeddings_out = None
if self.config.hn_rescale_embeddings:
predicted_embeddings_in = self.scaler(predicted_embeddings_in)
if predicted_embeddings_out is not None:
predicted_embeddings_out = self.out_scaler(predicted_embeddings_out)
if getattr(self.config, "hn_predict_bias", False):
predicted_bias = self.bias_projection(hidden_states)[..., 0]
else:
predicted_bias = torch.zeros_like(
target_surface_forms[..., 0], dtype=self.dtype
)
return predicted_embeddings_in, predicted_embeddings_out, predicted_bias