| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.models.auto.configuration_auto import CONFIG_MAPPING |
| from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING |
|
|
| class LanceAIConfig(PretrainedConfig): |
| model_type = "lance_ai" |
| def __init__(self, vocab_size=50257, hidden_size=2048, num_layers=24, num_heads=16, architectures=["LanceAI"], **kwargs): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.num_heads = num_heads |
| self.architectures = architectures |
|
|
| class LanceAI(PreTrainedModel, GenerationMixin): |
| config_class = LanceAIConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
| self.encoder = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads), |
| num_layers=config.num_layers |
| ) |
| self.decoder = nn.TransformerDecoder( |
| nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads), |
| num_layers=config.num_layers |
| ) |
|
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
| self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
| |
| self.generation_config.max_new_tokens = 250 |
| self.generation_config.temperature = 0.8 |
| self.generation_config.top_k = 40 |
| self.generation_config.top_p = 0.9 |
| self.generation_config.do_sample = True |
| self.generation_config.repetition_penalty = 1.3 |
| self.generation_config.no_repeat_ngram_size = 3 |
| self.generation_config.length_penalty = 1.0 |
| |
| self.to(torch.bfloat16) |
|
|
| self.init_weights() |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, inputs_embeds=None, return_dict=True, use_cache=False, **kwargs): |
| embeddings = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds |
| encoder_output = self.encoder(embeddings) |
| decoder_output = self.decoder(embeddings, encoder_output) |
|
|
| logits = self.lm_head(decoder_output) |
| loss = None |
|
|
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| shift_labels = torch.clamp(shift_labels, max=self.config.vocab_size - 1) |
| loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
|
|
| if return_dict: |
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|
| return (loss, logits) if loss is not None else logits |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): |
| |
| if past_key_values: |
| input_ids = input_ids[:, -1].unsqueeze(-1) |
| |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "past_key_values": past_key_values, |
| **kwargs, |
| } |
| |
| def _reorder_cache(self, past_key_values, beam_idx): |
| |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) |
| return reordered_past |
|
|
| |
| CONFIG_MAPPING.register("lance_ai", LanceAIConfig) |
| MODEL_FOR_CAUSAL_LM_MAPPING.register(LanceAIConfig, LanceAI) |
| LanceAIConfig.register_for_auto_class("AutoConfig") |
| LanceAI.register_for_auto_class("AutoModelForCausalLM") |