hello-world / model.py
Chiedo John
Add dataset integration to Hello World model
d0c3c53
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from datasets import load_dataset
class HelloWorldConfig(PretrainedConfig):
model_type = "hello_world"
def __init__(
self,
vocab_size=13,
hidden_size=64,
num_hidden_layers=1,
num_attention_heads=1,
intermediate_size=128,
hidden_act="gelu",
max_position_embeddings=512,
type_vocab_size=1,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
class HelloWorldModel(PreTrainedModel):
config_class = HelloWorldConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
batch_first=True
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
labels=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
if input_ids is not None:
batch_size, seq_length = input_ids.shape
else:
raise ValueError("You have to specify input_ids")
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
inputs_embeds = self.embeddings(input_ids)
position_embeds = self.position_embeddings(position_ids)
hidden_states = inputs_embeds + position_embeds
hidden_states = self.layer(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=hidden_states if output_hidden_states else None,
attentions=None
)
def generate_hello_world(self):
hello_token_id = 5
world_token_id = 6
input_ids = torch.tensor([[hello_token_id, world_token_id]])
with torch.no_grad():
outputs = self.forward(input_ids)
return "Hello World!"
@classmethod
def load_dataset(cls, dataset_name="chiedo/hello-world", split=None):
"""
Load the Hello World dataset.
Args:
dataset_name (str): Name of the dataset on Hugging Face Hub
split (str, optional): Specific split to load ('train', 'validation', 'test')
Returns:
Dataset or DatasetDict depending on split parameter
"""
try:
if split:
return load_dataset(dataset_name, split=split)
else:
return load_dataset(dataset_name)
except Exception as e:
print(f"Error loading dataset: {e}")
print(f"Make sure the dataset exists at: https://huggingface.co/datasets/{dataset_name}")
return None
def prepare_dataset_batch(self, texts, tokenizer, max_length=128):
"""
Prepare a batch of texts from the dataset for model input.
Args:
texts (list): List of text strings
tokenizer: Tokenizer to encode the texts
max_length (int): Maximum sequence length
Returns:
dict: Dictionary with input_ids and attention_mask tensors
"""
return tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt"
)