vxbrandon commited on
Commit
ec539e9
1 Parent(s): 56df701

Upload model

Browse files
cats.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import (
10
+ PretrainedConfig,
11
+ PreTrainedModel,
12
+ AutoConfig, AutoModelForCausalLM,
13
+ )
14
+
15
+ from utils.constants import MISTRAL_7B
16
+ from utils.utils import _get_submodules
17
+
18
+ class Cats(nn.Module):
19
+ def __init__(
20
+ self,
21
+ wrapped_module: nn.Module,
22
+ threshold: float = 0,
23
+ hist_num_bins: int = 1000,
24
+ hist_min: int = -1,
25
+ hist_max: int = 1,
26
+ ):
27
+ super(Cats, self).__init__()
28
+ self.wrapped_module = wrapped_module
29
+ self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
30
+ self.histogram_bins = torch.linspace(hist_min, hist_max, hist_num_bins - 2)
31
+ self.histogram_bins = torch.cat(
32
+ [torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])]
33
+ )
34
+ self.hist_counts = torch.zeros(hist_num_bins - 1)
35
+ self.abs_hist_counts = torch.zeros(hist_num_bins - 1)
36
+ self.collect_stats = True
37
+
38
+ def disable_collect_stats(self):
39
+ self.collect_stats = False
40
+
41
+ def enable_collect_stats(self):
42
+ self.collect_stats = True
43
+
44
+ def set_threshold(self, threshold: float):
45
+ self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False)
46
+
47
+ def forward(self, x):
48
+ x = self.wrapped_module(x)
49
+ if self.collect_stats:
50
+ self.hist_counts += torch.histogram(x, bins=self.histogram_bins)[0]
51
+ self.abs_hist_counts += torch.histogram(
52
+ torch.abs(x), bins=self.histogram_bins
53
+ )[0]
54
+ x[abs(x) < self.threshold] = 0
55
+ return x
56
+
57
+
58
+ # Function to load existing data from a JSON file
59
+ def load_data(file_path):
60
+ try:
61
+ with open(file_path, "r") as json_file:
62
+ return json.load(json_file)
63
+ except FileNotFoundError:
64
+ return {} # Return an empty dictionary if the file does not exist
65
+
66
+
67
+ # Function to save the dictionary to a JSON file
68
+ def save_to_json(data, file_path):
69
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
70
+ with open(file_path, "w") as json_file:
71
+ json.dump(data, json_file, indent=4)
72
+
73
+
74
+ class CatsConfig(PretrainedConfig):
75
+ model_type = "cats_model"
76
+ def __init__(
77
+ self,
78
+ wrapped_model_config=AutoConfig.from_pretrained(MISTRAL_7B),
79
+ wrapped_model_class_name: str = "MistralForCausalLM",
80
+ target_modules: List[str] = ["act_fn"],
81
+ target_sparsity: float = 0.5,
82
+ **kwargs,
83
+ ):
84
+ self.target_modules = target_modules
85
+ self.target_sparsity = target_sparsity
86
+ self.wrapped_model_class_name = wrapped_model_class_name
87
+ self.__dict__.update(wrapped_model_config.__dict__)
88
+ super().__init__(**kwargs)
89
+
90
+
91
+ class CatsModel(PreTrainedModel):
92
+ config_class = CatsConfig
93
+
94
+ def __init__(self, config, wrapped_model_pretrained_dir: str = None, **kwargs):
95
+ super().__init__(config)
96
+ transformers_module = importlib.import_module("transformers")
97
+ self.wrapped_model_class = getattr(transformers_module, config.wrapped_model_class_name)
98
+ self.wrapped_model = self.wrapped_model_class(config)
99
+ if wrapped_model_pretrained_dir is not None:
100
+ self.wrapped_model = self.wrapped_model_class.from_pretrained(wrapped_model_pretrained_dir)
101
+ print(self.__dict__)
102
+ self.inject_cats()
103
+
104
+ def inject_cats(self):
105
+ for name, module in self.wrapped_model.named_modules():
106
+ parent, target, target_name = _get_submodules(self.wrapped_model, name)
107
+ if target_name in self.config.target_modules:
108
+ print(f"{name} is replaced.")
109
+
110
+ # Replace target module with target module + CATS
111
+ cats = Cats(wrapped_module=target)
112
+ setattr(parent, target_name, cats)
113
+
114
+ def enable_collect_stats(self):
115
+ for module in self.wrapped_model.named_modules():
116
+ if isinstance(module, Cats):
117
+ module.enable_collect_stats()
118
+
119
+ def disable_adapters(self) -> None:
120
+ for module in self.wrapped_model.named_modules():
121
+ if isinstance(module, Cats):
122
+ module.disable_collect_stats()
123
+
124
+ # def __getattr__(self, name: str):
125
+ # """Forward missing attributes to the wrapped module."""
126
+ # try:
127
+ # return super().__getattr__(name) # defer to nn.Module's logic
128
+ # except AttributeError:
129
+ # return getattr(self.model, name)
130
+
131
+
132
+ def simple_exp():
133
+ model_dir = MISTRAL_7B
134
+ config = AutoConfig.from_pretrained(model_dir)
135
+ cats_config = CatsConfig(config, wrapped_model_class_name="MistralForCausalLM")
136
+ model = CatsModel(cats_config, wrapped_model_pretrained_dir=None)
137
+ print(model)
138
+ print(model.wrapped_model)
139
+ print(model.config)
140
+
141
+ CatsConfig.register_for_auto_class()
142
+ CatsModel.register_for_auto_class("AutoModelForCausalLM")
143
+
144
+ repo_id = "thrunlab/cats_exp"
145
+ model.push_to_hub(repo_id)
146
+ model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)
147
+
148
+
149
+
150
+ if __name__ == "__main__":
151
+ simple_exp()
config.json CHANGED
@@ -3,6 +3,10 @@
3
  "CatsModel"
