|
import torch |
|
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, LogitsProcessorList |
|
from transformers.generation_utils import GenerationMixin |
|
from transformers.modeling_outputs import CausalLMOutput |
|
|
|
class HelloWorldConfig(PretrainedConfig): |
|
model_type = "hello-world" |
|
|
|
class HelloWorldModel(PreTrainedModel, GenerationMixin): |
|
config_class = HelloWorldConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward(self, input_ids=None, **kwargs): |
|
batch_size = input_ids.shape[0] |
|
sequence_length = input_ids.shape[1] |
|
|
|
|
|
hello_world_token_id = self.config.vocab_size - 1 |
|
logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf')) |
|
logits[:, :, hello_world_token_id] = 0 |
|
|
|
return CausalLMOutput(logits=logits) |
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
return {"input_ids": input_ids} |
|
|
|
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False): |
|
return model_kwargs |
|
|
|
|
|
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json") |
|
tokenizer.add_tokens(["Hello, world!"]) |
|
|
|
tokenizer_config = { |
|
"do_lower_case": False, |
|
"model_max_length": 512, |
|
"padding_side": "right", |
|
"special_tokens_map_file": None, |
|
"tokenizer_file": "tokenizer.json", |
|
"unk_token": "<unk>", |
|
"bos_token": "<s>", |
|
"eos_token": "</s>", |
|
"vocab_size": 2, |
|
} |
|
|
|
|
|
with open("tokenizer.json", "w") as f: |
|
import json |
|
json.dump(tokenizer_config, f) |
|
|
|
|
|
config = HelloWorldConfig(vocab_size=2) |
|
model = HelloWorldModel(config) |
|
|
|
|
|
state_dict = model.state_dict() |
|
|
|
|
|
from safetensors.torch import save_file |
|
save_file(state_dict, "hello_world_model.safetensors") |
|
|