victan commited on
Commit
8192381
1 Parent(s): c24ca9f

Upload seamless_communication/cli/m4t/finetune/finetune.py with huggingface_hub

Browse files
seamless_communication/cli/m4t/finetune/finetune.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from fairseq2.models.nllb.tokenizer import NllbTokenizer
14
+
15
+ from seamless_communication.cli.m4t.finetune import dataloader, dist_utils, trainer
16
+ from seamless_communication.models.unity import (
17
+ UnitTokenizer,
18
+ UnitYModel,
19
+ load_unity_model,
20
+ load_unity_text_tokenizer,
21
+ load_unity_unit_tokenizer,
22
+ )
23
+
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
27
+ )
28
+
29
+ logger = logging.getLogger("finetune")
30
+
31
+
32
+ def init_parser() -> argparse.ArgumentParser:
33
+ parser = argparse.ArgumentParser(
34
+ description="Example finetuning script for M4T models"
35
+ )
36
+ parser.add_argument(
37
+ "--train_dataset",
38
+ type=Path,
39
+ required=True,
40
+ help="Path to manifest with train samples",
41
+ )
42
+ parser.add_argument(
43
+ "--eval_dataset",
44
+ type=Path,
45
+ required=True,
46
+ help="Path to manifest with eval samples",
47
+ )
48
+ parser.add_argument(
49
+ "--model_name",
50
+ type=str,
51
+ default="seamlessM4T_medium",
52
+ help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
53
+ )
54
+ parser.add_argument(
55
+ "--save_model_to",
56
+ type=Path,
57
+ required=True,
58
+ help="Path to save best finetuned model",
59
+ )
60
+ parser.add_argument(
61
+ "--seed",
62
+ type=int,
63
+ default=2343,
64
+ help="Randomizer seed value",
65
+ )
66
+ parser.add_argument(
67
+ "--batch_size",
68
+ type=int,
69
+ default=5,
70
+ help="Batch size for training and evaluation",
71
+ )
72
+ parser.add_argument(
73
+ "--patience",
74
+ type=int,
75
+ default=3,
76
+ help=(
77
+ "Set early termination after `patience` number of evaluations "
78
+ "without eval loss improvements"
79
+ ),
80
+ )
81
+ parser.add_argument(
82
+ "--max_epochs",
83
+ type=int,
84
+ default=10,
85
+ help=("Max number of training epochs"),
86
+ )
87
+ parser.add_argument(
88
+ "--learning_rate",
89
+ type=float,
90
+ default=1e-7,
91
+ help=("Finetuning learning rate"),
92
+ )
93
+ parser.add_argument(
94
+ "--warmup_steps",
95
+ type=int,
96
+ default=100,
97
+ help=("Number of steps with linearly increasing learning rate"),
98
+ )
99
+ parser.add_argument(
100
+ "--eval_steps",
101
+ type=int,
102
+ default=50,
103
+ help=("Get eval loss after each `eval_steps` training steps "),
104
+ )
105
+ parser.add_argument(
106
+ "--log_steps",
107
+ type=int,
108
+ default=10,
109
+ help=("Log inner loss after each `log_steps` training steps"),
110
+ )
111
+ parser.add_argument(
112
+ "--mode",
113
+ type=trainer.FinetuneMode,
114
+ choices=list(trainer.FinetuneMode),
115
+ default=trainer.FinetuneMode.SPEECH_TO_TEXT,
116
+ help=(
117
+ "* `SPEECH_TO_SPEECH` -- finetune S2T and T2U parts of the model; "
118
+ "* `TEXT_TO_SPEECH` -- finetune only T2U; "
119
+ "* `SPEECH_TO_TEXT` -- finetune only S2T"
120
+ ),
121
+ )
122
+ return parser
123
+
124
+
125
+ def main() -> None:
126
+ args = init_parser().parse_args()
127
+ dist_utils.init_distributed([logger, trainer.logger])
128
+ device = torch.device("cuda")
129
+ text_tokenizer: NllbTokenizer = load_unity_text_tokenizer(args.model_name)
130
+ unit_tokenizer: UnitTokenizer = load_unity_unit_tokenizer(args.model_name)
131
+ finetune_params = trainer.FinetuneParams(
132
+ finetune_mode=args.mode,
133
+ save_model_path=args.save_model_to,
134
+ device=device,
135
+ train_batch_size=args.batch_size,
136
+ eval_batch_size=args.batch_size,
137
+ patience=args.patience,
138
+ max_epochs=args.max_epochs,
139
+ learning_rate=args.learning_rate,
140
+ warmup_steps=args.warmup_steps,
141
+ eval_steps=args.eval_steps,
142
+ log_steps=args.log_steps,
143
+ )
144
+ logger.info(f"Finetune params: {finetune_params}")
145
+ model: UnitYModel = load_unity_model(
146
+ args.model_name, device=finetune_params.device, dtype=torch.float16
147
+ )
148
+ logger.info(f"Model {model}")
149
+ assert model.target_vocab_info == text_tokenizer.vocab_info
150
+ assert model.t2u_model is not None
151
+ assert model.t2u_model.target_vocab_info == unit_tokenizer.vocab_info
152
+
153
+ train_dataloader = dataloader.UnitYDataLoader(
154
+ text_tokenizer=text_tokenizer,
155
+ unit_tokenizer=unit_tokenizer,
156
+ batching_config=dataloader.BatchingConfig(
157
+ batch_size=finetune_params.train_batch_size,
158
+ rank=dist_utils.get_rank(),
159
+ world_size=dist_utils.get_world_size(),
160
+ ),
161
+ dataset_manifest_path=args.train_dataset,
162
+ )
163
+ eval_dataloader = dataloader.UnitYDataLoader(
164
+ text_tokenizer=text_tokenizer,
165
+ unit_tokenizer=unit_tokenizer,
166
+ batching_config=dataloader.BatchingConfig(
167
+ batch_size=finetune_params.eval_batch_size,
168
+ rank=dist_utils.get_rank(),
169
+ world_size=dist_utils.get_world_size(),
170
+ ),
171
+ dataset_manifest_path=args.eval_dataset,
172
+ )
173
+ finetune = trainer.UnitYFinetune(
174
+ model=model,
175
+ params=finetune_params,
176
+ train_data_loader=train_dataloader,
177
+ eval_data_loader=eval_dataloader,
178
+ )
179
+ finetune.run()
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()