# Copyright 2020-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all. For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded) as a standard :class:`~transformers.BertForSequenceClassification`. """ import argparse import os import shutil import torch from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer def main(args): pruning_method = args.pruning_method threshold = args.threshold model_name_or_path = args.model_name_or_path.rstrip("/") target_model_path = args.target_model_path print(f"Load fine-pruned model from {model_name_or_path}") model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin")) pruned_model = {} for name, tensor in model.items(): if "embeddings" in name or "LayerNorm" in name or "pooler" in name: pruned_model[name] = tensor print(f"Copied layer {name}") elif "classifier" in name or "qa_output" in name: pruned_model[name] = tensor print(f"Copied layer {name}") elif "bias" in name: pruned_model[name] = tensor print(f"Copied layer {name}") else: if pruning_method == "magnitude": mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold) pruned_model[name] = tensor * mask print(f"Pruned layer {name}") elif pruning_method == "topK": if "mask_scores" in name: continue prefix_ = name[:-6] scores = model[f"{prefix_}mask_scores"] mask = TopKBinarizer.apply(scores, threshold) pruned_model[name] = tensor * mask print(f"Pruned layer {name}") elif pruning_method == "sigmoied_threshold": if "mask_scores" in name: continue prefix_ = name[:-6] scores = model[f"{prefix_}mask_scores"] mask = ThresholdBinarizer.apply(scores, threshold, True) pruned_model[name] = tensor * mask print(f"Pruned layer {name}") elif pruning_method == "l0": if "mask_scores" in name: continue prefix_ = name[:-6] scores = model[f"{prefix_}mask_scores"] l, r = -0.1, 1.1 s = torch.sigmoid(scores) s_bar = s * (r - l) + l mask = s_bar.clamp(min=0.0, max=1.0) pruned_model[name] = tensor * mask print(f"Pruned layer {name}") else: raise ValueError("Unknown pruning method") if target_model_path is None: target_model_path = os.path.join( os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}" ) if not os.path.isdir(target_model_path): shutil.copytree(model_name_or_path, target_model_path) print(f"\nCreated folder {target_model_path}") torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin")) print("\nPruned model saved! See you later!") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--pruning_method", choices=["l0", "magnitude", "topK", "sigmoied_threshold"], type=str, required=True, help=( "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning," " sigmoied_threshold = Soft movement pruning)" ), ) parser.add_argument( "--threshold", type=float, required=False, help=( "For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model." "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared." "Not needed for `l0`" ), ) parser.add_argument( "--model_name_or_path", type=str, required=True, help="Folder containing the model that was previously fine-pruned", ) parser.add_argument( "--target_model_path", default=None, type=str, required=False, help="Folder containing the model that was previously fine-pruned", ) args = parser.parse_args() main(args)