| """
|
| Vortex model implementation for HuggingFace.
|
| Integrates with transformers library.
|
| """
|
|
|
| from typing import Optional, Tuple, List, Dict, Any
|
| import torch
|
| import torch.nn as nn
|
| from transformers import PreTrainedModel, PretrainedConfig, GenerationConfig
|
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
|
| from configuration_vortex import VortexConfig
|
| from models.vortex_model import VortexModel
|
|
|
|
|
| class VortexPreTrainedModel(PreTrainedModel):
|
| """
|
| Base class for Vortex models.
|
| Handles loading/saving in HF format.
|
| """
|
| config_class = VortexConfig
|
| base_model_prefix = "vortex"
|
| supports_gradient_checkpointing = True
|
| _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
|
|
| def _init_weights(self, module):
|
| """Initialize weights."""
|
| if isinstance(module, nn.Linear):
|
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| if module.bias is not None:
|
| module.bias.data.zero_()
|
| elif isinstance(module, nn.Embedding):
|
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| if module.padding_idx is not None:
|
| module.weight.data[module.padding_idx].zero_()
|
| elif isinstance(module, nn.LayerNorm):
|
| module.bias.data.zero_()
|
| module.weight.data.fill_(1.0)
|
|
|
| def get_input_embeddings(self):
|
| return self.vortex.embed_tokens
|
|
|
| def set_input_embeddings(self, value):
|
| self.vortex.embed_tokens = value
|
|
|
| def get_output_embeddings(self):
|
| return self.vortex.lm_head
|
|
|
| def set_output_embeddings(self, new_embeddings):
|
| self.vortex.lm_head = new_embeddings
|
|
|
|
|
| class VortexForCausalLM(VortexPreTrainedModel):
|
| """
|
| Vortex model for causal language modeling.
|
| """
|
| _tied_weights_keys = ["vortex.lm_head.weight"]
|
|
|
| def __init__(self, config: VortexConfig):
|
| super().__init__(config)
|
| self.config = config
|
|
|
|
|
| self.vortex = VortexModel(config.to_dict())
|
|
|
|
|
| self.apply(self._init_weights)
|
|
|
|
|
| if self.config.tie_word_embeddings:
|
| self.tie_weights()
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.LongTensor = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| position_ids: Optional[torch.LongTensor] = None,
|
| past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| inputs_embeds: Optional[torch.FloatTensor] = None,
|
| labels: Optional[torch.LongTensor] = None,
|
| use_cache: Optional[bool] = None,
|
| output_attentions: Optional[bool] = None,
|
| output_hidden_states: Optional[bool] = None,
|
| return_dict: Optional[bool] = None,
|
| domain_ids: Optional[torch.LongTensor] = None,
|
| domain_tags: Optional[torch.Tensor] = None,
|
| text: Optional[List[str]] = None,
|
| ) -> CausalLMOutputWithCrossAttentions:
|
| """
|
| Forward pass.
|
|
|
| Args:
|
| input_ids: Token IDs (batch, seq_len)
|
| attention_mask: Attention mask (batch, seq_len)
|
| labels: Labels for LM loss (batch, seq_len)
|
| domain_ids: Domain IDs (batch,)
|
| domain_tags: Domain tag masks (batch, seq_len, num_domains)
|
| text: Original text strings (for science modules)
|
| """
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
| outputs = self.vortex(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| domain_ids=domain_ids,
|
| domain_tags=domain_tags,
|
| text=text,
|
| return_dict=True,
|
| )
|
|
|
| logits = outputs["logits"]
|
| last_hidden_state = outputs["last_hidden_state"]
|
|
|
| loss = None
|
| if labels is not None:
|
|
|
| shift_logits = logits[..., :-1, :].contiguous()
|
| shift_labels = labels[..., 1:].contiguous()
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| loss = loss_fct(
|
| shift_logits.view(-1, shift_logits.size(-1)),
|
| shift_labels.view(-1),
|
| )
|
|
|
| if not return_dict:
|
| output = (logits,) + (last_hidden_state,)
|
| return (loss,) + output if loss is not None else output
|
|
|
| return CausalLMOutputWithCrossAttentions(
|
| loss=loss,
|
| logits=logits,
|
| hidden_states=last_hidden_state,
|
| attentions=None,
|
| )
|
|
|
| def prepare_inputs_for_generation(
|
| self,
|
| input_ids,
|
| past_key_values=None,
|
| attention_mask=None,
|
| **kwargs,
|
| ):
|
| """Prepare inputs for text generation."""
|
|
|
| if past_key_values:
|
| input_ids = input_ids[:, -1:]
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| "past_key_values": past_key_values,
|
| "use_cache": kwargs.get("use_cache", True),
|
| }
|
|
|
| def generate(
|
| self,
|
| input_ids: Optional[torch.LongTensor] = None,
|
| inputs_embeds: Optional[torch.FloatTensor] = None,
|
| **kwargs,
|
| ):
|
| """Generate text."""
|
| from transformers import GenerationConfig
|
|
|
| generation_config = kwargs.pop("generation_config", None)
|
| if generation_config is None:
|
| generation_config = GenerationConfig.from_model_config(self.config)
|
|
|
| return super().generate(
|
| input_ids=input_ids,
|
| inputs_embeds=inputs_embeds,
|
| generation_config=generation_config,
|
| **kwargs,
|
| )
|
|
|
|
|
|
|
| from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
| AutoConfig.register("vortex", VortexConfig)
|
| AutoModelForCausalLM.register(VortexConfig, VortexForCausalLM)
|
|
|
|
|
| def test_hf_integration():
|
| """Test HuggingFace integration."""
|
| from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
|
| config = VortexConfig(
|
| d_model=512,
|
| num_layers=2,
|
| num_heads=8,
|
| vocab_size=1000,
|
| )
|
|
|
|
|
| model = VortexForCausalLM(config)
|
| print(f"Model parameters: {model.get_num_parameters():,}")
|
|
|
|
|
| batch_size = 2
|
| seq_len = 32
|
| input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
| labels = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
|
|
| outputs = model(input_ids=input_ids, labels=labels)
|
| print(f"Loss: {outputs.loss.item():.4f}")
|
| print(f"Logits shape: {outputs.logits.shape}")
|
|
|
|
|
| model.save_pretrained("./test_vortex_model")
|
| config.save_pretrained("./test_vortex_model")
|
|
|
| loaded_config = AutoConfig.from_pretrained("./test_vortex_model")
|
| loaded_model = AutoModelForCausalLM.from_pretrained("./test_vortex_model")
|
| print(f"Loaded model type: {type(loaded_model)}")
|
|
|
| print("HF integration test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_hf_integration()
|
|
|