shethjenil commited on
Commit
e0ac99f
·
verified ·
1 Parent(s): 82be5fe

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.yaml +105 -0
  2. main.py +195 -0
config.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/Output"
2
+ save_freq: 5
3
+ log_interval: 10
4
+ device: "cuda"
5
+ epochs: 50
6
+ batch_size: 8
7
+ max_len: 400
8
+ pretrained_model: ""
9
+ second_stage_load_pretrained: true
10
+ load_only_params: true
11
+
12
+ external_models:
13
+ asr:
14
+ input_dim: 80
15
+ hidden_dim: 256
16
+ n_token: 178
17
+ plbert:
18
+ vocab_size: 178
19
+ hidden_size: 768
20
+ num_attention_heads: 12
21
+ intermediate_size: 2048
22
+ dropout: 0.1
23
+
24
+ data_params:
25
+ train_data: "shethjenil/audiodata"
26
+ root_path: ""
27
+ min_length: 50
28
+
29
+ preprocess_params:
30
+ sr: 24000
31
+ n_fft: 2048
32
+ win_length: 1200
33
+ hop_length: 300
34
+
35
+ model_params:
36
+ multispeaker: true
37
+ dim_in: 64
38
+ hidden_dim: 128
39
+ max_conv_dim: 512
40
+ n_layer: 2
41
+ n_mels: 80
42
+ n_token: 178
43
+ max_dur: 50
44
+ style_dim: 128
45
+ dropout: 0.2
46
+ decoder:
47
+ type: "istftnet"
48
+ hidden_dim: 256
49
+ decoder_out_dim: 256
50
+ asr_res_in: 128
51
+ resblock_kernel_sizes: [3, 3]
52
+ upsample_rates: [10, 6]
53
+ upsample_initial_channel: 256
54
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]]
55
+ upsample_kernel_sizes: [20, 12]
56
+ gen_istft_n_fft: 20
57
+ gen_istft_hop_size: 5
58
+ disable_complex: true
59
+ slm:
60
+ model: "microsoft/wavlm-base-plus"
61
+ sr: 16000
62
+ hidden: 768
63
+ nlayers: 13
64
+ initial_channel: 64
65
+ diffusion:
66
+ embedding_mask_proba: 0.1
67
+ transformer:
68
+ num_layers: 3
69
+ num_heads: 8
70
+ head_features: 64
71
+ multiplier: 2
72
+ dist:
73
+ sigma_data: 0.2
74
+ estimate_sigma_data: true
75
+ mean: -3.0
76
+ std: 1.0
77
+
78
+ loss_params:
79
+ lambda_mel: 5.0
80
+ lambda_gen: 1.0
81
+ lambda_slm: 1.0
82
+ lambda_mono: 1.0
83
+ lambda_s2s: 1.0
84
+ lambda_F0: 1.0
85
+ lambda_norm: 1.0
86
+ lambda_dur: 1.0
87
+ lambda_ce: 20.0
88
+ lambda_sty: 1.0
89
+ lambda_diff: 1.0
90
+ diff_epoch: 10
91
+ joint_epoch: 30
92
+
93
+ optimizer_params:
94
+ lr: 0.0001
95
+ bert_lr: 0.00001
96
+ ft_lr: 0.0001
97
+
98
+ slmadv_params:
99
+ min_len: 400
100
+ max_len: 500
101
+ batch_percentage: 0.5
102
+ iter: 10
103
+ thresh: 5.0
104
+ scale: 0.01
105
+ sig: 1.5
main.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%bash
2
+ # pip uninstall -q styletts2 -y
3
+ # pip install git+https://github.com/dummyjenil/styletts2 -q
4
+
5
+
6
+ import torch
7
+ from torch import nn
8
+ import onnx
9
+ import torch.nn.utils.parametrize as parametrize
10
+ from styletts2.models.styletts2 import StyleTTS2Model, StyleTTS2Config
11
+ from onnx_toolkit import ONNXParser
12
+ from safetensors.torch import save_file
13
+ from huggingface_hub import hf_hub_download
14
+
15
+
16
+ # -----------------------------
17
+ # Utils
18
+ # -----------------------------
19
+ def get_layer_from_key(model: nn.Module, key: str):
20
+ module = model
21
+ for part in key.split(".")[:-1]:
22
+ module = module[int(part)] if part.isdigit() else getattr(module, part)
23
+ return module
24
+
25
+
26
+ def is_bidirectional(node):
27
+ for attr in node.attribute:
28
+ if attr.name == "direction":
29
+ value = onnx.helper.get_attribute_value(attr)
30
+ if isinstance(value, bytes):
31
+ value = value.decode("utf-8")
32
+ return value == "bidirectional"
33
+ return False
34
+
35
+
36
+ # -----------------------------
37
+ # LSTM Conversion
38
+ # -----------------------------
39
+ def convert_onnx_lstm(W, R, B, layer_name="lstm", bidirectional=False):
40
+ def reorder_gates(w):
41
+ i, o, f, c = torch.chunk(w, 4, dim=0)
42
+ return torch.cat([i, f, c, o], dim=0)
43
+
44
+ state_dict = {}
45
+
46
+ for d in range(W.shape[0]):
47
+ suffix = "" if d == 0 else "_reverse"
48
+
49
+ w_ih = reorder_gates(torch.tensor(W[d]))
50
+ w_hh = reorder_gates(torch.tensor(R[d]))
51
+
52
+ b_ih, b_hh = torch.chunk(torch.tensor(B[d]), 2, dim=0)
53
+ b_ih = reorder_gates(b_ih)
54
+ b_hh = reorder_gates(b_hh)
55
+
56
+ state_dict[f"{layer_name}.weight_ih_l0{suffix}"] = w_ih
57
+ state_dict[f"{layer_name}.weight_hh_l0{suffix}"] = w_hh
58
+ state_dict[f"{layer_name}.bias_ih_l0{suffix}"] = b_ih
59
+ state_dict[f"{layer_name}.bias_hh_l0{suffix}"] = b_hh
60
+
61
+ return state_dict
62
+
63
+
64
+ # -----------------------------
65
+ # MAIN FUNCTION (UPDATED)
66
+ # -----------------------------
67
+ def convert_model_from_local(
68
+ onnx_path: str,
69
+ config,
70
+ save_path: str
71
+ ):
72
+ """
73
+ onnx_path: local path (hf_hub_download se mila hua bhi chalega)
74
+ config: StyleTTS2Config OR dict
75
+ save_path: output safetensors path
76
+ """
77
+
78
+ # ---- Config handling ----
79
+ if isinstance(config, dict):
80
+ config = StyleTTS2Config(**config)
81
+
82
+ model = StyleTTS2Model(config)
83
+
84
+ # Remove unused modules
85
+ for attr in [
86
+ "wd", "msd", "mpd", "pitch_extractor",
87
+ "text_aligner", "diffusion",
88
+ "predictor_encoder", "style_encoder"
89
+ ]:
90
+ if hasattr(model, attr):
91
+ delattr(model, attr)
92
+
93
+ # Remove parametrizations
94
+ for module in model.modules():
95
+ if hasattr(module, "parametrizations") and "weight" in module.parametrizations:
96
+ parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
97
+
98
+ state_dict = model.state_dict()
99
+
100
+ # ---- Load ONNX (LOCAL PATH) ----
101
+ m = ONNXParser(onnx_path)
102
+
103
+ # -------- LSTM handling --------
104
+ pytorch_lstm = {
105
+ name: module for name, module in model.named_modules()
106
+ if isinstance(module, nn.LSTM)
107
+ }
108
+
109
+ onnx_lstm_layers = {
110
+ i.name.rstrip("/LSTM").lstrip("/").replace("/", "."): i
111
+ for i in m.find().find_by_op_type("LSTM")
112
+ }
113
+
114
+ predefined_dict = {}
115
+
116
+ for pt_name in pytorch_lstm:
117
+ if pt_name not in onnx_lstm_layers:
118
+ continue
119
+
120
+ node = onnx_lstm_layers[pt_name]
121
+ block = m.find().find_by_name(node.name, exact=True)
122
+
123
+ tensors = list(block.tensor().values())
124
+ if len(tensors) != 3:
125
+ continue
126
+
127
+ w, r, b = tensors
128
+
129
+ converted = convert_onnx_lstm(
130
+ w, r, b,
131
+ pt_name,
132
+ is_bidirectional(block.single_node)
133
+ )
134
+
135
+ predefined_dict.update(converted)
136
+
137
+ # -------- Build state_dict --------
138
+ finder = m.find()
139
+ new_state_dict = {}
140
+
141
+ for name, tensor in state_dict.items():
142
+ full_key = "kmodel." + name
143
+ results = finder.find_by_tensor(full_key)
144
+
145
+ if results:
146
+ new_state_dict[name] = torch.tensor(results[0].tensor()[full_key])
147
+ continue
148
+
149
+ module = get_layer_from_key(model, name)
150
+
151
+ if isinstance(module, nn.LSTM) and name in predefined_dict:
152
+ new_state_dict[name] = predefined_dict[name]
153
+ continue
154
+
155
+ new_state_dict[name] = tensor # fallback
156
+
157
+ # Load weights
158
+ new_state_dict['decoder.generator.stft.window'] = model.decoder.generator.stft.window.clone()
159
+ model.load_state_dict(new_state_dict)
160
+
161
+ # Make contiguous
162
+ final_sd = {
163
+ k: v.contiguous() if not v.is_contiguous() else v
164
+ for k, v in model.state_dict().items()
165
+ }
166
+
167
+ save_file(final_sd, save_path)
168
+
169
+
170
+
171
+ # config = StyleTTS2Config.from_yaml("config.yaml")
172
+ # convert_model_from_local(
173
+ # onnx_path=hf_hub_download("KittenML/kitten-tts-nano-0.8-fp32","kitten_tts_nano_v0_8.onnx"),
174
+ # config=config,
175
+ # save_path="mini_model.safetensors"
176
+ # )
177
+
178
+
179
+
180
+ # ----
181
+ # quantization baaki hai
182
+
183
+
184
+ # config.model_params.style_dim = 512
185
+ # config.model_params.hidden_dim = 512
186
+ # config.model_params.decoder.hidden_dim = 1024
187
+ # config.model_params.decoder.decoder_out_dim = 512
188
+ # config.model_params.decoder.asr_res_in = 256
189
+ # config.model_params.decoder.upsample_initial_channel = 512
190
+ # convert_model_from_local(
191
+ # onnx_path=hf_hub_download("KittenML/kitten-tts-mini-0.8","kitten_tts_mini_v0_8.onnx"),
192
+ # config=config,
193
+ # save_path="mini_model.safetensors"
194
+ # )
195
+