Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/fairseq
/incremental_decoding_utils.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import uuid | |
from typing import Dict, Optional | |
from torch import Tensor | |
class FairseqIncrementalState(object): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.init_incremental_state() | |
def init_incremental_state(self): | |
self._incremental_state_id = str(uuid.uuid4()) | |
def _get_full_incremental_state_key(self, key: str) -> str: | |
return "{}.{}".format(self._incremental_state_id, key) | |
def get_incremental_state( | |
self, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
key: str, | |
) -> Optional[Dict[str, Optional[Tensor]]]: | |
"""Helper for getting incremental state for an nn.Module.""" | |
full_key = self._get_full_incremental_state_key(key) | |
if incremental_state is None or full_key not in incremental_state: | |
return None | |
return incremental_state[full_key] | |
def set_incremental_state( | |
self, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | |
key: str, | |
value: Dict[str, Optional[Tensor]], | |
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: | |
"""Helper for setting incremental state for an nn.Module.""" | |
if incremental_state is not None: | |
full_key = self._get_full_incremental_state_key(key) | |
incremental_state[full_key] = value | |
return incremental_state | |
def with_incremental_state(cls): | |
cls.__bases__ = (FairseqIncrementalState,) + tuple( | |
b for b in cls.__bases__ if b != FairseqIncrementalState | |
) | |
return cls | |