Mxytyu commited on
Commit
25845f5
1 Parent(s): 2a350ba

Create hello_world_model.safetensors

Browse files
Files changed (1) hide show
  1. hello_world_model.safetensors +54 -0
hello_world_model.safetensors ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, CausalLMOutput
3
+
4
+ # Define the model configuration
5
+ class HelloWorldConfig(PretrainedConfig):
6
+ model_type = "hello-world"
7
+ vocab_size = 2
8
+ bos_token_id = 0
9
+ eos_token_id = 1
10
+
11
+ # Define the model
12
+ class HelloWorldModel(PreTrainedModel):
13
+ config_class = HelloWorldConfig
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+
18
+ def forward(self, input_ids=None, **kwargs):
19
+ batch_size = input_ids.shape[0]
20
+ sequence_length = input_ids.shape[1]
21
+
22
+ # Generate logits for the "Hello, world!" token
23
+ hello_world_token_id = self.config.vocab_size - 1
24
+ logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf'))
25
+ logits[:, :, hello_world_token_id] = 0
26
+
27
+ return CausalLMOutput(logits=logits)
28
+
29
+ # Define and save the tokenizer
30
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
31
+ tokenizer.add_tokens(["Hello, world!"])
32
+
33
+ tokenizer_config = {
34
+ "do_lower_case": False,
35
+ "model_max_length": 512,
36
+ "padding_side": "right",
37
+ "special_tokens_map_file": None,
38
+ "tokenizer_file": "tokenizer.json",
39
+ "unk_token": "<unk>",
40
+ "bos_token": "<s>",
41
+ "eos_token": "</s>",
42
+ "vocab_size": 2,
43
+ }
44
+
45
+ with open("tokenizer.json", "w") as f:
46
+ json.dump(tokenizer_config, f)
47
+
48
+ # Initialize model
49
+ config = HelloWorldConfig()
50
+ model = HelloWorldModel(config)
51
+
52
+ # Save model using safetensors format
53
+ from safetensors.torch import save_file
54
+ save_file(model.state_dict(), "hello_world_model.safetensors")