| | import logging |
| | from dataclasses import fields |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | from transformers import PreTrainedModel |
| | from transformers.cache_utils import Cache |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.models.auto import AutoModelForCausalLM |
| |
|
| | from olmo.config import ModelConfig |
| | from olmo.model import OLMo |
| |
|
| | from .configuration_olmo import OLMoConfig |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| |
|
| | def create_model_config_from_pretrained_config(config: OLMoConfig): |
| | """ |
| | Utility function |
| | """ |
| |
|
| | kwargs = {} |
| | for field in fields(ModelConfig): |
| | kwargs[field.name] = getattr(config, field.name) |
| |
|
| | model_config = ModelConfig(**kwargs) |
| | return model_config |
| |
|
| |
|
| | class OLMoForCausalLM(PreTrainedModel): |
| | """ |
| | Extremely barebones HF model wrapper. |
| | """ |
| |
|
| | config_class = OLMoConfig |
| | base_model_prefix = "model" |
| | _no_split_modules = ["OLMoBlock"] |
| |
|
| | def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False): |
| | super().__init__(config) |
| |
|
| | if not model: |
| | model_config = create_model_config_from_pretrained_config(config) |
| | |
| | model_config.init_device = "cpu" |
| | self.model = OLMo(model_config, init_params=init_params) |
| | else: |
| | self.model = model |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attention_bias: Optional[torch.Tensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[ |
| | Cache |
| | ] = None, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | if use_cache is None: |
| | use_cache = self.config.use_cache |
| |
|
| | if output_attentions: |
| | raise ValueError("output_attentions is not yet supported in OLMo") |
| |
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | outputs = self.model.forward( |
| | input_ids=input_ids, |
| | input_embeddings=inputs_embeds, |
| | attention_mask=attention_mask, |
| | attention_bias=attention_bias, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| | logits = outputs.logits |
| | hidden_states = outputs.hidden_states |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | |
| | loss_fct = torch.nn.CrossEntropyLoss() |
| | shift_logits = shift_logits.view(-1, self.config.embedding_size) |
| | shift_labels = shift_labels.view(-1) |
| | |
| | shift_labels = shift_labels.to(shift_logits.device) |
| | loss = loss_fct(shift_logits, shift_labels) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return (loss,) + output if loss is not None else output |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.attn_key_values, |
| | hidden_states=hidden_states, |
| | ) |
| |
|
| | def can_generate(self) -> bool: |
| | return True |
| |
|
| | def prepare_inputs_for_generation( |
| | self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs |
| | ): |
| | if past_key_values: |
| | |
| | input_ids = input_ids[:, -1:] |
| | model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values} |
| |
|
| | model_inputs.update(kwargs) |
| | model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache) |
| | return model_inputs |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def get_input_embeddings(self) -> torch.nn.Module: |
| | return self.model.transformer.wte |
| |
|
| | def set_input_embeddings(self, value: torch.nn.Module): |
| | self.model.transformer.wte = value |
| |
|
| | def get_output_embeddings(self): |
| | if self.config.weight_tying: |
| | return self.model.transformer.wte |
| | else: |
| | return self.model.transformer.ff_out |
| |
|
| | def set_output_embeddings(self, value: torch.nn.Module): |
| | if self.config.weight_tying: |
| | self.model.transformer.wte = value |
| | else: |
| | self.model.transformer.ff_out = value |
| |
|
| | def tie_weights(self): |
| | """ |
| | This function is intentionally left as a no-op. |
| | |
| | Weight tying is handled as follows: |
| | - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration. |
| | See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`. |
| | - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled. |
| | See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method. |
| | |
| | Therefore, there is no need to explicitly tie the weights in this function. |
| | """ |
| | pass |
| |
|
| | def resize_token_embeddings( |
| | self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None |
| | ) -> torch.nn.Embedding: |
| | """ |
| | Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`. |
| | |
| | Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. |
| | |
| | Arguments: |
| | new_num_tokens (`int`, *optional*): |
| | The new number of tokens in the embedding matrix. Increasing the size will add newly initialized |
| | vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just |
| | returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. |
| | pad_to_multiple_of (`int`, *optional*): |
| | If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to |
| | `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. |
| | |
| | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability |
| | `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more |
| | details about this, or help on choosing the correct value for resizing, refer to this guide: |
| | https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc |
| | |
| | Return: |
| | `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. |
| | |
| | Note: |
| | This method differs from the base class implementation by resizing the `embedding_size` attribute of the |
| | model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size` |
| | is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token |
| | embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary. |
| | """ |
| | model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
| | if new_num_tokens is None and pad_to_multiple_of is None: |
| | return model_embeds |
| |
|
| | |
| | self.config.embedding_size = model_embeds.weight.shape[0] |
| | self.model.config.embedding_size = model_embeds.weight.shape[0] |
| |
|
| | |
| | if self.config.embedding_size < self.config.vocab_size: |
| | warning_message = ( |
| | f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size " |
| | f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary " |
| | "size is less than or equal to the new token embedding size." |
| | ) |
| | log.warning(warning_message) |
| |
|
| | |
| | self.tie_weights() |
| |
|
| | return model_embeds |
| |
|
| |
|
| | |
| | |
| | |
| | AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM) |
| |
|