Inteligent_ai / model.safetensors
Mxytyu's picture
Update model.safetensors
82a267c verified
raw
history blame
2.16 kB
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, LogitsProcessorList
from transformers.generation_utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutput
class HelloWorldConfig(PretrainedConfig):
model_type = "hello-world"
class HelloWorldModel(PreTrainedModel, GenerationMixin):
config_class = HelloWorldConfig
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids=None, **kwargs):
batch_size = input_ids.shape[0]
sequence_length = input_ids.shape[1]
# Generate a tensor with repeated "Hello, world!" token IDs
hello_world_token_id = self.config.vocab_size - 1 # assuming last token is "Hello, world!"
logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf'))
logits[:, :, hello_world_token_id] = 0 # setting logits for "Hello, world!" to 0 (highest value)
return CausalLMOutput(logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False):
return model_kwargs
# Define tokenizer
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
tokenizer.add_tokens(["Hello, world!"])
tokenizer_config = {
"do_lower_case": False,
"model_max_length": 512,
"padding_side": "right",
"special_tokens_map_file": None,
"tokenizer_file": "tokenizer.json",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"vocab_size": 2, # Simplified vocabulary size
}
# Save tokenizer configuration
with open("tokenizer.json", "w") as f:
import json
json.dump(tokenizer_config, f)
# Initialize model
config = HelloWorldConfig(vocab_size=2) # Adjusted vocab size
model = HelloWorldModel(config)
# Create dummy state_dict for saving
state_dict = model.state_dict()
# Save model using safetensors format
from safetensors.torch import save_file
save_file(state_dict, "hello_world_model.safetensors")