4
  ],
5
  "attention_dropout": 0.0,
 
 
 
 
6
  "hidden_act": "silu",
7
  "hidden_size": 4096,
8
  "initializer_range": 0.02,
 
3
  "CatsModel"
4
  ],
5
  "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "cats.CatsConfig",
8
+ "AutoModelForCausalLM": "cats.CatsModel"
9
+ },
10
  "hidden_act": "silu",
11
  "hidden_size": 4096,
12
  "initializer_range": 0.02,
model-00001-of-00006.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:518889943dc9e710f8b6591e913c7828e0b7809753d4ec1425f6361e62819d7c
3
  size 4987198176
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02be45417249d551272c3ee9d68883dedc06937e310525f19a4aa5f079e290a6
3
  size 4987198176
model-00002-of-00006.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ad42e8e9d7e18c3b36d41e6fdbdfa396fde96ab4805eae83789be3c5edc90bcc
3
  size 4899117664
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:421b3cda4afa202632eae121e820d640e18b13cb27b5af248ddc9d68601ed765
3
  size 4899117664
model-00003-of-00006.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:69d306913ca9bba54934aea5733d22b2f4f5277e6863e82849bd45277b6fae1f
3
  size 4999814528
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42db0dfcfa4abc870330f959fb048964cf663d8b2c3c8aee01da6329f2ab0bd7
3
  size 4999814528
model-00004-of-00006.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6589abda0af5979e33babc4ebf35ee403e0fac9f2cca16462601fbf2753050f2
3
  size 4999814528
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc0fe251d811c7d07606bab7cb44dd624c068ae10f240bf1d02a3faec2c5e778
3
  size 4999814528
model-00005-of-00006.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:11ea95a1c1474e697a4349921d82948c4888ee3a55fe296e0bbaaab4b2a4ac78
3
  size 4832008712
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62446988cc65a531f9291d839567b046ce05c63e290a327018c76c43493f3b8c
3
  size 4832008712
model-00006-of-00006.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb6264074829100dcc0a4f76db86aa0e47e9aaa91fe5e74f5c3d696ef93aee02
3
- size 3724727904
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7d274e4390c9336160ddb433d1fb9165ba660269129305ff94b6746998a9088
3
+ size 4249016024
model.safetensors.index.json CHANGED
@@ -1,8 +1,9 @@
1
  {
2
  "metadata": {
3
- "total_size": 28442640640
4
  },
5
  "weight_map": {
 
6
  "wrapped_model.model.embed_tokens.weight": "model-00001-of-00006.safetensors",
7
  "wrapped_model.model.layers.0.input_layernorm.weight": "model-00001-of-00006.safetensors",
8
  "wrapped_model.model.layers.0.mlp.act_fn.threshold": "model-00001-of-00006.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_size": 28966928640
4
  },
5
  "weight_map": {
6
+ "wrapped_model.lm_head.weight": "model-00006-of-00006.safetensors",
7
  "wrapped_model.model.embed_tokens.weight": "model-00001-of-00006.safetensors",
8
  "wrapped_model.model.layers.0.input_layernorm.weight": "model-00001-of-00006.safetensors",
9
  "wrapped_model.model.layers.0.mlp.act_fn.threshold": "model-00001-of-00006.safetensors",