Mxytyu commited on
Commit
2a350ba
1 Parent(s): 1c984df

Update model.safetensors

Browse files
Files changed (1) hide show
  1. model.safetensors +16 -24
model.safetensors CHANGED
@@ -1,12 +1,15 @@
1
  import torch
2
- from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, LogitsProcessorList
3
- from transformers.generation_utils import GenerationMixin
4
- from transformers.modeling_outputs import CausalLMOutput
5
 
 
6
  class HelloWorldConfig(PretrainedConfig):
7
  model_type = "hello-world"
 
 
 
8
 
9
- class HelloWorldModel(PreTrainedModel, GenerationMixin):
 
10
  config_class = HelloWorldConfig
11
 
12
  def __init__(self, config):
@@ -15,21 +18,15 @@ class HelloWorldModel(PreTrainedModel, GenerationMixin):
15
  def forward(self, input_ids=None, **kwargs):
16
  batch_size = input_ids.shape[0]
17
  sequence_length = input_ids.shape[1]
18
-
19
- # Generate a tensor with repeated "Hello, world!" token IDs
20
- hello_world_token_id = self.config.vocab_size - 1 # assuming last token is "Hello, world!"
21
- logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf'))
22
- logits[:, :, hello_world_token_id] = 0 # setting logits for "Hello, world!" to 0 (highest value)
23
-
24
- return CausalLMOutput(logits=logits)
25
 
26
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
27
- return {"input_ids": input_ids}
 
 
28
 
29
- def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False):
30
- return model_kwargs
31
 
32
- # Define tokenizer
33
  tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
34
  tokenizer.add_tokens(["Hello, world!"])
35
 
@@ -42,21 +39,16 @@ tokenizer_config = {
42
  "unk_token": "<unk>",
43
  "bos_token": "<s>",
44
  "eos_token": "</s>",
45
- "vocab_size": 2, # Simplified vocabulary size
46
  }
47
 
48
- # Save tokenizer configuration
49
  with open("tokenizer.json", "w") as f:
50
- import json
51
  json.dump(tokenizer_config, f)
52
 
53
  # Initialize model
54
- config = HelloWorldConfig(vocab_size=2) # Adjusted vocab size
55
  model = HelloWorldModel(config)
56
 
57
- # Create dummy state_dict for saving
58
- state_dict = model.state_dict()
59
-
60
  # Save model using safetensors format
61
  from safetensors.torch import save_file
62
- save_file(state_dict, "hello_world_model.safetensors")
 
1
  import torch
2
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, CausalLMOutput
 
 
3
 
4
+ # Define the model configuration
5
  class HelloWorldConfig(PretrainedConfig):
6
  model_type = "hello-world"
7
+ vocab_size = 2
8
+ bos_token_id = 0
9
+ eos_token_id = 1
10
 
11
+ # Define the model
12
+ class HelloWorldModel(PreTrainedModel):
13
  config_class = HelloWorldConfig
14
 
15
  def __init__(self, config):
 
18
  def forward(self, input_ids=None, **kwargs):
19
  batch_size = input_ids.shape[0]
20
  sequence_length = input_ids.shape[1]
 
 
 
 
 
 
 
21
 
22
+ # Generate logits for the "Hello, world!" token
23
+ hello_world_token_id = self.config.vocab_size - 1
24
+ logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf'))
25
+ logits[:, :, hello_world_token_id] = 0
26
 
27
+ return CausalLMOutput(logits=logits)
 
28
 
29
+ # Define and save the tokenizer
30
  tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
31
  tokenizer.add_tokens(["Hello, world!"])
32
 
 
39
  "unk_token": "<unk>",
40
  "bos_token": "<s>",
41
  "eos_token": "</s>",
42
+ "vocab_size": 2,
43
  }
44
 
 
45
  with open("tokenizer.json", "w") as f:
 
46
  json.dump(tokenizer_config, f)
47
 
48
  # Initialize model
49
+ config = HelloWorldConfig()
50
  model = HelloWorldModel(config)
51
 
 
 
 
52
  # Save model using safetensors format
53
  from safetensors.torch import save_file
54
+ save_file(model.state_dict(), "hello_world_model.safetensors")