OLMo-Bitnet-1B / configuration_olmo.py
emozilla's picture
update inference code
2b1c7b3
raw
history blame
1.42 kB
"""
OLMo configuration
"""
from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging
from .config import ModelConfig
from .aliases import PathOrStr
from .beam_search import Sampler
from .exceptions import OLMoError
from .initialization import ModuleType
from .optim import Optimizer
from .util import StrEnum
from .safetensors_util import STKey
from .torch_util import seed_all
logger = logging.get_logger(__name__)
class OLMoConfig(PretrainedConfig):
model_type = "olmo"
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
def __init__(self, use_cache: bool = False, **kwargs):
model_config = ModelConfig()
all_kwargs = model_config.asdict()
all_kwargs.update(kwargs)
all_kwargs.update({"use_cache": use_cache})
all_kwargs.update(
{
"architectures": all_kwargs.get("architectures", ["OLMoModelForCausalLM"])
or ["OLMoModelForCausalLM"]
}
)
super().__init__(**all_kwargs)
@property
def num_attention_heads(self):
return self.n_heads
@property
def num_hidden_layers(self):
return self.n_layers
@property
def hidden_size(self):
return self.d_model
# Register the config class so that it is available for transformer pipelines, auto-loading etc.
AutoConfig.register("olmo", OLMoConfig)