jordiclive commited on
Commit
7cc9b80
1 Parent(s): 4c6991f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -4
README.md CHANGED
@@ -74,7 +74,7 @@ repo_id = "jordiclive/falcon_lora_40b_ckpt_500_oasst_1"
74
  base_model = "tiiuae/falcon-40b"
75
 
76
  # Model Loading
77
- def transfer_embeddings(model, embed_path, tokenizer):
78
  old_embeddings = model.get_input_embeddings()
79
  old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
80
  new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
@@ -83,16 +83,17 @@ def transfer_embeddings(model, embed_path, tokenizer):
83
  embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
84
  vocab_size = tokenizer.vocab_size
85
  new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
86
- new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.weight.data.to(
87
  new_embeddings.weight.dtype
88
  ).to(new_embeddings.weight.device)
89
  model.set_input_embeddings(new_embeddings)
90
  model.tie_weights()
91
 
92
 
 
93
  def load_peft_model(model, peft_model_path, tokenizer):
94
  embed_weights = hf_hub_download(peft_model_path, "extra_embeddings.pt")
95
- model.resize_token_embeddings(tokenizer.vocab_size + embed_weights.shape[0])
96
  model.config.eos_token_id = tokenizer.eos_token_id
97
  model.config.bos_token_id = tokenizer.bos_token_id
98
  model.config.pad_token_id = tokenizer.pad_token_id
@@ -102,7 +103,7 @@ def load_peft_model(model, peft_model_path, tokenizer):
102
  torch_dtype=model.dtype,
103
  )
104
  model.eos_token_id = tokenizer.eos_token_id
105
- transfer_embeddings(model, peft_model_path.joinpath("extra_embeddings.pt"), tokenizer)
106
  return model
107
 
108
 
@@ -116,6 +117,8 @@ model = load_peft_model(model, repo_id, tokenizer)
116
 
117
  # device configuration
118
  model = model.to(device)
 
 
119
 
120
 
121
  # Choose Generation parameters
@@ -155,4 +158,5 @@ generate("What is a meme, and what's the history behind this word?")
155
  generate("What's the Earth total population")
156
  generate("Write a story about future of AI development")
157
 
 
158
  ```
 
74
  base_model = "tiiuae/falcon-40b"
75
 
76
  # Model Loading
77
+ def add_embeddings(model, embed_path, tokenizer):
78
  old_embeddings = model.get_input_embeddings()
79
  old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
80
  new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
 
83
  embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
84
  vocab_size = tokenizer.vocab_size
85
  new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
86
+ new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.to(
87
  new_embeddings.weight.dtype
88
  ).to(new_embeddings.weight.device)
89
  model.set_input_embeddings(new_embeddings)
90
  model.tie_weights()
91
 
92
 
93
+
94
  def load_peft_model(model, peft_model_path, tokenizer):
95
  embed_weights = hf_hub_download(peft_model_path, "extra_embeddings.pt")
96
+ model.resize_token_embeddings(tokenizer.vocab_size + torch.load(embed_weights).shape[0])
97
  model.config.eos_token_id = tokenizer.eos_token_id
98
  model.config.bos_token_id = tokenizer.bos_token_id
99
  model.config.pad_token_id = tokenizer.pad_token_id
 
103
  torch_dtype=model.dtype,
104
  )
105
  model.eos_token_id = tokenizer.eos_token_id
106
+ add_embeddings(model, embed_weights, tokenizer)
107
  return model
108
 
109
 
 
117
 
118
  # device configuration
119
  model = model.to(device)
120
+ if dtype == torch.float16:
121
+ model = model.half()
122
 
123
 
124
  # Choose Generation parameters
 
158
  generate("What's the Earth total population")
159
  generate("Write a story about future of AI development")
160
 
161
+
162
  ```