|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Mapping |
|
|
|
import torch |
|
from fairseq2.assets import asset_store, download_manager |
|
from fairseq2.models.utils import ConfigLoader, ModelLoader |
|
from fairseq2.models.utils.checkpoint import convert_fairseq_checkpoint |
|
|
|
from seamless_communication.models.monotonic_decoder.builder import ( |
|
MonotonicDecoderConfig, |
|
create_monotonic_decoder_model, |
|
monotonic_decoder_archs, |
|
) |
|
from seamless_communication.models.monotonic_decoder.model import MonotonicDecoderModel |
|
|
|
|
|
def convert_monotonic_checkpoint( |
|
checkpoint: Mapping[str, Any], config: MonotonicDecoderConfig |
|
) -> Mapping[str, Any]: |
|
state_dict = checkpoint["model"] |
|
|
|
|
|
if "text_decoder.layers.0.self_attn.k_proj.weight" in state_dict: |
|
return checkpoint |
|
|
|
key_map = { |
|
|
|
r"^decoder\.embed_tokens\.": r"text_decoder_frontend.embed.", |
|
r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"text_decoder.layers.\1.self_attn.output_proj.", |
|
r"^decoder\.layers\.([0-9]+)\.self_attn\.": r"text_decoder.layers.\1.self_attn.", |
|
r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"text_decoder.layers.\1.self_attn_layer_norm.", |
|
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"text_decoder.layers.\1.encoder_decoder_attn.output_proj.", |
|
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.energy_bias": r"text_decoder.layers.\1.p_choose_layer.energy_bias", |
|
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.source_energy_layer\.": r"text_decoder.layers.\1.p_choose_layer.k_energy_proj.", |
|
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.target_energy_layer\.": r"text_decoder.layers.\1.p_choose_layer.q_energy_proj.", |
|
r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"text_decoder.layers.\1.encoder_decoder_attn.", |
|
r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"text_decoder.layers.\1.encoder_decoder_attn_layer_norm.", |
|
r"^decoder\.layers\.([0-9]+)\.fc1\.": r"text_decoder.layers.\1.ffn.inner_proj.", |
|
r"^decoder\.layers\.([0-9]+)\.fc2\.": r"text_decoder.layers.\1.ffn.output_proj.", |
|
r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"text_decoder.layers.\1.ffn_layer_norm.", |
|
r"^decoder\.layer_norm\.": r"text_decoder.layer_norm.", |
|
r"^decoder\.output_projection\.": r"final_proj.", |
|
|
|
} |
|
|
|
|
|
checkpoint = convert_fairseq_checkpoint(checkpoint, key_map) |
|
|
|
state_dict = checkpoint["model"] |
|
|
|
embeds = state_dict["final_proj.weight"] |
|
|
|
|
|
|
|
if embeds.size(0) == 256103: |
|
embeds = embeds[:-1] |
|
|
|
state_dict["final_proj.weight"] = embeds |
|
|
|
|
|
|
|
state_dict["text_decoder_frontend.embed.weight"] = embeds |
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]] |
|
|
|
return checkpoint |
|
|
|
|
|
load_monotonic_decoder_config = ConfigLoader[MonotonicDecoderConfig]( |
|
asset_store, monotonic_decoder_archs |
|
) |
|
|
|
|
|
load_monotonic_decoder_model = ModelLoader[ |
|
MonotonicDecoderModel, MonotonicDecoderConfig |
|
]( |
|
asset_store, |
|
download_manager, |
|
load_monotonic_decoder_config, |
|
create_monotonic_decoder_model, |
|
convert_monotonic_checkpoint, |
|
restrict_checkpoints=False, |
|
) |
|
|