import importlib import json import os from typing import List import numpy as np import torch import torch.nn as nn from transformers import ( PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM, ) from utils.constants import MISTRAL_7B from utils.utils import _get_submodules class Cats(nn.Module): def __init__( self, wrapped_module: nn.Module, threshold: float = 0, hist_num_bins: int = 1000, hist_min: int = -1, hist_max: int = 1, ): super(Cats, self).__init__() self.wrapped_module = wrapped_module self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False) self.histogram_bins = torch.linspace(hist_min, hist_max, hist_num_bins - 2) self.histogram_bins = torch.cat( [torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])] ) self.hist_counts = torch.zeros(hist_num_bins - 1) self.abs_hist_counts = torch.zeros(hist_num_bins - 1) self.collect_stats = True def disable_collect_stats(self): self.collect_stats = False def enable_collect_stats(self): self.collect_stats = True def set_threshold(self, threshold: float): self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=False) def forward(self, x): x = self.wrapped_module(x) if self.collect_stats: self.hist_counts += torch.histogram(x, bins=self.histogram_bins)[0] self.abs_hist_counts += torch.histogram( torch.abs(x), bins=self.histogram_bins )[0] x[abs(x) < self.threshold] = 0 return x # Function to load existing data from a JSON file def load_data(file_path): try: with open(file_path, "r") as json_file: return json.load(json_file) except FileNotFoundError: return {} # Return an empty dictionary if the file does not exist # Function to save the dictionary to a JSON file def save_to_json(data, file_path): os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "w") as json_file: json.dump(data, json_file, indent=4) class CatsConfig(PretrainedConfig): model_type = "cats_model" def __init__( self, wrapped_model_config=AutoConfig.from_pretrained(MISTRAL_7B), wrapped_model_class_name: str = "MistralForCausalLM", target_modules: List[str] = ["act_fn"], target_sparsity: float = 0.5, **kwargs, ): self.target_modules = target_modules self.target_sparsity = target_sparsity self.wrapped_model_class_name = wrapped_model_class_name self.__dict__.update(wrapped_model_config.__dict__) super().__init__(**kwargs) class CatsModel(PreTrainedModel): config_class = CatsConfig def __init__(self, config, wrapped_model_pretrained_dir: str = None, **kwargs): super().__init__(config) transformers_module = importlib.import_module("transformers") self.wrapped_model_class = getattr(transformers_module, config.wrapped_model_class_name) self.wrapped_model = self.wrapped_model_class(config) if wrapped_model_pretrained_dir is not None: self.wrapped_model = self.wrapped_model_class.from_pretrained(wrapped_model_pretrained_dir) print(self.__dict__) self.inject_cats() def inject_cats(self): for name, module in self.wrapped_model.named_modules(): parent, target, target_name = _get_submodules(self.wrapped_model, name) if target_name in self.config.target_modules: print(f"{name} is replaced.") # Replace target module with target module + CATS cats = Cats(wrapped_module=target) setattr(parent, target_name, cats) def enable_collect_stats(self): for module in self.wrapped_model.named_modules(): if isinstance(module, Cats): module.enable_collect_stats() def disable_adapters(self) -> None: for module in self.wrapped_model.named_modules(): if isinstance(module, Cats): module.disable_collect_stats() # def __getattr__(self, name: str): # """Forward missing attributes to the wrapped module.""" # try: # return super().__getattr__(name) # defer to nn.Module's logic # except AttributeError: # return getattr(self.model, name) def simple_exp(): model_dir = MISTRAL_7B config = AutoConfig.from_pretrained(model_dir) cats_config = CatsConfig(config, wrapped_model_class_name="MistralForCausalLM") model = CatsModel(cats_config, wrapped_model_pretrained_dir=None) print(model) print(model.wrapped_model) print(model.config) CatsConfig.register_for_auto_class() CatsModel.register_for_auto_class("AutoModelForCausalLM") repo_id = "thrunlab/cats_exp" model.push_to_hub(repo_id) model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True) if __name__ == "__main__": simple_exp()