kinensake's picture
Modify: requirements.txt
2ea9ced
raw
history blame
1.09 kB
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
import itertools
from .abc.base import LMScorer
from .gpt2 import GPT2LMScorer
class AutoLMScorer:
MODEL_CLASSES = [GPT2LMScorer]
def __init__(self):
raise EnvironmentError(
"AutoLMscorer is designed to be instantiated "
"using the `AutoLMscorer.from_pretrained(model_name)`"
"method"
)
@classmethod
def from_pretrained(cls, model_name: str, **kwargs: Any) -> LMScorer:
for model_class in cls.MODEL_CLASSES:
if model_name not in model_class.supported_model_names():
continue
return model_class(model_name, **kwargs)
raise ValueError(
"Unrecognized model name."
"Can be one of: %s" % ", ".join(cls.supported_model_names()),
)
@classmethod
def supported_model_names(cls) -> Iterable[str]:
classes = cls.MODEL_CLASSES
models = map(lambda c: c.supported_model_names(), classes)
return itertools.chain.from_iterable(models)