Spaces:
Paused
Paused
voice_clone_v3
/
transformers
/examples
/research_projects
/seq2seq-distillation
/convert_pl_checkpoint_to_hf.py
#!/usr/bin/env python | |
import os | |
from pathlib import Path | |
from typing import Dict, List | |
import fire | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from transformers.utils.logging import get_logger | |
logger = get_logger(__name__) | |
def remove_prefix(text: str, prefix: str): | |
if text.startswith(prefix): | |
return text[len(prefix) :] | |
return text # or whatever | |
def sanitize(sd): | |
return {remove_prefix(k, "model."): v for k, v in sd.items()} | |
def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]): | |
new_sd = {} | |
for k in state_dicts[0].keys(): | |
tensors = [sd[k] for sd in state_dicts] | |
new_t = sum(tensors) / len(tensors) | |
assert isinstance(new_t, torch.Tensor) | |
new_sd[k] = new_t | |
return new_sd | |
def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None: | |
"""Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict. | |
Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once! | |
Args: | |
pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files. | |
If a directory is passed, all .ckpt files inside it will be averaged! | |
hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint | |
save_path (:obj:`str`): Directory to save the new model | |
""" | |
hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir) | |
if os.path.isfile(pl_ckpt_path): | |
ckpt_files = [pl_ckpt_path] | |
else: | |
assert os.path.isdir(pl_ckpt_path) | |
ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt")) | |
assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory" | |
if len(ckpt_files) > 1: | |
logger.info(f"averaging the weights of {ckpt_files}") | |
state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files] | |
state_dict = average_state_dicts(state_dicts) | |
missing, unexpected = hf_model.load_state_dict(state_dict, strict=False) | |
assert not missing, f"missing keys: {missing}" | |
hf_model.save_pretrained(save_path) | |
try: | |
tok = AutoTokenizer.from_pretrained(hf_src_model_dir) | |
tok.save_pretrained(save_path) | |
except Exception: | |
pass | |
# dont copy tokenizer if cant | |
if __name__ == "__main__": | |
fire.Fire(convert_pl_to_hf) | |