victan's picture
Upload seamless_communication/models/aligner/loader.py with huggingface_hub
5c71b2a
raw
history blame contribute delete
No virus
2.92 kB
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
from typing import Any, List, Mapping
import torch
from fairseq2.assets import asset_store, download_manager
from fairseq2.models.utils import ConfigLoader, ModelLoader
from seamless_communication.models.aligner.builder import (
UnitY2AlignmentConfig,
aligner_archs,
create_unity2_alignment_model,
)
from seamless_communication.models.aligner.model import UnitY2AlignmentModel
from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
def convert_unity2_aligner_checkpoint(
checkpoint: Mapping[str, Any], config: UnitY2AlignmentConfig
) -> Mapping[str, Any]:
if (
"model" in checkpoint
and "alignment_encoder.t_conv.1.weight" in checkpoint["model"]
):
return checkpoint
alignment_frontend_statedict = {}
text_emb_state_keymap = {"weight": "alignment_frontend.embed_text.weight"}
for k, v in checkpoint["text_emb_state"].items():
alignment_frontend_statedict[text_emb_state_keymap[k]] = v
unit_emb_state_keymap = {"weight": "alignment_frontend.embed_unit.weight"}
for k, v in checkpoint["unit_emb_state"].items():
alignment_frontend_statedict[unit_emb_state_keymap[k]] = v
alignment_encoder_state_dict = {}
for k, v in checkpoint["aligner_state"].items():
alignment_encoder_state_dict[f"alignment_encoder.{k}"] = v
model_state = {
**alignment_encoder_state_dict,
**alignment_frontend_statedict,
}
char_embeds = model_state["alignment_frontend.embed_text.weight"]
index_mapping = _get_char_index_mapping(config)
vocab_size = len(index_mapping)
char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
checkpoint["model"] = model_state
return checkpoint
def _get_char_index_mapping(config: UnitY2AlignmentConfig) -> List[int]:
char_tokenizer = load_unity_char_tokenizer(config.model_name_or_card)
spm_order = [
char_tokenizer.model.index_to_token(i)
for i in range(char_tokenizer.model.vocabulary_size)
][4:]
spm_to_dict_mapping = {
ch: idx
for (idx, ch) in zip(
range(4, char_tokenizer.model.vocabulary_size),
sorted(spm_order),
)
}
model_to_dict_mapping = [0, 1, 2, 3] + [spm_to_dict_mapping[ch] for ch in spm_order]
return model_to_dict_mapping
load_unity2_alignment_config = ConfigLoader[UnitY2AlignmentConfig](
asset_store, aligner_archs
)
load_unity2_alignment_model = ModelLoader[UnitY2AlignmentModel, UnitY2AlignmentConfig](
asset_store,
download_manager,
load_unity2_alignment_config,
create_unity2_alignment_model,
convert_unity2_aligner_checkpoint,
restrict_checkpoints=False,
)