File size: 8,759 Bytes
7e497b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
from argparse import Namespace
from torch.utils.checkpoint import checkpoint
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import Transformer, create_params
from open_lm.attention import get_attn_func, xformers_attn, torch_attn
from open_lm.norms import get_norm_class
import torch
import torch.nn as nn
from typing import Union, Tuple, Optional, List
import os
class OpenLMModel(PreTrainedModel):
config_class = OpenLMConfig
def __init__(self, config, **kwargs):
# This has to be done before init as it sets makes sure hf config is correct
if hasattr(config, "params"):
params = config.params
else:
params_args_dict = config.params_args_dict
if not params_args_dict.get("norm_type"):
params_args_dict["norm_type"] = get_norm_class(params_args_dict["model_norm"])
if not params_args_dict.get("attn_func"):
params_args_dict["attn_func"] = get_attn_func(
params_args_dict["attn_name"],
params_args_dict["attn_activation"],
params_args_dict["attn_seq_scalar"],
params_args_dict["attn_seq_scalar_alpha"]
)
params = create_params(Namespace(**config.params_args_dict))
config.set_params(params)
super().__init__(config, **kwargs)
self.supports_gradient_checkpointing = True
self.model = Transformer(params)
@property
def gradient_checkpointing(self):
return self.model.grad_checkpointing
@gradient_checkpointing.setter
def gradient_checkpointing(self, value):
self.model.grad_checkpointing = value
def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
return self.model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
class OpenLMforCausalLM(OpenLMModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.lm_head = None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.tok_embeddings
def set_input_embeddings(self, value):
self.model.tok_embeddings = value
def get_output_embeddings(self):
return self.model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
raise NotImplementedError
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, OpenLlamaForCausalLM
>>> model = OpenLlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
assert position_ids is None, "Position IDs are not supported"
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
logits, _, past_key_values = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
attention_mask=attention_mask,
)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, loss=loss)
return output
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[1]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_cache = ()
for layer_past in past_key_values:
reordered_cache += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_cache
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
if (
os.path.isdir(pretrained_model_name_or_path)
and kwargs.get("config", None) is not None
and getattr(kwargs["config"], "checkpoint_file", None) is not None
):
# Setting torch default dtype
torch_dtype = getattr(kwargs["config"], "torch_dtype", None)
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype is not None:
torch.set_default_dtype(torch_dtype)
print("Loading checkpoint from directory")
checkpoint_path = kwargs["config"].checkpoint_file
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint["state_dict"]
state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()}
state_dict = {f"model.{x}": y for x, y in state_dict.items()}
return super().from_pretrained(None, state_dict=state_dict, **kwargs)
elif os.path.isdir(pretrained_model_name_or_path):
# Load from a PyTorch checkpoint
print("Loading checkpoint from directory")
checkpoint_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
state_dict = torch.load(checkpoint_path)
# state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()}
state_dict = {f"model.{x}" if "model." not in x else x: y for x, y in state_dict.items()}
return super().from_pretrained(pretrained_model_name_or_path, state_dict=state_dict, **kwargs)
else:
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|