File size: 1,915 Bytes
14efde5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import subprocess
import shlex
import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizerFast


mname_from = "meta-llama/Llama-2-7b-hf"
mname_tiny = "tiny-random-llama-2"
vocab_keep_items = 3000

config = LlamaConfig.from_pretrained(mname_from)
# print("orig config", config)
config.update(dict(
    hidden_size=16,
    intermediate_size=64,
    num_attention_heads=4,
    num_hidden_layers=2,
    max_position_embeddings=256,
    num_key_value_heads=4,
    vocab_size=vocab_keep_items,
))
print("new config", config)

# create a tiny random model
tiny_model = LlamaForCausalLM(config)
print(f"num of params {tiny_model.num_parameters()}")

# shrink it more and save
tiny_model.bfloat16() # half-size
tiny_model.save_pretrained(mname_tiny)

# shrink the tokenizer from 32k to 3k vocab
tokenizer_fast = LlamaTokenizerFast.from_pretrained(mname_from)
tmp_dir = f"/tmp/{mname_from}"
tokenizer_fast.save_pretrained(tmp_dir)
# resize tokenizer.json (vocab.txt will be automatically resized on save_pretrained)
# perl  -0777 -pi -e 's|(2999).*|$1},"merges": []}}|msg' tokenizer.json # 0-indexed, so vocab_keep_items-1!
closing_pat = '},"merges": []}}'
cmd = (f"perl -0777 -pi -e 's|({vocab_keep_items-1}).*|$1{closing_pat}|msg' {tmp_dir}/tokenizer.json")
#print(f"Running:\n{cmd}")
result = subprocess.run(shlex.split(cmd), capture_output=True, text=True)
#print(result)

# reload with modified tokenizer
tokenizer_fast_tiny = LlamaTokenizerFast.from_pretrained(tmp_dir)
tokenizer_fast_tiny.save_pretrained(mname_tiny)

# test the new model and tokenizer function
model_inputs = tokenizer_fast_tiny("Making tiny model", return_tensors="pt")
gen_tokens = tiny_model.generate(**model_inputs, max_new_tokens=100)
print(tokenizer_fast_tiny.batch_decode(gen_tokens, skip_special_tokens=True))
print("Random output should be expected, but no crashing")

print(f"Model+Tokenizer saved in {mname_tiny}")