| """Prisma model for HuggingFace integration.
|
|
|
| Usage:
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
| model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True)
|
| tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma")
|
| """
|
|
|
| import torch
|
| from transformers import PreTrainedModel
|
| from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
| from .configuration_prisma import PrismaConfig
|
| from .mirrored import MirroredTransformer, MirroredConfig
|
| from .layers import build_word_start_table, compute_word_positions
|
|
|
|
|
| class PrismaForCausalLM(PreTrainedModel):
|
| """Prisma mirrored transformer for causal language modeling."""
|
|
|
| config_class = PrismaConfig
|
| _tied_weights_keys = ["transformer.lm_head.weight"]
|
| _no_split_modules = ["MirroredBlock", "MiddleBlock"]
|
| _keys_to_ignore_on_load_missing = [
|
| r"transformer\..*\.rotary\.inv_freq",
|
| r"transformer\..*\.word_rope\.word_inv_freq",
|
| ]
|
|
|
| def __init__(self, config: PrismaConfig):
|
| super().__init__(config)
|
|
|
| mirrored_config = MirroredConfig(
|
| vocab_size=config.vocab_size,
|
| hidden_size=config.hidden_size,
|
| num_heads=config.num_heads,
|
| num_kv_heads=config.num_kv_heads,
|
| num_layers=config.num_layers,
|
| n_middle=config.n_middle,
|
| max_seq_len=config.max_seq_len,
|
| dropout=config.dropout,
|
| aux_skip_k=config.aux_skip_k,
|
| aux_skip_weight=config.aux_skip_weight,
|
| use_g2lu=config.use_g2lu,
|
| word_rope_dims=config.word_rope_dims,
|
| word_rope_base=config.word_rope_base,
|
| embed_dim=config.embed_dim,
|
| head_dim=config.head_dim,
|
| )
|
| self.transformer = MirroredTransformer(mirrored_config)
|
|
|
|
|
| if config.word_rope_dims > 0:
|
| self.register_buffer(
|
| "word_start_table",
|
| torch.zeros(config.vocab_size, dtype=torch.bool),
|
| persistent=True,
|
| )
|
| else:
|
| self.word_start_table = None
|
|
|
|
|
| self._word_pos_counter = 0
|
|
|
| self.post_init()
|
|
|
| def set_tokenizer(self, tokenizer):
|
| """Build word_start_table from tokenizer. Call this if not loading from pretrained."""
|
| if self.config.word_rope_dims > 0:
|
| table = build_word_start_table(tokenizer, self.config.vocab_size)
|
| self.word_start_table = table.to(self.device)
|
|
|
| def get_input_embeddings(self):
|
| return self.transformer.embed
|
|
|
| def set_input_embeddings(self, value):
|
| self.transformer.embed = value
|
|
|
| def get_output_embeddings(self):
|
| return self.transformer.lm_head
|
|
|
| def set_output_embeddings(self, new_embeddings):
|
| self.transformer.lm_head = new_embeddings
|
|
|
| def tie_weights(self):
|
| if self.config.tie_word_embeddings:
|
| embed_dim = self.config.embed_dim or self.config.hidden_size
|
| head_dim = self.config.head_dim or self.config.hidden_size
|
| if embed_dim == head_dim:
|
| self.transformer.lm_head.weight = self.transformer.embed.weight
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| attention_mask=None,
|
| past_key_values=None,
|
| labels=None,
|
| use_cache=False,
|
| return_dict=True,
|
| **kwargs,
|
| ):
|
|
|
| past_kv_list = None
|
| if past_key_values is not None:
|
|
|
| has_content = False
|
| if isinstance(past_key_values, (list, tuple)):
|
| has_content = len(past_key_values) > 0
|
| past_kv_list = past_key_values if has_content else None
|
| elif hasattr(past_key_values, 'get_seq_length'):
|
| has_content = past_key_values.get_seq_length() > 0
|
| if has_content:
|
| past_kv_list = [past_key_values[i] for i in range(len(past_key_values))]
|
|
|
|
|
| word_positions = None
|
| if self.word_start_table is not None and self.config.word_rope_dims > 0:
|
| if past_kv_list is not None and input_ids.size(1) == 1:
|
|
|
| last_token = input_ids[0, -1].item()
|
| if self.word_start_table[last_token]:
|
| self._word_pos_counter = 0
|
| else:
|
| self._word_pos_counter += 1
|
| word_positions = torch.tensor(
|
| [[float(self._word_pos_counter)]],
|
| device=input_ids.device,
|
| )
|
| else:
|
|
|
| word_positions = compute_word_positions(input_ids, self.word_start_table)
|
|
|
| self._word_pos_counter = int(word_positions[0, -1].item())
|
|
|
| output = self.transformer(
|
| input_ids,
|
| labels=labels,
|
| use_cache=use_cache,
|
| past_kv=past_kv_list,
|
| word_positions=word_positions,
|
| )
|
|
|
|
|
| new_cache = None
|
| if use_cache and output.get("past_kv") is not None:
|
| from transformers.cache_utils import DynamicCache
|
| new_cache = DynamicCache()
|
| for layer_idx, (k, v) in enumerate(output["past_kv"]):
|
| new_cache.update(k, v, layer_idx)
|
|
|
| if not return_dict:
|
| result = (output["logits"],)
|
| if use_cache:
|
| result += (new_cache,)
|
| return result
|
|
|
| return CausalLMOutputWithPast(
|
| loss=output.get("loss"),
|
| logits=output["logits"],
|
| past_key_values=new_cache,
|
| )
|
|
|
| def prepare_inputs_for_generation(
|
| self, input_ids, past_key_values=None, **kwargs
|
| ):
|
|
|
| has_cache = False
|
| if past_key_values is not None:
|
| if hasattr(past_key_values, 'get_seq_length'):
|
| has_cache = past_key_values.get_seq_length() > 0
|
| elif isinstance(past_key_values, (list, tuple)):
|
| has_cache = len(past_key_values) > 0
|
| if has_cache:
|
| input_ids = input_ids[:, -1:]
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "past_key_values": past_key_values,
|
| "use_cache": True,
|
| }
|
|
|