victan commited on
Commit
5c71b2a
·
1 Parent(s): 8266600

Upload seamless_communication/models/aligner/loader.py with huggingface_hub

Browse files
seamless_communication/models/aligner/loader.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, List, Mapping
8
+
9
+ import torch
10
+ from fairseq2.assets import asset_store, download_manager
11
+ from fairseq2.models.utils import ConfigLoader, ModelLoader
12
+
13
+ from seamless_communication.models.aligner.builder import (
14
+ UnitY2AlignmentConfig,
15
+ aligner_archs,
16
+ create_unity2_alignment_model,
17
+ )
18
+ from seamless_communication.models.aligner.model import UnitY2AlignmentModel
19
+ from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
20
+
21
+
22
+ def convert_unity2_aligner_checkpoint(
23
+ checkpoint: Mapping[str, Any], config: UnitY2AlignmentConfig
24
+ ) -> Mapping[str, Any]:
25
+ if (
26
+ "model" in checkpoint
27
+ and "alignment_encoder.t_conv.1.weight" in checkpoint["model"]
28
+ ):
29
+ return checkpoint
30
+
31
+ alignment_frontend_statedict = {}
32
+ text_emb_state_keymap = {"weight": "alignment_frontend.embed_text.weight"}
33
+ for k, v in checkpoint["text_emb_state"].items():
34
+ alignment_frontend_statedict[text_emb_state_keymap[k]] = v
35
+
36
+ unit_emb_state_keymap = {"weight": "alignment_frontend.embed_unit.weight"}
37
+ for k, v in checkpoint["unit_emb_state"].items():
38
+ alignment_frontend_statedict[unit_emb_state_keymap[k]] = v
39
+
40
+ alignment_encoder_state_dict = {}
41
+ for k, v in checkpoint["aligner_state"].items():
42
+ alignment_encoder_state_dict[f"alignment_encoder.{k}"] = v
43
+
44
+ model_state = {
45
+ **alignment_encoder_state_dict,
46
+ **alignment_frontend_statedict,
47
+ }
48
+
49
+ char_embeds = model_state["alignment_frontend.embed_text.weight"]
50
+
51
+ index_mapping = _get_char_index_mapping(config)
52
+ vocab_size = len(index_mapping)
53
+ char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
54
+
55
+ checkpoint["model"] = model_state
56
+
57
+ return checkpoint
58
+
59
+
60
+ def _get_char_index_mapping(config: UnitY2AlignmentConfig) -> List[int]:
61
+ char_tokenizer = load_unity_char_tokenizer(config.model_name_or_card)
62
+ spm_order = [
63
+ char_tokenizer.model.index_to_token(i)
64
+ for i in range(char_tokenizer.model.vocabulary_size)
65
+ ][4:]
66
+ spm_to_dict_mapping = {
67
+ ch: idx
68
+ for (idx, ch) in zip(
69
+ range(4, char_tokenizer.model.vocabulary_size),
70
+ sorted(spm_order),
71
+ )
72
+ }
73
+ model_to_dict_mapping = [0, 1, 2, 3] + [spm_to_dict_mapping[ch] for ch in spm_order]
74
+ return model_to_dict_mapping
75
+
76
+
77
+ load_unity2_alignment_config = ConfigLoader[UnitY2AlignmentConfig](
78
+ asset_store, aligner_archs
79
+ )
80
+
81
+ load_unity2_alignment_model = ModelLoader[UnitY2AlignmentModel, UnitY2AlignmentConfig](
82
+ asset_store,
83
+ download_manager,
84
+ load_unity2_alignment_config,
85
+ create_unity2_alignment_model,
86
+ convert_unity2_aligner_checkpoint,
87
+ restrict_checkpoints=False,
88
+ )