|
from abc import ABC |
|
from abc import abstractmethod |
|
from typing import Tuple |
|
|
|
import torch |
|
|
|
from espnet.nets.scorer_interface import BatchScorerInterface |
|
|
|
|
|
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC): |
|
"""The abstract LM class |
|
|
|
To share the loss calculation way among different models, |
|
We uses delegate pattern here: |
|
The instance of this class should be passed to "LanguageModel" |
|
|
|
>>> from espnet2.lm.abs_model import AbsLM |
|
>>> lm = AbsLM() |
|
>>> model = LanguageESPnetModel(lm=lm) |
|
|
|
This "model" is one of mediator objects for "Task" class. |
|
|
|
""" |
|
|
|
@abstractmethod |
|
def forward( |
|
self, input: torch.Tensor, hidden: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
raise NotImplementedError |
|
|