File size: 1,606 Bytes
1bae57f
2a350ba
dc78d43
2a350ba
1bae57f
 
2a350ba
 
 
dc78d43
2a350ba
 
1bae57f
dc78d43
1bae57f
 
dc78d43
2ec93d0
 
 
dc78d43
2a350ba
 
 
 
2ec93d0
2a350ba
80b293f
2a350ba
82a267c
 
 
1bae57f
 
 
 
 
 
 
 
2ec93d0
2a350ba
1bae57f
 
 
 
2ec93d0
82a267c
2a350ba
82a267c
 
 
 
2a350ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, CausalLMOutput

# Define the model configuration
class HelloWorldConfig(PretrainedConfig):
    model_type = "hello-world"
    vocab_size = 2
    bos_token_id = 0
    eos_token_id = 1

# Define the model
class HelloWorldModel(PreTrainedModel):
    config_class = HelloWorldConfig

    def __init__(self, config):
        super().__init__(config)

    def forward(self, input_ids=None, **kwargs):
        batch_size = input_ids.shape[0]
        sequence_length = input_ids.shape[1]

        # Generate logits for the "Hello, world!" token
        hello_world_token_id = self.config.vocab_size - 1
        logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf'))
        logits[:, :, hello_world_token_id] = 0

        return CausalLMOutput(logits=logits)

# Define and save the tokenizer
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
tokenizer.add_tokens(["Hello, world!"])

tokenizer_config = {
    "do_lower_case": False,
    "model_max_length": 512,
    "padding_side": "right",
    "special_tokens_map_file": None,
    "tokenizer_file": "tokenizer.json",
    "unk_token": "<unk>",
    "bos_token": "<s>",
    "eos_token": "</s>",
    "vocab_size": 2,
}

with open("tokenizer.json", "w") as f:
    json.dump(tokenizer_config, f)

# Initialize model
config = HelloWorldConfig()
model = HelloWorldModel(config)

# Save model using safetensors format
from safetensors.torch import save_file
save_file(model.state_dict(), "hello_world_model.safetensors")