Upload seamless_communication/models/aligner/loader.py with huggingface_hub
Browse files
seamless_communication/models/aligner/loader.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# MIT_LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Any, List, Mapping
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from fairseq2.assets import asset_store, download_manager
|
11 |
+
from fairseq2.models.utils import ConfigLoader, ModelLoader
|
12 |
+
|
13 |
+
from seamless_communication.models.aligner.builder import (
|
14 |
+
UnitY2AlignmentConfig,
|
15 |
+
aligner_archs,
|
16 |
+
create_unity2_alignment_model,
|
17 |
+
)
|
18 |
+
from seamless_communication.models.aligner.model import UnitY2AlignmentModel
|
19 |
+
from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
|
20 |
+
|
21 |
+
|
22 |
+
def convert_unity2_aligner_checkpoint(
|
23 |
+
checkpoint: Mapping[str, Any], config: UnitY2AlignmentConfig
|
24 |
+
) -> Mapping[str, Any]:
|
25 |
+
if (
|
26 |
+
"model" in checkpoint
|
27 |
+
and "alignment_encoder.t_conv.1.weight" in checkpoint["model"]
|
28 |
+
):
|
29 |
+
return checkpoint
|
30 |
+
|
31 |
+
alignment_frontend_statedict = {}
|
32 |
+
text_emb_state_keymap = {"weight": "alignment_frontend.embed_text.weight"}
|
33 |
+
for k, v in checkpoint["text_emb_state"].items():
|
34 |
+
alignment_frontend_statedict[text_emb_state_keymap[k]] = v
|
35 |
+
|
36 |
+
unit_emb_state_keymap = {"weight": "alignment_frontend.embed_unit.weight"}
|
37 |
+
for k, v in checkpoint["unit_emb_state"].items():
|
38 |
+
alignment_frontend_statedict[unit_emb_state_keymap[k]] = v
|
39 |
+
|
40 |
+
alignment_encoder_state_dict = {}
|
41 |
+
for k, v in checkpoint["aligner_state"].items():
|
42 |
+
alignment_encoder_state_dict[f"alignment_encoder.{k}"] = v
|
43 |
+
|
44 |
+
model_state = {
|
45 |
+
**alignment_encoder_state_dict,
|
46 |
+
**alignment_frontend_statedict,
|
47 |
+
}
|
48 |
+
|
49 |
+
char_embeds = model_state["alignment_frontend.embed_text.weight"]
|
50 |
+
|
51 |
+
index_mapping = _get_char_index_mapping(config)
|
52 |
+
vocab_size = len(index_mapping)
|
53 |
+
char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
|
54 |
+
|
55 |
+
checkpoint["model"] = model_state
|
56 |
+
|
57 |
+
return checkpoint
|
58 |
+
|
59 |
+
|
60 |
+
def _get_char_index_mapping(config: UnitY2AlignmentConfig) -> List[int]:
|
61 |
+
char_tokenizer = load_unity_char_tokenizer(config.model_name_or_card)
|
62 |
+
spm_order = [
|
63 |
+
char_tokenizer.model.index_to_token(i)
|
64 |
+
for i in range(char_tokenizer.model.vocabulary_size)
|
65 |
+
][4:]
|
66 |
+
spm_to_dict_mapping = {
|
67 |
+
ch: idx
|
68 |
+
for (idx, ch) in zip(
|
69 |
+
range(4, char_tokenizer.model.vocabulary_size),
|
70 |
+
sorted(spm_order),
|
71 |
+
)
|
72 |
+
}
|
73 |
+
model_to_dict_mapping = [0, 1, 2, 3] + [spm_to_dict_mapping[ch] for ch in spm_order]
|
74 |
+
return model_to_dict_mapping
|
75 |
+
|
76 |
+
|
77 |
+
load_unity2_alignment_config = ConfigLoader[UnitY2AlignmentConfig](
|
78 |
+
asset_store, aligner_archs
|
79 |
+
)
|
80 |
+
|
81 |
+
load_unity2_alignment_model = ModelLoader[UnitY2AlignmentModel, UnitY2AlignmentConfig](
|
82 |
+
asset_store,
|
83 |
+
download_manager,
|
84 |
+
load_unity2_alignment_config,
|
85 |
+
create_unity2_alignment_model,
|
86 |
+
convert_unity2_aligner_checkpoint,
|
87 |
+
restrict_checkpoints=False,
|
88 |
+
)
|