| |
| |
| |
| |
|
|
| import uuid |
| from typing import Dict, Optional |
|
|
| from torch import Tensor |
|
|
|
|
| class FairseqIncrementalState(object): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.init_incremental_state() |
|
|
| def init_incremental_state(self): |
| self._incremental_state_id = str(uuid.uuid4()) |
|
|
| def _get_full_incremental_state_key(self, key: str) -> str: |
| return "{}.{}".format(self._incremental_state_id, key) |
|
|
| def get_incremental_state( |
| self, |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
| key: str, |
| ) -> Optional[Dict[str, Optional[Tensor]]]: |
| """Helper for getting incremental state for an nn.Module.""" |
| full_key = self._get_full_incremental_state_key(key) |
| if incremental_state is None or full_key not in incremental_state: |
| return None |
| return incremental_state[full_key] |
|
|
| def set_incremental_state( |
| self, |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
| key: str, |
| value: Dict[str, Optional[Tensor]], |
| ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: |
| """Helper for setting incremental state for an nn.Module.""" |
| if incremental_state is not None: |
| full_key = self._get_full_incremental_state_key(key) |
| incremental_state[full_key] = value |
| return incremental_state |
|
|
|
|
| def with_incremental_state(cls): |
| cls.__bases__ = (FairseqIncrementalState,) + tuple( |
| b for b in cls.__bases__ if b != FairseqIncrementalState |
| ) |
| return cls |
|
|