|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Preprocessing script before training DistilBERT. |
|
Specific to BERT -> DistilBERT. |
|
""" |
|
import argparse |
|
|
|
import torch |
|
|
|
from transformers import BertForMaskedLM |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description=( |
|
"Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned" |
|
" Distillation" |
|
) |
|
) |
|
parser.add_argument("--model_type", default="bert", choices=["bert"]) |
|
parser.add_argument("--model_name", default="bert-base-uncased", type=str) |
|
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str) |
|
parser.add_argument("--vocab_transform", action="store_true") |
|
args = parser.parse_args() |
|
|
|
if args.model_type == "bert": |
|
model = BertForMaskedLM.from_pretrained(args.model_name) |
|
prefix = "bert" |
|
else: |
|
raise ValueError('args.model_type should be "bert".') |
|
|
|
state_dict = model.state_dict() |
|
compressed_sd = {} |
|
|
|
for w in ["word_embeddings", "position_embeddings"]: |
|
compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"] |
|
for w in ["weight", "bias"]: |
|
compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"] |
|
|
|
std_idx = 0 |
|
for teacher_idx in [0, 2, 4, 7, 9, 11]: |
|
for w in ["weight", "bias"]: |
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}" |
|
] |
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}" |
|
] |
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}" |
|
] |
|
|
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}" |
|
] |
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}" |
|
] |
|
|
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}" |
|
] |
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}" |
|
] |
|
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}" |
|
] |
|
std_idx += 1 |
|
|
|
compressed_sd["vocab_projector.weight"] = state_dict["cls.predictions.decoder.weight"] |
|
compressed_sd["vocab_projector.bias"] = state_dict["cls.predictions.bias"] |
|
if args.vocab_transform: |
|
for w in ["weight", "bias"]: |
|
compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"] |
|
compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"] |
|
|
|
print(f"N layers selected for distillation: {std_idx}") |
|
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}") |
|
|
|
print(f"Save transferred checkpoint to {args.dump_checkpoint}.") |
|
torch.save(compressed_sd, args.dump_checkpoint) |
|
|