| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from dataclasses import dataclass |
| | from functools import partial |
| | from typing import Callable, Dict, Optional |
| |
|
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm |
| | from fairseq.tasks import FairseqTask |
| | from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias |
| | from .modules import BlockEncoder, Decoder1d |
| | from examples.data2vec.data.modality import Modality |
| |
|
| |
|
| | @dataclass |
| | class D2vTextConfig(D2vModalityConfig): |
| | type: Modality = Modality.TEXT |
| | max_source_positions: int = 512 |
| | learned_pos: bool = True |
| | dropout: float = 0.1 |
| |
|
| | no_scale_embedding: bool = True |
| | layernorm_embedding: bool = True |
| | no_token_positional_embeddings: bool = False |
| |
|
| |
|
| | class TextEncoder(ModalitySpecificEncoder): |
| |
|
| | modality_cfg: D2vTextConfig |
| |
|
| | def __init__( |
| | self, |
| | modality_cfg: D2vTextConfig, |
| | embed_dim: int, |
| | make_block: Callable[[float], nn.ModuleList], |
| | norm_layer: Callable[[int], nn.LayerNorm], |
| | layer_norm_first: bool, |
| | alibi_biases: Dict, |
| | task: Optional[FairseqTask], |
| | ): |
| | self.pad_idx = task.source_dictionary.pad() |
| | self.vocab_size = len(task.source_dictionary) |
| |
|
| | local_encoder = TextLocalEncoder( |
| | vocab_size=self.vocab_size, |
| | embed_dim=embed_dim, |
| | max_source_positions=modality_cfg.max_source_positions, |
| | pad_idx=self.pad_idx, |
| | no_scale_embedding=modality_cfg.no_scale_embedding, |
| | layernorm_embedding=modality_cfg.layernorm_embedding, |
| | dropout=modality_cfg.dropout, |
| | no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings, |
| | learned_pos=modality_cfg.learned_pos, |
| | ) |
| | dpr = np.linspace( |
| | modality_cfg.start_drop_path_rate, |
| | modality_cfg.end_drop_path_rate, |
| | modality_cfg.prenet_depth, |
| | ) |
| | context_encoder = BlockEncoder( |
| | nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), |
| | norm_layer(embed_dim) |
| | if not layer_norm_first and modality_cfg.prenet_depth > 0 |
| | else None, |
| | layer_norm_first, |
| | modality_cfg.prenet_layerdrop, |
| | modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0, |
| | ) |
| | decoder = ( |
| | Decoder1d(modality_cfg.decoder, embed_dim) |
| | if modality_cfg.decoder is not None |
| | else None |
| | ) |
| |
|
| | alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases) |
| |
|
| | super().__init__( |
| | modality_cfg=modality_cfg, |
| | embed_dim=embed_dim, |
| | local_encoder=local_encoder, |
| | project_features=nn.Identity(), |
| | fixed_positional_encoder=None, |
| | relative_positional_encoder=None, |
| | context_encoder=context_encoder, |
| | decoder=decoder, |
| | get_alibi_bias=alibi_bias_fn, |
| | ) |
| |
|
| | def reset_parameters(self): |
| | super().reset_parameters() |
| |
|
| | def convert_padding_mask(self, x, padding_mask): |
| | if padding_mask is None or padding_mask.size(1) == x.size(1): |
| | return padding_mask |
| |
|
| | diff = self.downsample - padding_mask.size(1) % self.downsample |
| | if 0 < diff < self.downsample: |
| | padding_mask = F.pad(padding_mask, (0, diff), value=True) |
| |
|
| | padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample) |
| | padding_mask = padding_mask.all(-1) |
| | if padding_mask.size(1) > x.size(1): |
| | padding_mask = padding_mask[:, : x.size(1)] |
| |
|
| | assert x.size(1) == padding_mask.size( |
| | 1 |
| | ), f"{x.size(1), padding_mask.size(1), diff, self.downsample}" |
| |
|
| | return padding_mask |
| |
|
| |
|
| | class TextLocalEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | vocab_size, |
| | embed_dim, |
| | max_source_positions, |
| | pad_idx, |
| | no_scale_embedding, |
| | layernorm_embedding, |
| | dropout, |
| | no_token_positional_embeddings, |
| | learned_pos, |
| | ): |
| | super().__init__() |
| | self.pad_idx = pad_idx |
| | self.dropout_module = FairseqDropout(dropout) |
| |
|
| | self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx) |
| | self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim) |
| | self.embed_positions = ( |
| | PositionalEmbedding( |
| | max_source_positions, |
| | embed_dim, |
| | pad_idx, |
| | learned=learned_pos, |
| | ) |
| | if not no_token_positional_embeddings |
| | else None |
| | ) |
| | self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim) |
| |
|
| | self.layernorm_embedding = None |
| | if layernorm_embedding: |
| | self.layernorm_embedding = LayerNorm(embed_dim) |
| |
|
| | def forward(self, src_tokens): |
| | x = self.embed_scale * self.embed_tokens(src_tokens) |
| | if self.embed_positions is not None: |
| | x = x + self.embed_positions(src_tokens) |
| |
|
| | if self.layernorm_embedding is not None: |
| | x = self.layernorm_embedding(x) |
| | x = self.dropout_module(x) |
| | return x |
| |
|