|
from typing import * |
|
|
|
import itertools |
|
|
|
from .abc.base import LMScorer |
|
from .gpt2 import GPT2LMScorer |
|
|
|
|
|
class AutoLMScorer: |
|
MODEL_CLASSES = [GPT2LMScorer] |
|
|
|
def __init__(self): |
|
raise EnvironmentError( |
|
"AutoLMscorer is designed to be instantiated " |
|
"using the `AutoLMscorer.from_pretrained(model_name)`" |
|
"method" |
|
) |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_name: str, **kwargs: Any) -> LMScorer: |
|
for model_class in cls.MODEL_CLASSES: |
|
if model_name not in model_class.supported_model_names(): |
|
continue |
|
return model_class(model_name, **kwargs) |
|
raise ValueError( |
|
"Unrecognized model name." |
|
"Can be one of: %s" % ", ".join(cls.supported_model_names()), |
|
) |
|
|
|
@classmethod |
|
def supported_model_names(cls) -> Iterable[str]: |
|
classes = cls.MODEL_CLASSES |
|
models = map(lambda c: c.supported_model_names(), classes) |
|
return itertools.chain.from_iterable(models) |
|
|