OxxoCodes commited on
Commit
4556966
1 Parent(s): a570f75
Files changed (1) hide show
  1. prune.py +3 -1
prune.py CHANGED
@@ -4,6 +4,8 @@ import re
4
  import torch
5
  from modeling_jamba import JambaForCausalLM
6
 
 
 
7
  model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16)
8
 
9
  def prune_and_copy_additional_layers(original_state_dict):
@@ -37,5 +39,5 @@ def prune_and_copy_additional_layers(original_state_dict):
37
  pruned_state_dict = prune_and_copy_additional_layers(model.state_dict())
38
 
39
  print("Saving weights...")
40
- torch.save(pruned_state_dict, '/scratch/nbrown9/jamba-small-v1.bin')
41
  print("Done!")
 
4
  import torch
5
  from modeling_jamba import JambaForCausalLM
6
 
7
+ output_dir = "/home/user/jamba-small"
8
+
9
  model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16)
10
 
11
  def prune_and_copy_additional_layers(original_state_dict):
 
39
  pruned_state_dict = prune_and_copy_additional_layers(model.state_dict())
40
 
41
  print("Saving weights...")
42
+ torch.save(pruned_state_dict, output_dir)
43
  print("Done!")