|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.models import ( |
|
FairseqIncrementalDecoder, |
|
FairseqLanguageModel, |
|
register_model, |
|
) |
|
from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class AdaptiveSpanSmallConfig(FairseqDataclass): |
|
|
|
vocab_size: int = 50 |
|
d_model: int = 256 |
|
n_head: int = 4 |
|
d_inner: int = 1024 |
|
n_layer: int = 8 |
|
attn_span: int = 1024 |
|
dropout: float = 0.0 |
|
emb_dropout: float = 0.0 |
|
adapt_span_ramp: int = 32 |
|
adapt_span_init: float = 0.0 |
|
aux_loss_scaler: float = 0.000002 |
|
adapt_span_layer: bool = False |
|
|
|
|
|
@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig) |
|
class AdaptiveSpanTransformer(FairseqLanguageModel): |
|
@classmethod |
|
def build_model(cls, cfg: AdaptiveSpanSmallConfig, task): |
|
return cls(AdaptiveSpanDecoder(cfg, task)) |
|
|
|
def get_aux_loss(self): |
|
return self.decoder.get_aux_loss() |
|
|
|
def get_current_max_span(self): |
|
return self.decoder.get_current_max_span() |
|
|
|
def get_current_avg_span(self): |
|
return self.decoder.get_current_avg_span() |
|
|
|
|
|
class AdaptiveSpanDecoder(FairseqIncrementalDecoder): |
|
def __init__(self, cfg, task): |
|
|
|
super().__init__(task.target_dictionary) |
|
|
|
self.config = cfg |
|
config = AdaptiveSpanSmallConfig( |
|
vocab_size=len(task.target_dictionary), |
|
d_model=cfg.d_model, |
|
n_head=cfg.n_head, |
|
d_inner=cfg.d_inner, |
|
n_layer=cfg.n_layer, |
|
attn_span=cfg.attn_span, |
|
dropout=cfg.dropout, |
|
emb_dropout=cfg.emb_dropout, |
|
adapt_span_ramp=cfg.adapt_span_ramp, |
|
adapt_span_init=cfg.adapt_span_init, |
|
aux_loss_scaler=cfg.aux_loss_scaler, |
|
adapt_span_layer=cfg.adapt_span_layer, |
|
) |
|
logger.info(config) |
|
self.model = AdaptiveSpanTransformerModel(**config.__dict__) |
|
|
|
self._mems = None |
|
|
|
def forward( |
|
self, |
|
src_tokens, |
|
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, |
|
encoder_out=None, |
|
): |
|
bsz = src_tokens.size(0) |
|
if incremental_state is not None: |
|
mems = self.get_incremental_state("mems") |
|
src_tokens = src_tokens[:, -1:] |
|
else: |
|
mems = self._mems |
|
|
|
if mems is None: |
|
|
|
mems = self.init_hid_cache(bsz) |
|
output = self.model(x=src_tokens, h_cache=mems,) |
|
if incremental_state is not None: |
|
self.set_incremental_state(incremental_state, "mems", output[1]) |
|
else: |
|
self._mems = output[1] |
|
return (output[0],) |
|
|
|
def max_positions(self): |
|
return self.config.attn_span |
|
|
|
def init_hid_cache(self, batch_sz): |
|
hid = [] |
|
for layer in self.model.layers: |
|
param = next(self.model.parameters()) |
|
h = torch.zeros( |
|
batch_sz, |
|
layer.get_cache_size(), |
|
self.config.d_model, |
|
dtype=param.dtype, |
|
device=param.device, |
|
) |
|
hid.append(h) |
|
return hid |
|
|
|
def get_aux_loss(self): |
|
return self.model.get_aux_loss() |
|
|
|
def get_current_max_span(self): |
|
return self.model.get_current_max_span() |
|
|
|
def get_current_avg_span(self): |
|
return self.model.get_current_avg_span() |
|
|
|
def reorder_incremental_state( |
|
self, |
|
incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], |
|
new_order: torch.Tensor, |
|
): |
|
"""Reorder incremental state. |
|
|
|
This will be called when the order of the input has changed from the |
|
previous time step. A typical use case is beam search, where the input |
|
order changes between time steps based on the selection of beams. |
|
""" |
|
raise NotImplementedError("This is required for generation/beam search") |
|
|
|
|
|
|
|
|
|
|