import re import jax.numpy as jnp from flax.traverse_util import flatten_dict, unflatten_dict import torch from modeling_flax_vqgan import VQModel from configuration_vqgan import VQGANConfig regex = r"\w+[.]\d+" def rename_key(key): pats = re.findall(regex, key) for pat in pats: key = key.replace(pat, "_".join(pat.split("."))) return key # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61 def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} random_flax_state_dict = flatten_dict(flax_model.params) flax_state_dict = {} remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) ) add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and ( flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) ) # Need to change some parameters name to match Flax names so that we don't have to fork any layer for pt_key, pt_tensor in pt_state_dict.items(): pt_tuple_key = tuple(pt_key.split(".")) has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict if remove_base_model_prefix and has_base_model_prefix: pt_tuple_key = pt_tuple_key[1:] elif add_base_model_prefix and require_base_model_prefix: pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key # Correctly rename weight parameters if ( "norm" in pt_key and (pt_tuple_key[-1] == "bias") and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict) ): pt_tensor = pt_tensor[None, None, None, :] elif ( "norm" in pt_key and (pt_tuple_key[-1] == "bias") and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) ): pt_tuple_key = pt_tuple_key[:-1] + ("scale",) pt_tensor = pt_tensor[None, None, None, :] elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("scale",) pt_tensor = pt_tensor[None, None, None, :] if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: # conv layer pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) pt_tensor = pt_tensor.transpose(2, 3, 1, 0) elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: # linear layer pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) pt_tensor = pt_tensor.T elif pt_tuple_key[-1] == "gamma": pt_tuple_key = pt_tuple_key[:-1] + ("weight",) elif pt_tuple_key[-1] == "beta": pt_tuple_key = pt_tuple_key[:-1] + ("bias",) if pt_tuple_key in random_flax_state_dict: if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: raise ValueError( f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." ) # also add unexpected weight so that warning is thrown flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) return unflatten_dict(flax_state_dict) def convert_model(config_path, pt_state_dict_path, save_path): config = VQGANConfig.from_pretrained(config_path) model = VQModel(config) state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"] keys = list(state_dict.keys()) for key in keys: if key.startswith("loss"): state_dict.pop(key) continue renamed_key = rename_key(key) state_dict[renamed_key] = state_dict.pop(key) state = convert_pytorch_state_dict_to_flax(state_dict, model) model.params = unflatten_dict(state) model.save_pretrained(save_path)