xieli commited on
Commit
6852edb
·
1 Parent(s): 1e0e3cd

audio edit

Browse files

remove useless file

feat: add spaces, change edit_app name

feat: change readme, dockerfile

feat: change readme

feat: add default config

feat: remote useless file

feat: change readme

feat: change readme

feat: change requirements version

feat: change requirements

feat: remove dockerfile

feat: change pkg version

feat: support hf model source

feat: fix model loader

feat: test

feat: fix model loader

feat: fix model cache path

feat: add log

feat: fix download

feat: fix tokenizer

feat: fix tokenizer

feat: fix download

feat: add hf login

feat: add log

feat: remove useless log

feat: fix model loader

feat: fix model loader

feat: add log

feat: fix model loader

feat: rollback code

feat: fix

feat: fix model loader

feat: fix model path

feat: zerogpu

feat: fix

feat: fix app

feat: optimize download

feat: optimize download

feat: change app desc

feat: add log

feat: add log

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -4
  2. .gitignore +2 -0
  3. README.md +13 -1
  4. __init__.py +0 -0
  5. app.py +499 -0
  6. config/__init__.py +12 -0
  7. config/edit_config.py +33 -0
  8. config/prompts.py +62 -0
  9. funasr_detach/__init__.py +38 -0
  10. funasr_detach/auto/__init__.py +0 -0
  11. funasr_detach/auto/auto_frontend.py +90 -0
  12. funasr_detach/auto/auto_model.py +575 -0
  13. funasr_detach/auto/auto_tokenizer.py +7 -0
  14. funasr_detach/bin/__init__.py +0 -0
  15. funasr_detach/bin/compute_audio_cmvn.py +152 -0
  16. funasr_detach/bin/inference.py +33 -0
  17. funasr_detach/bin/tokenize_text.py +281 -0
  18. funasr_detach/bin/train.py +227 -0
  19. funasr_detach/datasets/__init__.py +0 -0
  20. funasr_detach/datasets/audio_datasets/__init__.py +0 -0
  21. funasr_detach/datasets/audio_datasets/datasets.py +112 -0
  22. funasr_detach/datasets/audio_datasets/index_ds.py +150 -0
  23. funasr_detach/datasets/audio_datasets/preprocessor.py +55 -0
  24. funasr_detach/datasets/audio_datasets/samplers.py +306 -0
  25. funasr_detach/datasets/audio_datasets/scp2jsonl.py +116 -0
  26. funasr_detach/download/__init__.py +0 -0
  27. funasr_detach/download/download_dataset_from_hub.py +19 -0
  28. funasr_detach/download/download_from_hub.py +231 -0
  29. funasr_detach/download/file.py +335 -0
  30. funasr_detach/download/name_maps_from_hub.py +13 -0
  31. funasr_detach/download/runtime_sdk_download_tool.py +60 -0
  32. funasr_detach/frontends/__init__.py +0 -0
  33. funasr_detach/frontends/default.py +347 -0
  34. funasr_detach/frontends/eend_ola_feature.py +49 -0
  35. funasr_detach/frontends/fused.py +144 -0
  36. funasr_detach/frontends/s3prl.py +139 -0
  37. funasr_detach/frontends/utils/__init__.py +1 -0
  38. funasr_detach/frontends/utils/beamformer.py +84 -0
  39. funasr_detach/frontends/utils/complex_utils.py +194 -0
  40. funasr_detach/frontends/utils/dnn_beamformer.py +173 -0
  41. funasr_detach/frontends/utils/dnn_wpe.py +93 -0
  42. funasr_detach/frontends/utils/feature_transform.py +263 -0
  43. funasr_detach/frontends/utils/frontend.py +151 -0
  44. funasr_detach/frontends/utils/log_mel.py +83 -0
  45. funasr_detach/frontends/utils/mask_estimator.py +77 -0
  46. funasr_detach/frontends/utils/stft.py +239 -0
  47. funasr_detach/frontends/wav_frontend.py +556 -0
  48. funasr_detach/frontends/windowing.py +74 -0
  49. funasr_detach/losses/__init__.py +0 -0
  50. funasr_detach/losses/label_smoothing_loss.py +125 -0
.gitattributes CHANGED
@@ -1,4 +1,4 @@
1
- assets/*.pdf filter=lfs diff=lfs merge=lfs -text
2
- assets/*.png filter=lfs diff=lfs merge=lfs -text
3
- examples/*.wav filter=lfs diff=lfs merge=lfs -text
4
- * !text !filter !merge !diff
 
1
+ examples filter=lfs diff=lfs merge=lfs -text
2
+ speakers/nezha_prompt.wav filter=lfs diff=lfs merge=lfs -text
3
+ speakers/nezhaRAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
4
+ speakers/nezha哼唱_prompt.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ output/
README.md CHANGED
@@ -1 +1,13 @@
1
- # Step-Audio-EditX
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Step-Audio-EditX
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: true
10
+ short_description: Try out Step-Audio-EditX
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import argparse
4
+ import torch
5
+ import logging
6
+ import threading
7
+ from datetime import datetime
8
+ import torchaudio
9
+ import librosa
10
+ import soundfile as sf
11
+
12
+ # ZeroGPU support
13
+ try:
14
+ import spaces
15
+ ZEROGPU_AVAILABLE = True
16
+ except ImportError:
17
+ ZEROGPU_AVAILABLE = False
18
+ # Create a dummy decorator for non-ZeroGPU environments
19
+ class spaces:
20
+ @staticmethod
21
+ def GPU(duration=10):
22
+ def decorator(func):
23
+ return func
24
+ return decorator
25
+
26
+ # Project imports
27
+ from tokenizer import StepAudioTokenizer
28
+ from tts import StepAudioTTS
29
+ from model_loader import ModelSource
30
+ from config.edit_config import get_supported_edit_types
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # Global variables for ZeroGPU-optimized loading
37
+ encoder = None
38
+ common_tts_engine = None
39
+ args_global = None
40
+ _model_lock = threading.Lock() # Thread lock for model initialization
41
+
42
+ def initialize_models():
43
+ """Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
44
+ global encoder, common_tts_engine, args_global
45
+
46
+ # Fast path: check if already initialized (without lock)
47
+ if common_tts_engine is not None:
48
+ return # Already initialized
49
+
50
+ # Slow path: acquire lock and double-check
51
+ with _model_lock:
52
+ # Double-check pattern: another thread might have initialized while waiting for lock
53
+ if common_tts_engine is not None:
54
+ return # Already initialized by another thread
55
+
56
+ if args_global is None:
57
+ raise RuntimeError("Global args not set. Cannot initialize models.")
58
+
59
+ try:
60
+ logger.info("🚀 Initializing models inside GPU context (first call)...")
61
+
62
+ # Determine model source
63
+ source_mapping = {
64
+ "auto": ModelSource.AUTO,
65
+ "local": ModelSource.LOCAL,
66
+ "modelscope": ModelSource.MODELSCOPE,
67
+ "huggingface": ModelSource.HUGGINGFACE
68
+ }
69
+ model_source = source_mapping[args_global.model_source]
70
+
71
+ # Load StepAudioTokenizer (avoid CUDA initialization in main process)
72
+ encoder = StepAudioTokenizer(
73
+ os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
74
+ model_source=model_source,
75
+ funasr_model_id=args_global.tokenizer_model_id
76
+ )
77
+ logger.info("✓ StepAudioTokenizer loaded")
78
+
79
+ # Initialize common TTS engine (avoid CUDA initialization in main process)
80
+ common_tts_engine = StepAudioTTS(
81
+ os.path.join(args_global.model_path, "Step-Audio-EditX"),
82
+ encoder,
83
+ model_source=model_source,
84
+ tts_model_id=args_global.tts_model_id
85
+ )
86
+ logger.info("✓ StepCommonAudioTTS loaded")
87
+ print("Models initialized inside GPU context.")
88
+
89
+ if ZEROGPU_AVAILABLE:
90
+ logger.info("💡 Models loaded inside GPU context - ready for inference")
91
+ else:
92
+ logger.info("💡 Models loaded - ready for inference")
93
+
94
+ except Exception as e:
95
+ logger.error(f"❌ Error loading models: {e}")
96
+ raise
97
+
98
+ def get_model_config():
99
+ """Get model configuration without initializing GPU models"""
100
+ if args_global is None:
101
+ raise RuntimeError("Global args not set. Cannot get model config.")
102
+
103
+ return {
104
+ "encoder_path": os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
105
+ "tts_path": os.path.join(args_global.model_path, "Step-Audio-EditX"),
106
+ "model_source": args_global.model_source,
107
+ "tokenizer_model_id": args_global.tokenizer_model_id,
108
+ "tts_model_id": args_global.tts_model_id
109
+ }
110
+
111
+ def get_gpu_duration(audio_input, text_input, target_text, task_type, task_info):
112
+ """Dynamic GPU duration based on whether models need initialization"""
113
+ global common_tts_engine
114
+
115
+ if common_tts_engine is None:
116
+ # First call - need time for model loading (up to 5 minutes)
117
+ return 300 # Maximum allowed duration for model initialization
118
+ else:
119
+ # Subsequent calls - only inference time needed
120
+ return 120 # Standard inference duration
121
+
122
+ @spaces.GPU(duration=get_gpu_duration) # Dynamic duration based on model state
123
+ def process_audio_with_gpu(audio_input, text_input, target_text, task_type, task_info):
124
+ """Process audio using GPU (models are loaded inside GPU context to avoid main process errors)"""
125
+ global common_tts_engine
126
+
127
+ # Initialize models if not already loaded (inside GPU context to avoid main process errors)
128
+ if common_tts_engine is None:
129
+ print("Initializing common_tts_engine inside GPU context...")
130
+ logger.info("🎯 GPU allocated for 300s (first call with model loading)...")
131
+ initialize_models()
132
+ logger.info("✅ Models loaded successfully inside GPU context")
133
+ else:
134
+ print("common_tts_engine already initialized.")
135
+ logger.info("🎯 GPU allocated for 120s (inference with loaded models)...")
136
+
137
+ try:
138
+ # Use loaded models (first call may include loading time, subsequent calls are fast)
139
+ if task_type == "clone":
140
+ output_audio, sr = common_tts_engine.clone(audio_input, text_input, target_text)
141
+ else:
142
+ output_audio, sr = common_tts_engine.edit(audio_input, text_input, task_type, task_info, target_text)
143
+
144
+ logger.info("✅ Audio processing completed")
145
+ return output_audio, sr
146
+
147
+ except Exception as e:
148
+ logger.error(f"❌ Audio processing failed: {e}")
149
+ raise
150
+ # GPU automatically deallocated when function exits
151
+
152
+ # Save audio to temporary directory
153
+ def save_audio(audio_type, audio_data, sr, tmp_dir):
154
+ """Save audio data to a temporary file with timestamp"""
155
+ current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
156
+ save_path = os.path.join(tmp_dir, audio_type, f"{current_time}.wav")
157
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
158
+
159
+ try:
160
+ if isinstance(audio_data, torch.Tensor):
161
+ torchaudio.save(save_path, audio_data, sr)
162
+ else:
163
+ sf.write(save_path, audio_data, sr)
164
+ logger.debug(f"Audio saved to: {save_path}")
165
+ return save_path
166
+ except Exception as e:
167
+ logger.error(f"Failed to save audio: {e}")
168
+ raise
169
+
170
+
171
+ class EditxTab:
172
+ """Audio editing and voice cloning interface tab"""
173
+
174
+ def __init__(self, args):
175
+ self.args = args
176
+ self.edit_type_list = list(get_supported_edit_types().keys())
177
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
178
+
179
+ def history_messages_to_show(self, messages):
180
+ """Convert message history to gradio chatbot format"""
181
+ show_msgs = []
182
+ for message in messages:
183
+ edit_type = message['edit_type']
184
+ edit_info = message['edit_info']
185
+ source_text = message['source_text']
186
+ target_text = message['target_text']
187
+ raw_audio_part = message['raw_wave']
188
+ edit_audio_part = message['edit_wave']
189
+ type_str = f"{edit_type}-{edit_info}" if edit_info is not None else f"{edit_type}"
190
+ show_msgs.extend([
191
+ {"role": "user", "content": f"任务类型:{type_str}\n文本:{source_text}"},
192
+ {"role": "user", "content": gr.Audio(value=raw_audio_part, interactive=False)},
193
+ {"role": "assistant", "content": f"输出音频:\n文本:{target_text}"},
194
+ {"role": "assistant", "content": gr.Audio(value=edit_audio_part, interactive=False)}
195
+ ])
196
+ return show_msgs
197
+
198
+ def generate_clone(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
199
+ """Generate cloned audio (models are loaded on first GPU call)"""
200
+ self.logger.info("Starting voice cloning process")
201
+ state['history_audio'] = []
202
+ state['history_messages'] = []
203
+
204
+ # Input validation
205
+ if not prompt_text_input or prompt_text_input.strip() == "":
206
+ error_msg = "[Error] Uploaded text cannot be empty."
207
+ self.logger.error(error_msg)
208
+ return [{"role": "user", "content": error_msg}], state
209
+ if not prompt_audio_input:
210
+ error_msg = "[Error] Uploaded audio cannot be empty."
211
+ self.logger.error(error_msg)
212
+ return [{"role": "user", "content": error_msg}], state
213
+ if not generated_text or generated_text.strip() == "":
214
+ error_msg = "[Error] Clone content cannot be empty."
215
+ self.logger.error(error_msg)
216
+ return [{"role": "user", "content": error_msg}], state
217
+ if edit_type != "clone":
218
+ error_msg = "[Error] CLONE button must use clone task."
219
+ self.logger.error(error_msg)
220
+ return [{"role": "user", "content": error_msg}], state
221
+
222
+ try:
223
+ # Use GPU inference with models loaded inside GPU context
224
+ output_audio, output_sr = process_audio_with_gpu(
225
+ prompt_audio_input, prompt_text_input, generated_text, "clone", edit_info
226
+ )
227
+
228
+ if output_audio is not None and output_sr is not None:
229
+ # Convert tensor to numpy if needed
230
+ if isinstance(output_audio, torch.Tensor):
231
+ audio_numpy = output_audio.cpu().numpy().squeeze()
232
+ else:
233
+ audio_numpy = output_audio
234
+
235
+ # Load original audio for comparison
236
+ input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
237
+
238
+ # Create message for history
239
+ cur_assistant_msg = {
240
+ "edit_type": edit_type,
241
+ "edit_info": edit_info,
242
+ "source_text": prompt_text_input,
243
+ "target_text": generated_text,
244
+ "raw_wave": (input_sample_rate, input_audio_data_numpy),
245
+ "edit_wave": (output_sr, audio_numpy),
246
+ }
247
+ state["history_audio"].append((output_sr, audio_numpy, generated_text))
248
+ state["history_messages"].append(cur_assistant_msg)
249
+
250
+ show_msgs = self.history_messages_to_show(state["history_messages"])
251
+ self.logger.info("Voice cloning completed successfully")
252
+ return show_msgs, state
253
+ else:
254
+ error_msg = "[Error] Clone failed"
255
+ self.logger.error(error_msg)
256
+ return [{"role": "user", "content": error_msg}], state
257
+
258
+ except Exception as e:
259
+ error_msg = f"[Error] Clone failed: {str(e)}"
260
+ self.logger.error(error_msg)
261
+ return [{"role": "user", "content": error_msg}], state
262
+
263
+ def generate_edit(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
264
+ """Generate edited audio (models are loaded on first GPU call)"""
265
+ self.logger.info("Starting audio editing process")
266
+
267
+ # Input validation
268
+ if not prompt_text_input or prompt_text_input.strip() == "":
269
+ error_msg = "[Error] Uploaded text cannot be empty."
270
+ self.logger.error(error_msg)
271
+ return [{"role": "user", "content": error_msg}], state
272
+ if not prompt_audio_input:
273
+ error_msg = "[Error] Uploaded audio cannot be empty."
274
+ self.logger.error(error_msg)
275
+ return [{"role": "user", "content": error_msg}], state
276
+
277
+ try:
278
+ # Determine which audio to use
279
+ if len(state["history_audio"]) == 0:
280
+ # First edit - use uploaded audio
281
+ audio_to_edit = prompt_audio_input
282
+ text_to_use = prompt_text_input
283
+ self.logger.debug("Using prompt audio, no history found")
284
+ else:
285
+ # Use previous edited audio - save it to temp file first
286
+ sample_rate, audio_numpy, previous_text = state["history_audio"][-1]
287
+ temp_path = save_audio("temp", audio_numpy, sample_rate, self.args.tmp_dir)
288
+ audio_to_edit = temp_path
289
+ text_to_use = previous_text
290
+ self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}")
291
+
292
+ # For para-linguistic, use generated_text; otherwise use source text
293
+ if edit_type not in {"para-linguistic"}:
294
+ generated_text = text_to_use
295
+
296
+ # Use GPU inference with models loaded inside GPU context
297
+ output_audio, output_sr = process_audio_with_gpu(
298
+ audio_to_edit, text_to_use, generated_text, edit_type, edit_info
299
+ )
300
+
301
+ if output_audio is not None and output_sr is not None:
302
+ # Convert tensor to numpy if needed
303
+ if isinstance(output_audio, torch.Tensor):
304
+ audio_numpy = output_audio.cpu().numpy().squeeze()
305
+ else:
306
+ audio_numpy = output_audio
307
+
308
+ # Load original audio for comparison
309
+ if len(state["history_audio"]) == 0:
310
+ input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
311
+ else:
312
+ input_sample_rate, input_audio_data_numpy, _ = state["history_audio"][-1]
313
+
314
+ # Create message for history
315
+ cur_assistant_msg = {
316
+ "edit_type": edit_type,
317
+ "edit_info": edit_info,
318
+ "source_text": text_to_use,
319
+ "target_text": generated_text,
320
+ "raw_wave": (input_sample_rate, input_audio_data_numpy),
321
+ "edit_wave": (output_sr, audio_numpy),
322
+ }
323
+ state["history_audio"].append((output_sr, audio_numpy, generated_text))
324
+ state["history_messages"].append(cur_assistant_msg)
325
+
326
+ show_msgs = self.history_messages_to_show(state["history_messages"])
327
+ self.logger.info("Audio editing completed successfully")
328
+ return show_msgs, state
329
+ else:
330
+ error_msg = "[Error] Edit failed"
331
+ self.logger.error(error_msg)
332
+ return [{"role": "user", "content": error_msg}], state
333
+
334
+ except Exception as e:
335
+ error_msg = f"[Error] Edit failed: {str(e)}"
336
+ self.logger.error(error_msg)
337
+ return [{"role": "user", "content": error_msg}], state
338
+
339
+ def clear_history(self, state):
340
+ """Clear conversation history"""
341
+ state["history_messages"] = []
342
+ state["history_audio"] = []
343
+ return [], state
344
+
345
+ def init_state(self):
346
+ """Initialize conversation state"""
347
+ return {
348
+ "history_messages": [],
349
+ "history_audio": []
350
+ }
351
+
352
+ def register_components(self):
353
+ """Register gradio components - maintaining exact layout from original"""
354
+ with gr.Tab("Editx"):
355
+ with gr.Row():
356
+ with gr.Column():
357
+ self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1)
358
+ self.prompt_text_input = gr.Textbox(label="Audio Text Content", value="", scale=1)
359
+ self.prompt_audio_input = gr.Audio(
360
+ sources=["upload", "microphone"],
361
+ format="wav",
362
+ type="filepath",
363
+ label="Input Audio",
364
+ )
365
+ self.generated_text = gr.Textbox(label="Clone Text", lines=1, max_lines=200)
366
+ with gr.Row():
367
+ self.button_tts = gr.Button("CLONE")
368
+ self.button_edit = gr.Button("EDIT")
369
+
370
+ with gr.Column():
371
+ with gr.Row():
372
+ self.edit_type = gr.Dropdown(label="Task", choices=self.edit_type_list, value="clone")
373
+ self.edit_info = gr.Dropdown(label="Sub-task", choices=[], value=None)
374
+ self.chat_box = gr.Chatbot(label="History", type="messages", height=480*1)
375
+ self.clean_history_submit = gr.Button("Clear History")
376
+
377
+ gr.Markdown("---")
378
+ gr.Markdown("""
379
+ **Button Description:**
380
+ - CLONE: Synthesizes audio based on uploaded audio and text, only used for clone mode, will clear history information when used.
381
+ - EDIT: Edits based on uploaded audio, or continues to stack edit effects based on the previous round of generated audio.
382
+ """)
383
+ gr.Markdown("""
384
+ **Operation Workflow:**
385
+ - Upload the audio to be edited on the left side and fill in the corresponding text content of the audio;
386
+ - If the task requires modifying text content (such as clone, para-linguistic), fill in the text to be synthesized in the "clone text" field. For all other tasks, keep the uploaded audio text content unchanged;
387
+ - Select tasks and subtasks on the right side (some tasks have no subtasks, such as vad, etc.);
388
+ - Click the "CLONE" or "EDIT" button on the left side, and audio will be generated in the dialog box on the right side.
389
+ """)
390
+ gr.Markdown("""
391
+ **Para-linguistic Description:**
392
+ - Supported tags include: [Breathing] [Laughter] [Cough] [Sigh] [Confirmation-en] [Question-en] [Question-ah] [Question-oh] [Surprise-ah] [Surprise-oh] [Dissatisfaction-hnn] [Uhm] [Shh] [Crying] [Surprise-wa] [Surprise-yo] [Question-ei] [Question-yi]
393
+ - Example:
394
+ - Fill in "clone text" field: "Great, the weather is so nice today." Click the "CLONE" button to get audio.
395
+ - Change "clone text" field to: "Great[Laughter], the weather is so nice today[Surprise-ah]." Click the "EDIT" button to get para-linguistic audio.
396
+ """)
397
+
398
+ def register_events(self):
399
+ """Register event handlers"""
400
+ # Create independent state for each session
401
+ state = gr.State(self.init_state())
402
+
403
+ self.button_tts.click(self.generate_clone,
404
+ inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
405
+ outputs=[self.chat_box, state])
406
+ self.button_edit.click(self.generate_edit,
407
+ inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
408
+ outputs=[self.chat_box, state])
409
+
410
+ self.clean_history_submit.click(self.clear_history, inputs=[state], outputs=[self.chat_box, state])
411
+ self.edit_type.change(
412
+ fn=self.update_edit_info,
413
+ inputs=self.edit_type,
414
+ outputs=self.edit_info,
415
+ )
416
+
417
+ def update_edit_info(self, category):
418
+ """Update sub-task dropdown based on main task selection"""
419
+ category_items = get_supported_edit_types()
420
+ choices = category_items.get(category, [])
421
+ value = None if len(choices) == 0 else choices[0]
422
+ return gr.Dropdown(label="Sub-task", choices=choices, value=value)
423
+
424
+
425
+ def launch_demo(args, editx_tab):
426
+ """Launch the gradio demo"""
427
+ with gr.Blocks(title="🎙️ Step-Audio-EditX") as demo:
428
+ gr.Markdown("## 🎙️ Step-Audio-EditX")
429
+ gr.Markdown("Audio editing and voice cloning using Step-Audio-Edit model.")
430
+
431
+ # Register components
432
+ editx_tab.register_components()
433
+
434
+ # Register events
435
+ editx_tab.register_events()
436
+
437
+ # Launch demo
438
+ demo.queue().launch(
439
+ server_name=args.server_name,
440
+ server_port=args.server_port,
441
+ share=args.share if hasattr(args, 'share') else False
442
+ )
443
+
444
+
445
+ if __name__ == "__main__":
446
+ # Parse command line arguments
447
+ parser = argparse.ArgumentParser(description="Step-Audio Edit Demo")
448
+ parser.add_argument("--model-path", type=str, default="stepfun-ai", help="Model path.")
449
+ parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
450
+ parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
451
+ parser.add_argument("--tmp-dir", type=str, default="/tmp/gradio", help="Save path.")
452
+ parser.add_argument("--share", action="store_true", help="Share gradio app.")
453
+
454
+ # Multi-source loading support parameters
455
+ parser.add_argument(
456
+ "--model-source",
457
+ type=str,
458
+ default="huggingface",
459
+ choices=["auto", "local", "modelscope", "huggingface"],
460
+ help="Model source: auto (detect automatically), local, modelscope, or huggingface"
461
+ )
462
+ parser.add_argument(
463
+ "--tokenizer-model-id",
464
+ type=str,
465
+ default="dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online",
466
+ help="Tokenizer model ID for online loading"
467
+ )
468
+ parser.add_argument(
469
+ "--tts-model-id",
470
+ type=str,
471
+ default=None,
472
+ help="TTS model ID for online loading (if different from model-path)"
473
+ )
474
+
475
+ args = parser.parse_args()
476
+
477
+ # Store args globally for model configuration
478
+ args_global = args
479
+
480
+ logger.info(f"Configuration loaded:")
481
+ logger.info(f"Model source: {args.model_source}")
482
+ logger.info(f"Model path: {args.model_path}")
483
+ logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}")
484
+ if args.tts_model_id:
485
+ logger.info(f"TTS model ID: {args.tts_model_id}")
486
+
487
+ # Models will be initialized on first GPU call to avoid ZeroGPU main process errors
488
+
489
+ if ZEROGPU_AVAILABLE:
490
+ logger.info("🎉 ZeroGPU detected - using dynamic GPU duration management!")
491
+ logger.info("💡 First call: 300s (model loading), subsequent calls: 120s (inference only)")
492
+ else:
493
+ logger.info("💻 Running in local mode - models will be loaded on first call")
494
+
495
+ # Create EditxTab instance
496
+ editx_tab = EditxTab(args)
497
+
498
+ # Launch demo
499
+ launch_demo(args, editx_tab)
config/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration module for Step-Audio
3
+ """
4
+
5
+ from .prompts import TTS_SYSTEM_PROMPTS, AUDIO_EDIT_SYSTEM_PROMPT
6
+ from .edit_config import get_supported_edit_types
7
+
8
+ __all__ = [
9
+ 'TTS_SYSTEM_PROMPTS',
10
+ 'AUDIO_EDIT_SYSTEM_PROMPT',
11
+ 'get_supported_edit_types'
12
+ ]
config/edit_config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 音频编辑配置模块
3
+ 包含支持的编辑类型和相关配置
4
+ """
5
+
6
+ def get_supported_edit_types():
7
+ """
8
+ 获取支持的编辑类型和选项
9
+
10
+ Returns:
11
+ Dict[str, list]: Dictionary of edit types and their options
12
+ """
13
+ return {
14
+ "clone": [],
15
+ "emotion": [
16
+ 'happy', 'angry', 'sad', 'humour', 'confusion', 'disgusted',
17
+ 'empathy', 'embarrass', 'fear', 'surprised', 'excited',
18
+ 'depressed', 'coldness', 'admiration'
19
+ ],
20
+ "style": [
21
+ 'serious', 'arrogant', 'child', 'older', 'girl', 'pure',
22
+ 'sister', 'sweet', 'ethereal', 'whisper', 'gentle', 'recite',
23
+ 'generous', 'act_coy', 'warm', 'shy', 'comfort', 'authority',
24
+ 'chat', 'radio', 'soulful', 'story', 'vivid', 'program',
25
+ 'news', 'advertising', 'roar', 'murmur', 'shout', 'deeply', 'loudly'
26
+ ],
27
+ "vad": [],
28
+ "music": [],
29
+ "denoise": [],
30
+ "para-linguistic": [],
31
+ "speed": ["faster", "slower", "more faster", "more slower"],
32
+ "animal": [],
33
+ }
config/prompts.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 系统提示配置模块
3
+ 包含所有TTS和编辑相关的系统提示
4
+ """
5
+
6
+ # TTS相关系统提示
7
+ TTS_SYSTEM_PROMPTS = {
8
+ "sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。",
9
+ "sys_prompt_for_vocal": "请参考对话历史里的音色,用哼唱的方式将文本内容大声唱出来。",
10
+ "sys_prompt_wo_spk": '以自然的语速读出下面的文字。',
11
+ "sys_prompt_with_spk": '请用{}的声音尽可能自然地说出下面这些话。',
12
+ }
13
+
14
+ # 音频编辑系统提示
15
+ AUDIO_EDIT_SYSTEM_PROMPT = """As a highly skilled audio editing and tuning specialist, you excel at interpreting user instructions and applying precise adjustments to audio files according to their needs. Your expertise spans a wide range of audio enhancement capabilities, including but not limited to the following:
16
+
17
+ # Emotional Enhancement of Speech:
18
+ You are capable of infusing speech with various emotions such as:
19
+ - happy
20
+ - angry
21
+ - sad
22
+ - fear
23
+ - disgusted
24
+ - surprised
25
+ - excited
26
+
27
+ # Speech Style Transfer:
28
+ You can adapt vocal delivery to diverse styles including:
29
+ - Whisper
30
+ - Coquettish
31
+ - Gentle
32
+ - Sweet
33
+ - Arrogant
34
+ - Innocent
35
+ - Radio Host
36
+ - Childlike
37
+ - Bold and Unconstrained
38
+ - Serious
39
+ - Expressive and Vivid
40
+ - Ethereal
41
+ - Exaggerated
42
+ - Recitation
43
+ - Girlish
44
+ - News Broadcast
45
+ - Mature Female Voice
46
+ - Middle-Aged or Elderly
47
+ - Program Hosting
48
+
49
+ # Paralinguistic Adjustments:
50
+ You can fine-tune non-verbal speech elements such as:
51
+ - Laughter Enhancement
52
+ - Emphatic Stress
53
+ - Rhythm and Pace Modulation
54
+
55
+ # Audio Tuning & Editing:
56
+ Your technical proficiency includes:
57
+ - Noise Reduction
58
+ - Background Music Removal
59
+ - Silence Trimming
60
+ - Speaker Extraction
61
+
62
+ Note: Users will provide instructions in natural language. You are expected to accurately interpret their requirements and perform the most suitable audio edits and enhancements."""
funasr_detach/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Initialize funasr package."""
2
+
3
+ import os
4
+ import pkgutil
5
+ import importlib
6
+
7
+ dirname = os.path.dirname(__file__)
8
+ version_file = os.path.join(dirname, "version.txt")
9
+ with open(version_file, "r") as f:
10
+ __version__ = f.read().strip()
11
+
12
+
13
+ import importlib
14
+ import pkgutil
15
+
16
+
17
+ def import_submodules(package, recursive=True):
18
+ if isinstance(package, str):
19
+ package = importlib.import_module(package)
20
+ results = {}
21
+ for loader, name, is_pkg in pkgutil.walk_packages(
22
+ package.__path__, package.__name__ + "."
23
+ ):
24
+ try:
25
+ results[name] = importlib.import_module(name)
26
+ except Exception as e:
27
+ # 如果想要看到导入错误的具体信息,可以取消注释下面的行
28
+ # print(f"Failed to import {name}: {e}")
29
+ pass
30
+ if recursive and is_pkg:
31
+ results.update(import_submodules(name))
32
+ return results
33
+
34
+
35
+ import_submodules(__name__)
36
+
37
+ from funasr_detach.auto.auto_model import AutoModel
38
+ from funasr_detach.auto.auto_frontend import AutoFrontend
funasr_detach/auto/__init__.py ADDED
File without changes
funasr_detach/auto/auto_frontend.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ from tqdm import tqdm
4
+
5
+ from funasr_detach.register import tables
6
+ from funasr_detach.download.download_from_hub import download_model
7
+ from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank
8
+ from funasr_detach.auto.auto_model import prepare_data_iterator
9
+ from funasr_detach.auto.auto_model import prepare_data_iterator
10
+
11
+
12
+ class AutoFrontend:
13
+ def __init__(self, **kwargs):
14
+ assert "model" in kwargs
15
+ if "model_conf" not in kwargs:
16
+ logging.info(
17
+ "download models from model hub: {}".format(
18
+ kwargs.get("model_hub", "ms")
19
+ )
20
+ )
21
+ kwargs = download_model(**kwargs)
22
+
23
+ # build frontend
24
+ frontend = kwargs.get("frontend", None)
25
+ if frontend is not None:
26
+ frontend_class = tables.frontend_classes.get(frontend)
27
+ frontend = frontend_class(**kwargs["frontend_conf"])
28
+
29
+ self.frontend = frontend
30
+ if "frontend" in kwargs:
31
+ del kwargs["frontend"]
32
+ self.kwargs = kwargs
33
+
34
+ def __call__(self, input, input_len=None, kwargs=None, **cfg):
35
+
36
+ kwargs = self.kwargs if kwargs is None else kwargs
37
+ kwargs.update(cfg)
38
+
39
+ key_list, data_list = prepare_data_iterator(input, input_len=input_len)
40
+ batch_size = kwargs.get("batch_size", 1)
41
+ device = kwargs.get("device", "cpu")
42
+ if device == "cpu":
43
+ batch_size = 1
44
+
45
+ meta_data = {}
46
+
47
+ result_list = []
48
+ num_samples = len(data_list)
49
+ pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
50
+
51
+ time0 = time.perf_counter()
52
+ for beg_idx in range(0, num_samples, batch_size):
53
+ end_idx = min(num_samples, beg_idx + batch_size)
54
+ data_batch = data_list[beg_idx:end_idx]
55
+ key_batch = key_list[beg_idx:end_idx]
56
+
57
+ # extract fbank feats
58
+ time1 = time.perf_counter()
59
+ audio_sample_list = load_audio_text_image_video(
60
+ data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
61
+ )
62
+ time2 = time.perf_counter()
63
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
64
+ speech, speech_lengths = extract_fbank(
65
+ audio_sample_list,
66
+ data_type=kwargs.get("data_type", "sound"),
67
+ frontend=self.frontend,
68
+ **kwargs,
69
+ )
70
+ time3 = time.perf_counter()
71
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
72
+ meta_data["batch_data_time"] = (
73
+ speech_lengths.sum().item()
74
+ * self.frontend.frame_shift
75
+ * self.frontend.lfr_n
76
+ / 1000
77
+ )
78
+
79
+ speech.to(device=device), speech_lengths.to(device=device)
80
+ batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
81
+ result_list.append(batch)
82
+
83
+ pbar.update(1)
84
+ description = f"{meta_data}, "
85
+ pbar.set_description(description)
86
+
87
+ time_end = time.perf_counter()
88
+ pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
89
+
90
+ return result_list
funasr_detach/auto/auto_model.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import copy
4
+ import torch
5
+ import random
6
+ import string
7
+ import logging
8
+ import os.path
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from funasr_detach.register import tables
13
+ from funasr_detach.utils.load_utils import load_bytes
14
+ from funasr_detach.download.file import download_from_url
15
+ from funasr_detach.download.download_from_hub import download_model
16
+ from funasr_detach.utils.vad_utils import slice_padding_audio_samples
17
+ from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
18
+ from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
19
+ from funasr_detach.utils.load_utils import load_audio_text_image_video
20
+ from funasr_detach.utils.timestamp_tools import timestamp_sentence
21
+ from funasr_detach.models.campplus.utils import sv_chunk, postprocess, distribute_spk
22
+
23
+ try:
24
+ from funasr_detach.models.campplus.cluster_backend import ClusterBackend
25
+ except:
26
+ print("If you want to use the speaker diarization, please `pip install hdbscan`")
27
+
28
+
29
+ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
30
+ """
31
+
32
+ :param input:
33
+ :param input_len:
34
+ :param data_type:
35
+ :param frontend:
36
+ :return:
37
+ """
38
+ data_list = []
39
+ key_list = []
40
+ filelist = [".scp", ".txt", ".json", ".jsonl"]
41
+
42
+ chars = string.ascii_letters + string.digits
43
+ if isinstance(data_in, str) and data_in.startswith("http"): # url
44
+ data_in = download_from_url(data_in)
45
+ if isinstance(data_in, str) and os.path.exists(
46
+ data_in
47
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
48
+ _, file_extension = os.path.splitext(data_in)
49
+ file_extension = file_extension.lower()
50
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
51
+ with open(data_in, encoding="utf-8") as fin:
52
+ for line in fin:
53
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
54
+ if data_in.endswith(
55
+ ".jsonl"
56
+ ): # file.jsonl: json.dumps({"source": data})
57
+ lines = json.loads(line.strip())
58
+ data = lines["source"]
59
+ key = data["key"] if "key" in data else key
60
+ else: # filelist, wav.scp, text.txt: id \t data or data
61
+ lines = line.strip().split(maxsplit=1)
62
+ data = lines[1] if len(lines) > 1 else lines[0]
63
+ key = lines[0] if len(lines) > 1 else key
64
+
65
+ data_list.append(data)
66
+ key_list.append(key)
67
+ else:
68
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
69
+ data_list = [data_in]
70
+ key_list = [key]
71
+ elif isinstance(data_in, (list, tuple)):
72
+ if data_type is not None and isinstance(
73
+ data_type, (list, tuple)
74
+ ): # mutiple inputs
75
+ data_list_tmp = []
76
+ for data_in_i, data_type_i in zip(data_in, data_type):
77
+ key_list, data_list_i = prepare_data_iterator(
78
+ data_in=data_in_i, data_type=data_type_i
79
+ )
80
+ data_list_tmp.append(data_list_i)
81
+ data_list = []
82
+ for item in zip(*data_list_tmp):
83
+ data_list.append(item)
84
+ else:
85
+ # [audio sample point, fbank, text]
86
+ data_list = data_in
87
+ key_list = [
88
+ "rand_key_" + "".join(random.choice(chars) for _ in range(13))
89
+ for _ in range(len(data_in))
90
+ ]
91
+ else: # raw text; audio sample point, fbank; bytes
92
+ if isinstance(data_in, bytes): # audio bytes
93
+ data_in = load_bytes(data_in)
94
+ if key is None:
95
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
96
+ data_list = [data_in]
97
+ key_list = [key]
98
+
99
+ return key_list, data_list
100
+
101
+
102
+ class AutoModel:
103
+
104
+ def __init__(self, **kwargs):
105
+ if not kwargs.get("disable_log", False):
106
+ tables.print()
107
+
108
+ model, kwargs = self.build_model(**kwargs)
109
+
110
+ # if vad_model is not None, build vad model else None
111
+ vad_model = kwargs.get("vad_model", None)
112
+ vad_kwargs = kwargs.get("vad_model_revision", None)
113
+ if vad_model is not None:
114
+ logging.info("Building VAD model.")
115
+ vad_kwargs = {
116
+ "model": vad_model,
117
+ "model_revision": vad_kwargs,
118
+ "device": kwargs["device"],
119
+ }
120
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
121
+
122
+ # if punc_model is not None, build punc model else None
123
+ punc_model = kwargs.get("punc_model", None)
124
+ punc_kwargs = kwargs.get("punc_model_revision", None)
125
+ if punc_model is not None:
126
+ logging.info("Building punc model.")
127
+ punc_kwargs = {
128
+ "model": punc_model,
129
+ "model_revision": punc_kwargs,
130
+ "device": kwargs["device"],
131
+ }
132
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
133
+
134
+ # if spk_model is not None, build spk model else None
135
+ spk_model = kwargs.get("spk_model", None)
136
+ spk_kwargs = kwargs.get("spk_model_revision", None)
137
+ if spk_model is not None:
138
+ logging.info("Building SPK model.")
139
+ spk_kwargs = {
140
+ "model": spk_model,
141
+ "model_revision": spk_kwargs,
142
+ "device": kwargs["device"],
143
+ }
144
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
145
+ self.cb_model = ClusterBackend().to(kwargs["device"])
146
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
147
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
148
+ logging.error(
149
+ "spk_mode should be one of default, vad_segment and punc_segment."
150
+ )
151
+ self.spk_mode = spk_mode
152
+
153
+ self.kwargs = kwargs
154
+ self.model = model
155
+ self.vad_model = vad_model
156
+ self.vad_kwargs = vad_kwargs
157
+ self.punc_model = punc_model
158
+ self.punc_kwargs = punc_kwargs
159
+ self.spk_model = spk_model
160
+ self.spk_kwargs = spk_kwargs
161
+ self.model_path = kwargs.get("model_path")
162
+ self.repo_path = kwargs.get("repo_path")
163
+
164
+
165
+ def build_model(self, **kwargs):
166
+ assert "model" in kwargs
167
+ if "model_conf" not in kwargs:
168
+ logging.info(
169
+ "download models from model hub: {}".format(
170
+ kwargs.get("model_hub", "ms")
171
+ )
172
+ )
173
+ kwargs = download_model(**kwargs)
174
+
175
+ set_all_random_seed(kwargs.get("seed", 0))
176
+
177
+ device = kwargs.get("device", "cuda")
178
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
179
+ device = "cpu"
180
+ kwargs["batch_size"] = 1
181
+ kwargs["device"] = device
182
+
183
+ if kwargs.get("ncpu", None):
184
+ torch.set_num_threads(kwargs.get("ncpu"))
185
+
186
+ # build tokenizer
187
+ tokenizer = kwargs.get("tokenizer", None)
188
+ if tokenizer is not None:
189
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
190
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
191
+ kwargs["tokenizer"] = tokenizer
192
+ kwargs["token_list"] = tokenizer.token_list
193
+ vocab_size = len(tokenizer.token_list)
194
+ else:
195
+ vocab_size = -1
196
+
197
+ # build frontend
198
+ frontend = kwargs.get("frontend", None)
199
+ if frontend is not None:
200
+ frontend_class = tables.frontend_classes.get(frontend)
201
+ frontend = frontend_class(**kwargs["frontend_conf"])
202
+ kwargs["frontend"] = frontend
203
+ kwargs["input_size"] = frontend.output_size()
204
+
205
+ # build model
206
+ model_class = tables.model_classes.get(kwargs["model"])
207
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
208
+
209
+ model.to(device)
210
+
211
+ # init_param
212
+ init_param = kwargs.get("init_param", None)
213
+ if init_param is not None:
214
+ logging.info(f"Loading pretrained params from {init_param}")
215
+ load_pretrained_model(
216
+ model=model,
217
+ path=init_param,
218
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
219
+ oss_bucket=kwargs.get("oss_bucket", None),
220
+ scope_map=kwargs.get("scope_map", None),
221
+ excludes=kwargs.get("excludes", None),
222
+ )
223
+
224
+ return model, kwargs
225
+
226
+ def __call__(self, *args, **cfg):
227
+ kwargs = self.kwargs
228
+ kwargs.update(cfg)
229
+ res = self.model(*args, kwargs)
230
+ return res
231
+
232
+ def generate(self, input, input_len=None, **cfg):
233
+ if self.vad_model is None:
234
+ return self.inference(input, input_len=input_len, **cfg)
235
+
236
+ else:
237
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
238
+
239
+ def inference(
240
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
241
+ ):
242
+ kwargs = self.kwargs if kwargs is None else kwargs
243
+ kwargs.update(cfg)
244
+ model = self.model if model is None else model
245
+ model = model.cuda()
246
+ model.eval()
247
+
248
+ batch_size = kwargs.get("batch_size", 1)
249
+ # if kwargs.get("device", "cpu") == "cpu":
250
+ # batch_size = 1
251
+
252
+ key_list, data_list = prepare_data_iterator(
253
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
254
+ )
255
+
256
+ speed_stats = {}
257
+ asr_result_list = []
258
+ num_samples = len(data_list)
259
+ disable_pbar = kwargs.get("disable_pbar", False)
260
+ pbar = (
261
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
262
+ if not disable_pbar
263
+ else None
264
+ )
265
+ time_speech_total = 0.0
266
+ time_escape_total = 0.0
267
+ for beg_idx in range(0, num_samples, batch_size):
268
+ end_idx = min(num_samples, beg_idx + batch_size)
269
+ data_batch = data_list[beg_idx:end_idx]
270
+ key_batch = key_list[beg_idx:end_idx]
271
+ batch = {"data_in": data_batch, "key": key_batch}
272
+ if (end_idx - beg_idx) == 1 and kwargs.get(
273
+ "data_type", None
274
+ ) == "fbank": # fbank
275
+ batch["data_in"] = data_batch[0]
276
+ batch["data_lengths"] = input_len
277
+
278
+ time1 = time.perf_counter()
279
+ with torch.no_grad():
280
+ results, meta_data = model.inference(**batch, **kwargs)
281
+ time2 = time.perf_counter()
282
+
283
+ asr_result_list.extend(results)
284
+
285
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
286
+ batch_data_time = meta_data.get("batch_data_time", -1)
287
+ time_escape = time2 - time1
288
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
289
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
290
+ speed_stats["forward"] = f"{time_escape:0.3f}"
291
+ speed_stats["batch_size"] = f"{len(results)}"
292
+ speed_stats["time_cost"] = f"{(time_escape)}"
293
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
294
+ description = f"{speed_stats}, "
295
+ if pbar:
296
+ pbar.update(1)
297
+ pbar.set_description(description)
298
+ time_speech_total += batch_data_time
299
+ time_escape_total += time_escape
300
+
301
+ if pbar:
302
+ # pbar.update(1)
303
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
304
+ torch.cuda.empty_cache()
305
+ return asr_result_list
306
+
307
+ def inference_with_vad(self, input, input_len=None, **cfg):
308
+
309
+ # step.1: compute the vad model
310
+ self.vad_kwargs.update(cfg)
311
+ beg_vad = time.time()
312
+ res = self.inference(
313
+ input,
314
+ input_len=input_len,
315
+ model=self.vad_model,
316
+ kwargs=self.vad_kwargs,
317
+ **cfg,
318
+ )
319
+ end_vad = time.time()
320
+ print(f"time cost vad: {end_vad - beg_vad:0.3f}")
321
+
322
+ # step.2 compute asr model
323
+ model = self.model
324
+ kwargs = self.kwargs
325
+ kwargs.update(cfg)
326
+ batch_size = int(kwargs.get("batch_size_s", 300)) * 1000
327
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
328
+ kwargs["batch_size"] = batch_size
329
+
330
+ key_list, data_list = prepare_data_iterator(
331
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
332
+ )
333
+ results_ret_list = []
334
+ time_speech_total_all_samples = 1e-6
335
+
336
+ beg_total = time.time()
337
+ pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
338
+ for i in range(len(res)):
339
+ key = res[i]["key"]
340
+ vadsegments = res[i]["value"]
341
+ input_i = data_list[i]
342
+ speech = load_audio_text_image_video(
343
+ input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000)
344
+ )
345
+ speech_lengths = len(speech)
346
+ n = len(vadsegments)
347
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
348
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
349
+ results_sorted = []
350
+
351
+ if not len(sorted_data):
352
+ logging.info("decoding, utt: {}, empty speech".format(key))
353
+ continue
354
+
355
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
356
+ batch_size = max(
357
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
358
+ )
359
+
360
+ batch_size_ms_cum = 0
361
+ beg_idx = 0
362
+ beg_asr_total = time.time()
363
+ time_speech_total_per_sample = speech_lengths / 16000
364
+ time_speech_total_all_samples += time_speech_total_per_sample
365
+
366
+ all_segments = []
367
+ for j, _ in enumerate(range(0, n)):
368
+ # pbar_sample.update(1)
369
+ batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
370
+ if (
371
+ j < n - 1
372
+ and (
373
+ batch_size_ms_cum
374
+ + sorted_data[j + 1][0][1]
375
+ - sorted_data[j + 1][0][0]
376
+ )
377
+ < batch_size
378
+ and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0])
379
+ < batch_size_threshold_ms
380
+ ):
381
+ continue
382
+ batch_size_ms_cum = 0
383
+ end_idx = j + 1
384
+ speech_j, speech_lengths_j = slice_padding_audio_samples(
385
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
386
+ )
387
+ results = self.inference(
388
+ speech_j,
389
+ input_len=None,
390
+ model=model,
391
+ kwargs=kwargs,
392
+ disable_pbar=True,
393
+ **cfg,
394
+ )
395
+ if self.spk_model is not None:
396
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
397
+ for _b in range(len(speech_j)):
398
+ vad_segments = [
399
+ [
400
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
401
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
402
+ np.array(speech_j[_b]),
403
+ ]
404
+ ]
405
+ segments = sv_chunk(vad_segments)
406
+ all_segments.extend(segments)
407
+ speech_b = [i[2] for i in segments]
408
+ spk_res = self.inference(
409
+ speech_b,
410
+ input_len=None,
411
+ model=self.spk_model,
412
+ kwargs=kwargs,
413
+ disable_pbar=True,
414
+ **cfg,
415
+ )
416
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
417
+ beg_idx = end_idx
418
+ if len(results) < 1:
419
+ continue
420
+ results_sorted.extend(results)
421
+
422
+ restored_data = [0] * n
423
+ for j in range(n):
424
+ index = sorted_data[j][1]
425
+ restored_data[index] = results_sorted[j]
426
+ result = {}
427
+
428
+ # results combine for texts, timestamps, speaker embeddings and others
429
+ # TODO: rewrite for clean code
430
+ for j in range(n):
431
+ for k, v in restored_data[j].items():
432
+ if k.startswith("timestamp"):
433
+ if k not in result:
434
+ result[k] = []
435
+ for t in restored_data[j][k]:
436
+ t[0] += vadsegments[j][0]
437
+ t[1] += vadsegments[j][0]
438
+ result[k].extend(restored_data[j][k])
439
+ elif k == "spk_embedding":
440
+ if k not in result:
441
+ result[k] = restored_data[j][k]
442
+ else:
443
+ result[k] = torch.cat(
444
+ [result[k], restored_data[j][k]], dim=0
445
+ )
446
+ elif "text" in k:
447
+ if k not in result:
448
+ result[k] = restored_data[j][k]
449
+ else:
450
+ result[k] += " " + restored_data[j][k]
451
+ else:
452
+ if k not in result:
453
+ result[k] = restored_data[j][k]
454
+ else:
455
+ result[k] += restored_data[j][k]
456
+
457
+ return_raw_text = kwargs.get("return_raw_text", False)
458
+ # step.3 compute punc model
459
+ if self.punc_model is not None:
460
+ self.punc_kwargs.update(cfg)
461
+ punc_res = self.inference(
462
+ result["text"],
463
+ model=self.punc_model,
464
+ kwargs=self.punc_kwargs,
465
+ disable_pbar=True,
466
+ **cfg,
467
+ )
468
+ raw_text = copy.copy(result["text"])
469
+ if return_raw_text:
470
+ result["raw_text"] = raw_text
471
+ result["text"] = punc_res[0]["text"]
472
+ else:
473
+ raw_text = None
474
+
475
+ # speaker embedding cluster after resorted
476
+ if self.spk_model is not None and kwargs.get("return_spk_res", True):
477
+ if raw_text is None:
478
+ logging.error("Missing punc_model, which is required by spk_model.")
479
+ all_segments = sorted(all_segments, key=lambda x: x[0])
480
+ spk_embedding = result["spk_embedding"]
481
+ labels = self.cb_model(
482
+ spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
483
+ )
484
+ # del result['spk_embedding']
485
+ sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
486
+ if self.spk_mode == "vad_segment": # recover sentence_list
487
+ sentence_list = []
488
+ for res, vadsegment in zip(restored_data, vadsegments):
489
+ if "timestamp" not in res:
490
+ logging.error(
491
+ "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
492
+ and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
493
+ can predict timestamp, and speaker diarization relies on timestamps."
494
+ )
495
+ sentence_list.append(
496
+ {
497
+ "start": vadsegment[0],
498
+ "end": vadsegment[1],
499
+ "sentence": res["text"],
500
+ "timestamp": res["timestamp"],
501
+ }
502
+ )
503
+ elif self.spk_mode == "punc_segment":
504
+ if "timestamp" not in result:
505
+ logging.error(
506
+ "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
507
+ and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
508
+ can predict timestamp, and speaker diarization relies on timestamps."
509
+ )
510
+ sentence_list = timestamp_sentence(
511
+ punc_res[0]["punc_array"],
512
+ result["timestamp"],
513
+ raw_text,
514
+ return_raw_text=return_raw_text,
515
+ )
516
+ distribute_spk(sentence_list, sv_output)
517
+ result["sentence_info"] = sentence_list
518
+ elif kwargs.get("sentence_timestamp", False):
519
+ sentence_list = timestamp_sentence(
520
+ punc_res[0]["punc_array"],
521
+ result["timestamp"],
522
+ raw_text,
523
+ return_raw_text=return_raw_text,
524
+ )
525
+ result["sentence_info"] = sentence_list
526
+ if "spk_embedding" in result:
527
+ del result["spk_embedding"]
528
+
529
+ result["key"] = key
530
+ results_ret_list.append(result)
531
+ end_asr_total = time.time()
532
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
533
+ pbar_total.update(1)
534
+ pbar_total.set_description(
535
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
536
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
537
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
538
+ )
539
+
540
+ return results_ret_list
541
+
542
+ def infer_encoder(
543
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
544
+ ):
545
+ kwargs = self.kwargs if kwargs is None else kwargs
546
+ kwargs.update(cfg)
547
+ model = self.model if model is None else model
548
+ model = model.cuda()
549
+ model.eval()
550
+
551
+ batch_size = kwargs.get("batch_size", 1)
552
+
553
+ key_list, data_list = prepare_data_iterator(
554
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
555
+ )
556
+
557
+ asr_result_list = []
558
+ num_samples = len(data_list)
559
+ for beg_idx in range(0, num_samples, batch_size):
560
+ end_idx = min(num_samples, beg_idx + batch_size)
561
+ data_batch = data_list[beg_idx:end_idx]
562
+ key_batch = key_list[beg_idx:end_idx]
563
+ batch = {"data_in": data_batch, "key": key_batch}
564
+ if (end_idx - beg_idx) == 1 and kwargs.get(
565
+ "data_type", None
566
+ ) == "fbank": # fbank
567
+ batch["data_in"] = data_batch[0]
568
+ batch["data_lengths"] = input_len
569
+
570
+ with torch.no_grad():
571
+ results, meta_data, cache = model.infer_encoder(**batch, **kwargs)
572
+ asr_result_list.extend(results)
573
+
574
+ torch.cuda.empty_cache()
575
+ return asr_result_list, cache
funasr_detach/auto/auto_tokenizer.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ class AutoTokenizer:
2
+ """
3
+ Undo
4
+ """
5
+
6
+ def __init__(self):
7
+ pass
funasr_detach/bin/__init__.py ADDED
File without changes
funasr_detach/bin/compute_audio_cmvn.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import torch
5
+ import hydra
6
+ import logging
7
+ from omegaconf import DictConfig, OmegaConf
8
+
9
+ from funasr_detach.register import tables
10
+ from funasr_detach.download.download_from_hub import download_model
11
+ from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
12
+
13
+
14
+ @hydra.main(config_name=None, version_base=None)
15
+ def main_hydra(kwargs: DictConfig):
16
+ if kwargs.get("debug", False):
17
+ import pdb
18
+
19
+ pdb.set_trace()
20
+
21
+ assert "model" in kwargs
22
+ if "model_conf" not in kwargs:
23
+ logging.info(
24
+ "download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
25
+ )
26
+ kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
27
+
28
+ main(**kwargs)
29
+
30
+
31
+ def main(**kwargs):
32
+ print(kwargs)
33
+ # set random seed
34
+ tables.print()
35
+ set_all_random_seed(kwargs.get("seed", 0))
36
+ torch.backends.cudnn.enabled = kwargs.get(
37
+ "cudnn_enabled", torch.backends.cudnn.enabled
38
+ )
39
+ torch.backends.cudnn.benchmark = kwargs.get(
40
+ "cudnn_benchmark", torch.backends.cudnn.benchmark
41
+ )
42
+ torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
43
+
44
+ tokenizer = kwargs.get("tokenizer", None)
45
+
46
+ # build frontend if frontend is none None
47
+ frontend = kwargs.get("frontend", None)
48
+ if frontend is not None:
49
+ frontend_class = tables.frontend_classes.get(frontend)
50
+ frontend = frontend_class(**kwargs["frontend_conf"])
51
+ kwargs["frontend"] = frontend
52
+ kwargs["input_size"] = frontend.output_size()
53
+
54
+ # dataset
55
+ dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
56
+ dataset_train = dataset_class(
57
+ kwargs.get("train_data_set_list"),
58
+ frontend=frontend,
59
+ tokenizer=None,
60
+ is_training=False,
61
+ **kwargs.get("dataset_conf")
62
+ )
63
+
64
+ # dataloader
65
+ batch_sampler = kwargs["dataset_conf"].get(
66
+ "batch_sampler", "DynamicBatchLocalShuffleSampler"
67
+ )
68
+ batch_sampler_train = None
69
+ if batch_sampler is not None:
70
+ batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
71
+ dataset_conf = kwargs.get("dataset_conf")
72
+ dataset_conf["batch_type"] = "example"
73
+ dataset_conf["batch_size"] = 1
74
+ batch_sampler_train = batch_sampler_class(
75
+ dataset_train, is_training=False, **dataset_conf
76
+ )
77
+
78
+ dataloader_train = torch.utils.data.DataLoader(
79
+ dataset_train,
80
+ collate_fn=dataset_train.collator,
81
+ batch_sampler=batch_sampler_train,
82
+ num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
83
+ pin_memory=True,
84
+ )
85
+
86
+ iter_stop = int(kwargs.get("scale", 1.0) * len(dataloader_train))
87
+
88
+ total_frames = 0
89
+ for batch_idx, batch in enumerate(dataloader_train):
90
+ if batch_idx >= iter_stop:
91
+ break
92
+
93
+ fbank = batch["speech"].numpy()[0, :, :]
94
+ if total_frames == 0:
95
+ mean_stats = np.sum(fbank, axis=0)
96
+ var_stats = np.sum(np.square(fbank), axis=0)
97
+ else:
98
+ mean_stats += np.sum(fbank, axis=0)
99
+ var_stats += np.sum(np.square(fbank), axis=0)
100
+ total_frames += fbank.shape[0]
101
+
102
+ cmvn_info = {
103
+ "mean_stats": list(mean_stats.tolist()),
104
+ "var_stats": list(var_stats.tolist()),
105
+ "total_frames": total_frames,
106
+ }
107
+ cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
108
+ # import pdb;pdb.set_trace()
109
+ with open(cmvn_file, "w") as fout:
110
+ fout.write(json.dumps(cmvn_info))
111
+
112
+ mean = -1.0 * mean_stats / total_frames
113
+ var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
114
+ dims = mean.shape[0]
115
+ am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
116
+ with open(am_mvn, "w") as fout:
117
+ fout.write(
118
+ "<Nnet>"
119
+ + "\n"
120
+ + "<Splice> "
121
+ + str(dims)
122
+ + " "
123
+ + str(dims)
124
+ + "\n"
125
+ + "[ 0 ]"
126
+ + "\n"
127
+ + "<AddShift> "
128
+ + str(dims)
129
+ + " "
130
+ + str(dims)
131
+ + "\n"
132
+ )
133
+ mean_str = (
134
+ str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
135
+ )
136
+ fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
137
+ fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
138
+ var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
139
+ fout.write("<LearnRateCoef> 0 " + var_str + "\n")
140
+ fout.write("</Nnet>" + "\n")
141
+
142
+
143
+ """
144
+ python funasr/bin/compute_audio_cmvn.py \
145
+ --config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
146
+ --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
147
+ ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
148
+ ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
149
+ ++dataset_conf.num_workers=0
150
+ """
151
+ if __name__ == "__main__":
152
+ main_hydra()
funasr_detach/bin/inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import logging
3
+ from omegaconf import DictConfig, OmegaConf, ListConfig
4
+
5
+ from funasr_detach.auto.auto_model import AutoModel
6
+
7
+
8
+ @hydra.main(config_name=None, version_base=None)
9
+ def main_hydra(cfg: DictConfig):
10
+ def to_plain_list(cfg_item):
11
+ if isinstance(cfg_item, ListConfig):
12
+ return OmegaConf.to_container(cfg_item, resolve=True)
13
+ elif isinstance(cfg_item, DictConfig):
14
+ return {k: to_plain_list(v) for k, v in cfg_item.items()}
15
+ else:
16
+ return cfg_item
17
+
18
+ kwargs = to_plain_list(cfg)
19
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
20
+
21
+ logging.basicConfig(level=log_level)
22
+
23
+ if kwargs.get("debug", False):
24
+ import pdb
25
+
26
+ pdb.set_trace()
27
+ model = AutoModel(**kwargs)
28
+ res = model.generate(input=kwargs["input"])
29
+ print(res)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main_hydra()
funasr_detach/bin/tokenize_text.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ from collections import Counter
4
+ import logging
5
+ from pathlib import Path
6
+ import sys
7
+ from typing import List
8
+ from typing import Optional
9
+
10
+
11
+ from funasr_detach.utils.cli_utils import get_commandline_args
12
+ from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
13
+ from funasr_detach.tokenizer.cleaner import TextCleaner
14
+ from funasr_detach.tokenizer.phoneme_tokenizer import g2p_classes
15
+ from funasr_detach.utils.types import str2bool
16
+ from funasr_detach.utils.types import str_or_none
17
+
18
+
19
+ def field2slice(field: Optional[str]) -> slice:
20
+ """Convert field string to slice
21
+
22
+ Note that field string accepts 1-based integer.
23
+
24
+ Examples:
25
+ >>> field2slice("1-")
26
+ slice(0, None, None)
27
+ >>> field2slice("1-3")
28
+ slice(0, 3, None)
29
+ >>> field2slice("-3")
30
+ slice(None, 3, None)
31
+ """
32
+ field = field.strip()
33
+ try:
34
+ if "-" in field:
35
+ # e.g. "2-" or "2-5" or "-7"
36
+ s1, s2 = field.split("-", maxsplit=1)
37
+ if s1.strip() == "":
38
+ s1 = None
39
+ else:
40
+ s1 = int(s1)
41
+ if s1 == 0:
42
+ raise ValueError("1-based string")
43
+ if s2.strip() == "":
44
+ s2 = None
45
+ else:
46
+ s2 = int(s2)
47
+ else:
48
+ # e.g. "2"
49
+ s1 = int(field)
50
+ s2 = s1 + 1
51
+ if s1 == 0:
52
+ raise ValueError("must be 1 or more value")
53
+ except ValueError:
54
+ raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
55
+
56
+ if s1 is None:
57
+ slic = slice(None, s2)
58
+ else:
59
+ # -1 because of 1-based integer following "cut" command
60
+ # e.g "1-3" -> slice(0, 3)
61
+ slic = slice(s1 - 1, s2)
62
+ return slic
63
+
64
+
65
+ def tokenize(
66
+ input: str,
67
+ output: str,
68
+ field: Optional[str],
69
+ delimiter: Optional[str],
70
+ token_type: str,
71
+ space_symbol: str,
72
+ non_linguistic_symbols: Optional[str],
73
+ bpemodel: Optional[str],
74
+ log_level: str,
75
+ write_vocabulary: bool,
76
+ vocabulary_size: int,
77
+ remove_non_linguistic_symbols: bool,
78
+ cutoff: int,
79
+ add_symbol: List[str],
80
+ cleaner: Optional[str],
81
+ g2p: Optional[str],
82
+ ):
83
+
84
+ logging.basicConfig(
85
+ level=log_level,
86
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
87
+ )
88
+ if input == "-":
89
+ fin = sys.stdin
90
+ else:
91
+ fin = Path(input).open("r", encoding="utf-8")
92
+ if output == "-":
93
+ fout = sys.stdout
94
+ else:
95
+ p = Path(output)
96
+ p.parent.mkdir(parents=True, exist_ok=True)
97
+ fout = p.open("w", encoding="utf-8")
98
+
99
+ cleaner = TextCleaner(cleaner)
100
+ tokenizer = build_tokenizer(
101
+ token_type=token_type,
102
+ bpemodel=bpemodel,
103
+ delimiter=delimiter,
104
+ space_symbol=space_symbol,
105
+ non_linguistic_symbols=non_linguistic_symbols,
106
+ remove_non_linguistic_symbols=remove_non_linguistic_symbols,
107
+ g2p_type=g2p,
108
+ )
109
+
110
+ counter = Counter()
111
+ if field is not None:
112
+ field = field2slice(field)
113
+
114
+ for line in fin:
115
+ line = line.rstrip()
116
+ if field is not None:
117
+ # e.g. field="2-"
118
+ # uttidA hello world!! -> hello world!!
119
+ tokens = line.split(delimiter)
120
+ tokens = tokens[field]
121
+ if delimiter is None:
122
+ line = " ".join(tokens)
123
+ else:
124
+ line = delimiter.join(tokens)
125
+
126
+ line = cleaner(line)
127
+ tokens = tokenizer.text2tokens(line)
128
+ if not write_vocabulary:
129
+ fout.write(" ".join(tokens) + "\n")
130
+ else:
131
+ for t in tokens:
132
+ counter[t] += 1
133
+
134
+ if not write_vocabulary:
135
+ return
136
+
137
+ ## FIXME
138
+ ## del duplicate add_symbols in counter
139
+ for symbol_and_id in add_symbol:
140
+ # e.g symbol="<blank>:0"
141
+ try:
142
+ symbol, idx = symbol_and_id.split(":")
143
+ except ValueError:
144
+ raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
145
+ symbol = symbol.strip()
146
+ if symbol in counter:
147
+ del counter[symbol]
148
+
149
+ # ======= write_vocabulary mode from here =======
150
+ # Sort by the number of occurrences in descending order
151
+ # and filter lower frequency words than cutoff value
152
+ words_and_counts = list(
153
+ filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
154
+ )
155
+ # Restrict the vocabulary size
156
+ if vocabulary_size > 0:
157
+ if vocabulary_size < len(add_symbol):
158
+ raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
159
+ words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
160
+
161
+ # Parse the values of --add_symbol
162
+ for symbol_and_id in add_symbol:
163
+ # e.g symbol="<blank>:0"
164
+ try:
165
+ symbol, idx = symbol_and_id.split(":")
166
+ idx = int(idx)
167
+ except ValueError:
168
+ raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
169
+ symbol = symbol.strip()
170
+
171
+ # e.g. idx=0 -> append as the first symbol
172
+ # e.g. idx=-1 -> append as the last symbol
173
+ if idx < 0:
174
+ idx = len(words_and_counts) + 1 + idx
175
+ words_and_counts.insert(idx, (symbol, None))
176
+
177
+ # Write words
178
+ for w, c in words_and_counts:
179
+ fout.write(w + "\n")
180
+
181
+ # Logging
182
+ total_count = sum(counter.values())
183
+ invocab_count = sum(c for w, c in words_and_counts if c is not None)
184
+ logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
185
+
186
+
187
+ def get_parser() -> argparse.ArgumentParser:
188
+ parser = argparse.ArgumentParser(
189
+ description="Tokenize texts",
190
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
191
+ )
192
+ parser.add_argument(
193
+ "--log_level",
194
+ type=lambda x: x.upper(),
195
+ default="INFO",
196
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
197
+ help="The verbose level of logging",
198
+ )
199
+
200
+ parser.add_argument(
201
+ "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
202
+ )
203
+ parser.add_argument(
204
+ "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
205
+ )
206
+ parser.add_argument(
207
+ "--field",
208
+ "-f",
209
+ help="The target columns of the input text as 1-based integer. e.g 2-",
210
+ )
211
+ parser.add_argument(
212
+ "--token_type",
213
+ "-t",
214
+ default="char",
215
+ choices=["char", "bpe", "word", "phn"],
216
+ help="Token type",
217
+ )
218
+ parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
219
+ parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
220
+ parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
221
+ parser.add_argument(
222
+ "--non_linguistic_symbols",
223
+ type=str_or_none,
224
+ help="non_linguistic_symbols file path",
225
+ )
226
+ parser.add_argument(
227
+ "--remove_non_linguistic_symbols",
228
+ type=str2bool,
229
+ default=False,
230
+ help="Remove non-language-symbols from tokens",
231
+ )
232
+ parser.add_argument(
233
+ "--cleaner",
234
+ type=str_or_none,
235
+ choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
236
+ default=None,
237
+ help="Apply text cleaning",
238
+ )
239
+ parser.add_argument(
240
+ "--g2p",
241
+ type=str_or_none,
242
+ choices=g2p_classes,
243
+ default=None,
244
+ help="Specify g2p method if --token_type=phn",
245
+ )
246
+
247
+ group = parser.add_argument_group("write_vocabulary mode related")
248
+ group.add_argument(
249
+ "--write_vocabulary",
250
+ type=str2bool,
251
+ default=False,
252
+ help="Write tokens list instead of tokenized text per line",
253
+ )
254
+ group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
255
+ group.add_argument(
256
+ "--cutoff",
257
+ default=0,
258
+ type=int,
259
+ help="cut-off frequency used for write-vocabulary mode",
260
+ )
261
+ group.add_argument(
262
+ "--add_symbol",
263
+ type=str,
264
+ default=[],
265
+ action="append",
266
+ help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
267
+ )
268
+
269
+ return parser
270
+
271
+
272
+ def main(cmd=None):
273
+ print(get_commandline_args(), file=sys.stderr)
274
+ parser = get_parser()
275
+ args = parser.parse_args(cmd)
276
+ kwargs = vars(args)
277
+ tokenize(**kwargs)
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
funasr_detach/bin/train.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import os
5
+ import sys
6
+ import torch
7
+ import hydra
8
+ import logging
9
+ import argparse
10
+ from io import BytesIO
11
+ import torch.distributed as dist
12
+ from collections.abc import Sequence
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16
+
17
+ from funasr_detach.register import tables
18
+ from funasr_detach.optimizers import optim_classes
19
+ from funasr_detach.train_utils.trainer import Trainer
20
+ from funasr_detach.schedulers import scheduler_classes
21
+ from funasr_detach.train_utils.initialize import initialize
22
+ from funasr_detach.download.download_from_hub import download_model
23
+ from funasr_detach.models.lora.utils import mark_only_lora_as_trainable
24
+ from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
25
+ from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
26
+
27
+ # from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
28
+ # from funasr_detach.tokenizer.token_id_converter import TokenIDConverter
29
+ # from funasr_detach.tokenizer.funtoken import build_tokenizer
30
+
31
+
32
+ @hydra.main(config_name=None, version_base=None)
33
+ def main_hydra(kwargs: DictConfig):
34
+ if kwargs.get("debug", False):
35
+ import pdb
36
+
37
+ pdb.set_trace()
38
+
39
+ assert "model" in kwargs
40
+ if "model_conf" not in kwargs:
41
+ logging.info(
42
+ "download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
43
+ )
44
+ kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
45
+
46
+ main(**kwargs)
47
+
48
+
49
+ def main(**kwargs):
50
+ print(kwargs)
51
+
52
+ # set random seed
53
+ set_all_random_seed(kwargs.get("seed", 0))
54
+ torch.backends.cudnn.enabled = kwargs.get(
55
+ "cudnn_enabled", torch.backends.cudnn.enabled
56
+ )
57
+ torch.backends.cudnn.benchmark = kwargs.get(
58
+ "cudnn_benchmark", torch.backends.cudnn.benchmark
59
+ )
60
+ torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
61
+
62
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
63
+ if local_rank == 0:
64
+ tables.print()
65
+ # Check if we are using DDP or FSDP
66
+ use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
67
+ use_fsdp = kwargs.get("use_fsdp", None)
68
+ if use_ddp or use_fsdp:
69
+ dist.init_process_group(
70
+ backend=kwargs.get("backend", "nccl"), init_method="env://"
71
+ )
72
+ torch.cuda.set_device(local_rank)
73
+
74
+ # save config.yaml
75
+ if (
76
+ (use_ddp or use_fsdp)
77
+ and dist.get_rank() == 0
78
+ or not (use_ddp or use_fsdp)
79
+ and local_rank == 0
80
+ ):
81
+ os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
82
+ yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
83
+ OmegaConf.save(config=kwargs, f=yaml_file)
84
+ logging.info("config.yaml is saved to: %s", yaml_file)
85
+
86
+ tokenizer = kwargs.get("tokenizer", None)
87
+ if tokenizer is not None:
88
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
89
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
90
+ kwargs["tokenizer"] = tokenizer
91
+
92
+ # build frontend if frontend is none None
93
+ frontend = kwargs.get("frontend", None)
94
+ if frontend is not None:
95
+ frontend_class = tables.frontend_classes.get(frontend)
96
+ frontend = frontend_class(**kwargs["frontend_conf"])
97
+ kwargs["frontend"] = frontend
98
+ kwargs["input_size"] = frontend.output_size()
99
+
100
+ # build model
101
+ model_class = tables.model_classes.get(kwargs["model"])
102
+ model = model_class(
103
+ **kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)
104
+ )
105
+
106
+ # init_param
107
+ init_param = kwargs.get("init_param", None)
108
+ if init_param is not None:
109
+ if not isinstance(init_param, (list, tuple)):
110
+ init_param = (init_param,)
111
+ logging.info("init_param is not None: %s", init_param)
112
+ for p in init_param:
113
+ logging.info(f"Loading pretrained params from {p}")
114
+ load_pretrained_model(
115
+ model=model,
116
+ path=p,
117
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
118
+ oss_bucket=kwargs.get("oss_bucket", None),
119
+ scope_map=kwargs.get("scope_map", None),
120
+ excludes=kwargs.get("excludes", None),
121
+ )
122
+ else:
123
+ initialize(model, kwargs.get("init", "kaiming_normal"))
124
+
125
+ # freeze_param
126
+ freeze_param = kwargs.get("freeze_param", None)
127
+ if freeze_param is not None:
128
+ freeze_param = eval(freeze_param)
129
+ if isinstance(freeze_param, Sequence):
130
+ freeze_param = (freeze_param,)
131
+ logging.info("freeze_param is not None: %s", freeze_param)
132
+ for t in freeze_param:
133
+ for k, p in model.named_parameters():
134
+ if k.startswith(t + ".") or k == t:
135
+ logging.info(f"Setting {k}.requires_grad = False")
136
+ p.requires_grad = False
137
+
138
+ if use_ddp:
139
+ model = model.cuda(local_rank)
140
+ model = DDP(
141
+ model,
142
+ device_ids=[local_rank],
143
+ find_unused_parameters=kwargs.get("train_conf", {}).get(
144
+ "find_unused_parameters", False
145
+ ),
146
+ )
147
+ elif use_fsdp:
148
+ model = FSDP(model).cuda(local_rank)
149
+ else:
150
+ model = model.to(device=kwargs.get("device", "cuda"))
151
+
152
+ # optim
153
+ optim = kwargs.get("optim", "adam")
154
+ assert optim in optim_classes
155
+ optim_class = optim_classes.get(optim)
156
+ optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
157
+
158
+ # scheduler
159
+ scheduler = kwargs.get("scheduler", "warmuplr")
160
+ assert scheduler in scheduler_classes
161
+ scheduler_class = scheduler_classes.get(scheduler)
162
+ scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
163
+
164
+ # dataset
165
+ dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
166
+ dataset_tr = dataset_class(
167
+ kwargs.get("train_data_set_list"),
168
+ frontend=frontend,
169
+ tokenizer=tokenizer,
170
+ is_training=True,
171
+ **kwargs.get("dataset_conf"),
172
+ )
173
+ dataset_val = dataset_class(
174
+ kwargs.get("valid_data_set_list"),
175
+ frontend=frontend,
176
+ tokenizer=tokenizer,
177
+ is_training=False,
178
+ **kwargs.get("dataset_conf"),
179
+ )
180
+
181
+ # dataloader
182
+ batch_sampler = kwargs["dataset_conf"].get(
183
+ "batch_sampler", "DynamicBatchLocalShuffleSampler"
184
+ )
185
+ batch_sampler_val = None
186
+ if batch_sampler is not None:
187
+ batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
188
+ batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
189
+ batch_sampler_val = batch_sampler_class(
190
+ dataset_val, is_training=False, **kwargs.get("dataset_conf")
191
+ )
192
+ dataloader_tr = torch.utils.data.DataLoader(
193
+ dataset_tr,
194
+ collate_fn=dataset_tr.collator,
195
+ batch_sampler=batch_sampler,
196
+ num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
197
+ pin_memory=True,
198
+ )
199
+
200
+ dataloader_val = torch.utils.data.DataLoader(
201
+ dataset_val,
202
+ collate_fn=dataset_val.collator,
203
+ batch_sampler=batch_sampler_val,
204
+ num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
205
+ pin_memory=True,
206
+ )
207
+ trainer = Trainer(
208
+ model=model,
209
+ optim=optim,
210
+ scheduler=scheduler,
211
+ dataloader_train=dataloader_tr,
212
+ dataloader_val=dataloader_val,
213
+ local_rank=local_rank,
214
+ use_ddp=use_ddp,
215
+ use_fsdp=use_fsdp,
216
+ output_dir=kwargs.get("output_dir", "./exp"),
217
+ resume=kwargs.get("resume", True),
218
+ **kwargs.get("train_conf"),
219
+ )
220
+ trainer.run()
221
+
222
+ if use_ddp or use_fsdp:
223
+ torch.distributed.destroy_process_group()
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main_hydra()
funasr_detach/datasets/__init__.py ADDED
File without changes
funasr_detach/datasets/audio_datasets/__init__.py ADDED
File without changes
funasr_detach/datasets/audio_datasets/datasets.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from funasr_detach.register import tables
4
+ from funasr_detach.utils.load_utils import extract_fbank, load_audio_text_image_video
5
+
6
+
7
+ @tables.register("dataset_classes", "AudioDataset")
8
+ class AudioDataset(torch.utils.data.Dataset):
9
+ """
10
+ AudioDataset
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ path,
16
+ index_ds: str = None,
17
+ frontend=None,
18
+ tokenizer=None,
19
+ int_pad_value: int = -1,
20
+ float_pad_value: float = 0.0,
21
+ **kwargs
22
+ ):
23
+ super().__init__()
24
+ index_ds_class = tables.index_ds_classes.get(index_ds)
25
+ self.index_ds = index_ds_class(path, **kwargs)
26
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
27
+ if preprocessor_speech:
28
+ preprocessor_speech_class = tables.preprocessor_classes.get(
29
+ preprocessor_speech
30
+ )
31
+ preprocessor_speech = preprocessor_speech_class(
32
+ **kwargs.get("preprocessor_speech_conf")
33
+ )
34
+ self.preprocessor_speech = preprocessor_speech
35
+ preprocessor_text = kwargs.get("preprocessor_text", None)
36
+ if preprocessor_text:
37
+ preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
38
+ preprocessor_text = preprocessor_text_class(
39
+ **kwargs.get("preprocessor_text_conf")
40
+ )
41
+ self.preprocessor_text = preprocessor_text
42
+
43
+ self.frontend = frontend
44
+ self.fs = 16000 if frontend is None else frontend.fs
45
+ self.data_type = "sound"
46
+ self.tokenizer = tokenizer
47
+
48
+ self.int_pad_value = int_pad_value
49
+ self.float_pad_value = float_pad_value
50
+
51
+ def get_source_len(self, index):
52
+ item = self.index_ds[index]
53
+ return self.index_ds.get_source_len(item)
54
+
55
+ def get_target_len(self, index):
56
+ item = self.index_ds[index]
57
+ return self.index_ds.get_target_len(item)
58
+
59
+ def __len__(self):
60
+ return len(self.index_ds)
61
+
62
+ def __getitem__(self, index):
63
+ item = self.index_ds[index]
64
+ # import pdb;
65
+ # pdb.set_trace()
66
+ source = item["source"]
67
+ data_src = load_audio_text_image_video(source, fs=self.fs)
68
+ if self.preprocessor_speech:
69
+ data_src = self.preprocessor_speech(data_src, fs=self.fs)
70
+ speech, speech_lengths = extract_fbank(
71
+ data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
72
+ ) # speech: [b, T, d]
73
+
74
+ target = item["target"]
75
+ if self.preprocessor_text:
76
+ target = self.preprocessor_text(target)
77
+ if self.tokenizer:
78
+ ids = self.tokenizer.encode(target)
79
+ text = torch.tensor(ids, dtype=torch.int64)
80
+ else:
81
+ ids = target
82
+ text = ids
83
+ ids_lengths = len(ids)
84
+ text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
85
+
86
+ return {
87
+ "speech": speech[0, :, :],
88
+ "speech_lengths": speech_lengths,
89
+ "text": text,
90
+ "text_lengths": text_lengths,
91
+ }
92
+
93
+ def collator(self, samples: list = None):
94
+ outputs = {}
95
+ for sample in samples:
96
+ for key in sample.keys():
97
+ if key not in outputs:
98
+ outputs[key] = []
99
+ outputs[key].append(sample[key])
100
+
101
+ for key, data_list in outputs.items():
102
+ if isinstance(data_list[0], torch.Tensor):
103
+ if data_list[0].dtype == torch.int64:
104
+
105
+ pad_value = self.int_pad_value
106
+ else:
107
+ pad_value = self.float_pad_value
108
+
109
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(
110
+ data_list, batch_first=True, padding_value=pad_value
111
+ )
112
+ return outputs
funasr_detach/datasets/audio_datasets/index_ds.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ import concurrent.futures
6
+ import librosa
7
+ import torch.distributed as dist
8
+
9
+ from funasr_detach.register import tables
10
+
11
+
12
+ @tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
13
+ class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
14
+
15
+ def __init__(self, path):
16
+ super().__init__()
17
+
18
+ contents = []
19
+ with open(path, encoding="utf-8") as fin:
20
+ for line in fin:
21
+ data = json.loads(line.strip())
22
+ if "text" in data: # for sft
23
+ self.contents.append(data["text"])
24
+ if "source" in data: # for speech lab pretrain
25
+ prompt = data["prompt"]
26
+ source = data["source"]
27
+ target = data["target"]
28
+ source_len = data["source_len"]
29
+ target_len = data["target_len"]
30
+
31
+ contents.append(
32
+ {
33
+ "source": source,
34
+ "prompt": prompt,
35
+ "target": target,
36
+ "source_len": source_len,
37
+ "target_len": target_len,
38
+ }
39
+ )
40
+
41
+ self.contents = []
42
+ total_num = len(contents)
43
+ try:
44
+ rank = dist.get_rank()
45
+ world_size = dist.get_world_size()
46
+ except:
47
+ rank = 0
48
+ world_size = 1
49
+ logging.warning("distributed is not initialized, only single shard")
50
+ num_per_rank = total_num // world_size
51
+
52
+ # rank = 0
53
+ # import ipdb; ipdb.set_trace()
54
+ self.contents = contents[rank * num_per_rank : (rank + 1) * num_per_rank]
55
+
56
+ logging.info(
57
+ "in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(
58
+ rank, len(self.contents), len(contents)
59
+ )
60
+ )
61
+
62
+ def __len__(self):
63
+ return len(self.contents)
64
+
65
+ def __getitem__(self, index):
66
+ try:
67
+ data = self.contents[index]
68
+ except:
69
+ print(index)
70
+ return data
71
+
72
+ def get_source_len(self, data_dict):
73
+ return data_dict["source_len"]
74
+
75
+ def get_target_len(self, data_dict):
76
+
77
+ return data_dict["target_len"] if "target_len" in data_dict else 0
78
+
79
+
80
+ @tables.register("index_ds_classes", "IndexDSJsonl")
81
+ @tables.register("index_ds_classes", "IndexDSJsonlRankFull")
82
+ class IndexDSJsonlRankFull(torch.utils.data.Dataset):
83
+
84
+ def __init__(self, path: str, **kwargs):
85
+ super().__init__()
86
+
87
+ if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
88
+ from funasr_detach.datasets.audio_datasets.scp2jsonl import (
89
+ gen_jsonl_from_wav_text_list,
90
+ )
91
+
92
+ jsonl_outdir = os.path.dirname(path[0])
93
+ jsonl_name = (
94
+ "datalist_train.jsonl"
95
+ if kwargs.get("is_training", True)
96
+ else "datalist_val.jsonl"
97
+ )
98
+ jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name)
99
+ if not os.path.exists(jsonl_file_out):
100
+ print(f"datalist is: {path}, generate jsonl from it")
101
+ gen_jsonl_from_wav_text_list(
102
+ path, jsonl_file_out=jsonl_file_out, **kwargs
103
+ )
104
+ path = jsonl_file_out
105
+
106
+ contents = []
107
+ with open(path, encoding="utf-8") as fin:
108
+ for line in fin:
109
+ data = json.loads(line.strip())
110
+ if "text" in data: # for sft
111
+ self.contents.append(data["text"])
112
+ if "source" in data: # for speech lab pretrain
113
+ prompt = data.get("prompt", "<ASR>")
114
+ source = data["source"]
115
+ target = data["target"]
116
+ source_len = data.get("source_len", 1)
117
+ target_len = data.get("target_len", 0)
118
+
119
+ contents.append(
120
+ {
121
+ "source": source,
122
+ "prompt": prompt,
123
+ "target": target,
124
+ "source_len": source_len,
125
+ "target_len": target_len,
126
+ }
127
+ )
128
+
129
+ self.contents = contents
130
+
131
+ logging.info(
132
+ "total_num of samplers across ranks: {}".format(len(self.contents))
133
+ )
134
+
135
+ def __len__(self):
136
+ return len(self.contents)
137
+
138
+ def __getitem__(self, index):
139
+ try:
140
+ data = self.contents[index]
141
+ except:
142
+ print(index)
143
+ return data
144
+
145
+ def get_source_len(self, data_dict):
146
+ return data_dict.get("source_len", 1)
147
+
148
+ def get_target_len(self, data_dict):
149
+
150
+ return data_dict.get("target_len", 0)
funasr_detach/datasets/audio_datasets/preprocessor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ import concurrent.futures
6
+ import librosa
7
+ import torch.distributed as dist
8
+ from typing import Collection
9
+ import torch
10
+ import torchaudio
11
+ from torch import nn
12
+ import random
13
+ import re
14
+ from funasr_detach.tokenizer.cleaner import TextCleaner
15
+ from funasr_detach.register import tables
16
+
17
+
18
+ @tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
19
+ class SpeechPreprocessSpeedPerturb(nn.Module):
20
+ def __init__(self, speed_perturb: list = None, **kwargs):
21
+ super().__init__()
22
+ self.speed_perturb = speed_perturb
23
+
24
+ def forward(self, waveform, fs, **kwargs):
25
+ if self.speed_perturb is None:
26
+ return waveform
27
+ speed = random.choice(self.speed_perturb)
28
+ if speed != 1.0:
29
+ if not isinstance(waveform, torch.Tensor):
30
+ waveform = torch.tensor(waveform)
31
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
32
+ waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
33
+ )
34
+ waveform = waveform.view(-1)
35
+
36
+ return waveform
37
+
38
+
39
+ @tables.register("preprocessor_classes", "TextPreprocessSegDict")
40
+ class TextPreprocessSegDict(nn.Module):
41
+ def __init__(
42
+ self,
43
+ seg_dict: str = None,
44
+ text_cleaner: Collection[str] = None,
45
+ split_with_space: bool = False,
46
+ **kwargs
47
+ ):
48
+ super().__init__()
49
+
50
+ self.text_cleaner = TextCleaner(text_cleaner)
51
+
52
+ def forward(self, text, **kwargs):
53
+ text = self.text_cleaner(text)
54
+
55
+ return text
funasr_detach/datasets/audio_datasets/samplers.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import logging
4
+ import torch.distributed as dist
5
+
6
+ from funasr_detach.register import tables
7
+
8
+
9
+ @tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
10
+ class BatchSampler(torch.utils.data.BatchSampler):
11
+
12
+ def __init__(
13
+ self,
14
+ dataset,
15
+ batch_type: str = "example",
16
+ batch_size: int = 100,
17
+ buffer_size: int = 30,
18
+ drop_last: bool = False,
19
+ shuffle: bool = True,
20
+ is_training: bool = True,
21
+ **kwargs
22
+ ):
23
+
24
+ self.drop_last = drop_last
25
+ self.pre_idx = -1
26
+ self.dataset = dataset
27
+ self.total_samples = len(dataset)
28
+ self.batch_type = batch_type
29
+ self.batch_size = int(batch_size)
30
+ self.buffer_size = buffer_size
31
+ self.max_token_length = kwargs.get("max_token_length", 5000)
32
+ self.shuffle_idx = np.arange(self.total_samples)
33
+ self.shuffle = shuffle and is_training
34
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
35
+
36
+ def __len__(self):
37
+ return (self.total_samples - 1) // self.batch_size + 1
38
+
39
+ def set_epoch(self, epoch):
40
+ np.random.seed(epoch)
41
+
42
+ def __iter__(self):
43
+
44
+ if self.shuffle:
45
+ np.random.shuffle(self.shuffle_idx)
46
+
47
+ batch = []
48
+ max_token = 0
49
+ num_sample = 0
50
+
51
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
52
+ # print("iter_num: ", iter_num)
53
+ for iter in range(self.pre_idx + 1, iter_num):
54
+ datalen_with_index = []
55
+ for i in range(self.buffer_size):
56
+ idx = iter * self.buffer_size + i
57
+ if idx >= self.total_samples:
58
+ continue
59
+
60
+ idx_map = self.shuffle_idx[idx]
61
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
62
+ target_len = (
63
+ self.dataset.get_target_len(idx_map)
64
+ if self.batch_type == "length"
65
+ else 0.0
66
+ )
67
+ source_len = (
68
+ self.dataset.get_source_len(idx_map) / self.length_scale_source
69
+ )
70
+ sample_len_cur = source_len + target_len
71
+
72
+ datalen_with_index.append([idx, sample_len_cur])
73
+
74
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
75
+ for item in datalen_with_index_sort:
76
+ idx, sample_len_cur_raw = item
77
+ if sample_len_cur_raw > self.max_token_length:
78
+ continue
79
+
80
+ max_token_cur = max(max_token, sample_len_cur_raw)
81
+ max_token_padding = 1 + num_sample
82
+ if self.batch_type != "example":
83
+ max_token_padding *= max_token_cur
84
+ if max_token_padding <= self.batch_size:
85
+ batch.append(idx)
86
+ max_token = max_token_cur
87
+ num_sample += 1
88
+ else:
89
+ yield batch
90
+ batch = [idx]
91
+ max_token = sample_len_cur_raw
92
+ num_sample = 1
93
+
94
+
95
+ @tables.register("batch_sampler_classes", "BatchSampler")
96
+ @tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
97
+ class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
98
+
99
+ def __init__(
100
+ self,
101
+ dataset,
102
+ batch_type: str = "example",
103
+ batch_size: int = 100,
104
+ buffer_size: int = 30,
105
+ drop_last: bool = True,
106
+ shuffle: bool = True,
107
+ is_training: bool = True,
108
+ **kwargs
109
+ ):
110
+
111
+ self.drop_last = drop_last
112
+ self.pre_idx = -1
113
+ self.dataset = dataset
114
+ self.total_samples = len(dataset)
115
+ self.batch_type = batch_type
116
+ self.batch_size = int(batch_size)
117
+ self.buffer_size = buffer_size
118
+ self.max_token_length = kwargs.get("max_token_length", 1500)
119
+ self.shuffle_idx = np.arange(self.total_samples)
120
+ self.shuffle = shuffle and is_training
121
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
122
+
123
+ try:
124
+ rank = dist.get_rank()
125
+ world_size = dist.get_world_size()
126
+ except:
127
+ rank = 0
128
+ world_size = 1
129
+ self.rank = rank
130
+ self.world_size = world_size
131
+
132
+ def __len__(self):
133
+ return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
134
+
135
+ def set_epoch(self, epoch):
136
+ np.random.seed(epoch)
137
+
138
+ def __iter__(self):
139
+
140
+ batch_size_total = self.batch_size * self.world_size
141
+
142
+ if self.shuffle:
143
+ np.random.shuffle(self.shuffle_idx)
144
+
145
+ batch = []
146
+ max_token = 0
147
+ num_sample = 0
148
+
149
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
150
+ # print("iter_num: ", iter_num)
151
+ for iter in range(self.pre_idx + 1, iter_num):
152
+ # if iter == iter_num -1 and self.drop_last:
153
+ # continue
154
+ datalen_with_index = []
155
+ for i in range(self.buffer_size):
156
+ idx = iter * self.buffer_size + i
157
+ if idx >= self.total_samples:
158
+ continue
159
+
160
+ idx_map = self.shuffle_idx[idx]
161
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
162
+
163
+ source_len = (
164
+ self.dataset.get_source_len(idx_map) / self.length_scale_source
165
+ )
166
+ target_len = (
167
+ self.dataset.get_target_len(idx_map)
168
+ if self.batch_type == "length"
169
+ else 0.0
170
+ )
171
+ sample_len_cur = source_len + target_len
172
+
173
+ datalen_with_index.append([idx, sample_len_cur])
174
+
175
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
176
+ for item in datalen_with_index_sort:
177
+ idx, sample_len_cur_raw = item
178
+ if sample_len_cur_raw > self.max_token_length:
179
+ continue
180
+
181
+ max_token_cur = max(max_token, sample_len_cur_raw)
182
+ max_token_padding = 1 + num_sample
183
+ # if self.batch_type != 'example':
184
+ # max_token_padding *= max_token_cur
185
+ if max_token_padding <= batch_size_total:
186
+ batch.append(idx)
187
+ max_token = max_token_cur
188
+ num_sample += 1
189
+ else:
190
+ batch_rank = batch[
191
+ self.rank * self.batch_size : (self.rank + 1) * self.batch_size
192
+ ]
193
+ yield batch_rank
194
+ batch = [idx]
195
+ max_token = sample_len_cur_raw
196
+ num_sample = 1
197
+
198
+
199
+ @tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
200
+ class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
201
+
202
+ def __init__(
203
+ self,
204
+ dataset,
205
+ batch_type: str = "example",
206
+ batch_size: int = 100,
207
+ buffer_size: int = 30,
208
+ drop_last: bool = True,
209
+ shuffle: bool = True,
210
+ is_training: bool = True,
211
+ **kwargs
212
+ ):
213
+
214
+ self.drop_last = drop_last
215
+ self.pre_idx = -1
216
+ self.dataset = dataset
217
+ self.total_samples = len(dataset)
218
+ self.batch_type = batch_type
219
+ self.batch_size = int(batch_size)
220
+ self.buffer_size = buffer_size
221
+ self.max_token_length = kwargs.get("max_token_length", 1500)
222
+ self.shuffle_idx = np.arange(self.total_samples)
223
+ self.shuffle = shuffle and is_training
224
+ self.length_scale_source = kwargs.get("length_scale_source", 1.0)
225
+
226
+ try:
227
+ rank = dist.get_rank()
228
+ world_size = dist.get_world_size()
229
+ except:
230
+ rank = 0
231
+ world_size = 1
232
+ self.rank = rank
233
+ self.world_size = world_size
234
+
235
+ def __len__(self):
236
+ return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
237
+
238
+ def set_epoch(self, epoch):
239
+ np.random.seed(epoch)
240
+
241
+ def __iter__(self):
242
+
243
+ batch_size_total = self.batch_size * self.world_size
244
+ if self.shuffle:
245
+ np.random.shuffle(self.shuffle_idx)
246
+
247
+ batch_list_all_rank = []
248
+ batch_list_cur = []
249
+ max_token = 0
250
+ num_sample = 0
251
+
252
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
253
+ # print("iter_num: ", iter_num)
254
+ for iter in range(self.pre_idx + 1, iter_num):
255
+ # if iter == iter_num - 1 and self.drop_last:
256
+ # continue
257
+ datalen_with_index = []
258
+ for i in range(self.buffer_size):
259
+ idx = iter * self.buffer_size + i
260
+ if idx >= self.total_samples:
261
+ continue
262
+
263
+ idx_map = self.shuffle_idx[idx]
264
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
265
+
266
+ source_len = (
267
+ self.dataset.get_source_len(idx_map) / self.length_scale_source
268
+ )
269
+ target_len = (
270
+ self.dataset.get_target_len(idx_map)
271
+ if self.batch_type == "length"
272
+ else 0.0
273
+ )
274
+ sample_len_cur = source_len + target_len
275
+
276
+ datalen_with_index.append([idx, sample_len_cur])
277
+
278
+ datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
279
+ for ii, item in enumerate(datalen_with_index_sort):
280
+ is_last_batch = iter == iter_num - 1 and ii == len(
281
+ datalen_with_index_sort
282
+ )
283
+ idx, sample_len_cur_raw = item
284
+ if sample_len_cur_raw > self.max_token_length:
285
+ continue
286
+
287
+ max_token_cur = max(max_token, sample_len_cur_raw)
288
+ max_token_padding = 1 + num_sample
289
+
290
+ if self.batch_type != "example":
291
+ max_token_padding *= max_token_cur
292
+ if len(batch_list_all_rank) < self.world_size:
293
+
294
+ if max_token_padding <= self.batch_size:
295
+ batch_list_cur.append(idx)
296
+ max_token = max_token_cur
297
+ num_sample += 1
298
+ else:
299
+ batch_list_all_rank.append(batch_list_cur)
300
+ batch_list_cur = []
301
+ else:
302
+ batch_rank = batch_list_all_rank[self.rank]
303
+ yield batch_rank
304
+ batch_list_all_rank = [idx]
305
+ max_token = sample_len_cur_raw
306
+ num_sample = 1
funasr_detach/datasets/audio_datasets/scp2jsonl.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ import hydra
6
+ from omegaconf import DictConfig, OmegaConf
7
+ import concurrent.futures
8
+ import librosa
9
+ import torch.distributed as dist
10
+
11
+
12
+ def gen_jsonl_from_wav_text_list(
13
+ path, data_type_list=("source", "target"), jsonl_file_out: str = None, **kwargs
14
+ ):
15
+ try:
16
+ rank = dist.get_rank()
17
+ world_size = dist.get_world_size()
18
+ except:
19
+ rank = 0
20
+ world_size = 1
21
+
22
+ cpu_cores = os.cpu_count() or 1
23
+ print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
24
+ if rank == 0:
25
+ json_dict = {}
26
+ for data_type, data_file in zip(data_type_list, path):
27
+ json_dict[data_type] = {}
28
+ with open(data_file, "r") as f:
29
+
30
+ data_file_lists = f.readlines()
31
+ lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
32
+ task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
33
+ with concurrent.futures.ThreadPoolExecutor(
34
+ max_workers=cpu_cores
35
+ ) as executor:
36
+
37
+ futures = [
38
+ executor.submit(
39
+ parse_context_length,
40
+ data_file_lists[
41
+ i * lines_for_each_th : (i + 1) * lines_for_each_th
42
+ ],
43
+ data_type,
44
+ )
45
+ for i in range(task_num)
46
+ ]
47
+
48
+ for future in concurrent.futures.as_completed(futures):
49
+
50
+ json_dict[data_type].update(future.result())
51
+ # print(json_dict)
52
+
53
+ with open(jsonl_file_out, "w") as f:
54
+ for key in json_dict[data_type_list[0]].keys():
55
+ jsonl_line = {"key": key}
56
+ for data_file in data_type_list:
57
+ jsonl_line.update(json_dict[data_file][key])
58
+ jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
59
+ f.write(jsonl_line + "\n")
60
+ f.flush()
61
+
62
+ else:
63
+ pass
64
+
65
+ if world_size > 1:
66
+ dist.barrier()
67
+
68
+
69
+ def parse_context_length(data_list: list, data_type: str):
70
+
71
+ res = {}
72
+ for i, line in enumerate(data_list):
73
+ key, line = line.strip().split(maxsplit=1)
74
+ line = line.strip()
75
+ if os.path.exists(line):
76
+ waveform, _ = librosa.load(line, sr=16000)
77
+ sample_num = len(waveform)
78
+ context_len = int(sample_num // 16000 * 1000 / 10)
79
+ else:
80
+ context_len = len(line.split()) if " " in line else len(line)
81
+ res[key] = {data_type: line, f"{data_type}_len": context_len}
82
+ return res
83
+
84
+
85
+ @hydra.main(config_name=None, version_base=None)
86
+ def main_hydra(cfg: DictConfig):
87
+
88
+ kwargs = OmegaConf.to_container(cfg, resolve=True)
89
+
90
+ scp_file_list = kwargs.get(
91
+ "scp_file_list",
92
+ (
93
+ "/Users/zhifu/funasr1.0/test_local/wav.scp",
94
+ "/Users/zhifu/funasr1.0/test_local/text.txt",
95
+ ),
96
+ )
97
+ if isinstance(scp_file_list, str):
98
+ scp_file_list = eval(scp_file_list)
99
+ data_type_list = kwargs.get("data_type_list", ("source", "target"))
100
+ jsonl_file_out = kwargs.get(
101
+ "jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
102
+ )
103
+ gen_jsonl_from_wav_text_list(
104
+ scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
105
+ )
106
+
107
+
108
+ """
109
+ python -m funasr_detach.datasets.audio_datasets.scp2jsonl \
110
+ ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
111
+ ++data_type_list='["source", "target"]' \
112
+ ++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
113
+ """
114
+
115
+ if __name__ == "__main__":
116
+ main_hydra()
funasr_detach/download/__init__.py ADDED
File without changes
funasr_detach/download/download_dataset_from_hub.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def download_dataset():
2
+ pass
3
+
4
+
5
+ def download_dataset_from_ms(**kwargs):
6
+ from modelscope.msdatasets import MsDataset
7
+
8
+ dataset_name = kwargs.get(
9
+ "dataset_name", "speech_asr/speech_asr_aishell1_trainsets"
10
+ )
11
+ subset_name = kwargs.get("subset_name", "default")
12
+ split = kwargs.get("split", "train")
13
+ data_dump_dir = kwargs.get("data_dump_dir", None)
14
+ ds = MsDataset.load(
15
+ dataset_name=dataset_name,
16
+ subset_name=subset_name,
17
+ split=split,
18
+ cache_dir=data_dump_dir,
19
+ )
funasr_detach/download/download_from_hub.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import threading
4
+ from omegaconf import OmegaConf
5
+
6
+ from funasr_detach.download.name_maps_from_hub import name_maps_ms, name_maps_hf
7
+
8
+ # Global cache for downloaded models to avoid repeated downloads
9
+ # Key: (repo_id, model_revision, model_hub)
10
+ # Value: repo_cache_dir
11
+ _model_cache = {}
12
+ _cache_lock = threading.Lock()
13
+
14
+
15
+ def download_model(**kwargs):
16
+ model_hub = kwargs.get("model_hub", "ms")
17
+ model_or_path = kwargs.get("model")
18
+ repo_path = kwargs.get("repo_path", "")
19
+
20
+ # Handle name mapping based on model_hub
21
+ if model_hub == "ms" and model_or_path in name_maps_ms:
22
+ model_or_path = name_maps_ms[model_or_path]
23
+ elif model_hub == "hf" and model_or_path in name_maps_hf:
24
+ model_or_path = name_maps_hf[model_or_path]
25
+
26
+ model_revision = kwargs.get("model_revision")
27
+
28
+ # Download model if it doesn't exist locally
29
+ if not os.path.exists(model_or_path):
30
+ if model_hub == "local":
31
+ # For local models, the path should already exist
32
+ raise FileNotFoundError(f"Local model path does not exist: {model_or_path}")
33
+ elif model_hub in ["ms", "hf"]:
34
+ repo_path, model_or_path = get_or_download_model_dir(
35
+ model_or_path,
36
+ model_revision,
37
+ is_training=kwargs.get("is_training"),
38
+ check_latest=kwargs.get("kwargs", True),
39
+ model_hub=model_hub,
40
+ )
41
+ else:
42
+ raise ValueError(f"Unsupported model_hub: {model_hub}")
43
+
44
+ print(f"Using model path: {model_or_path}")
45
+ kwargs["model_path"] = model_or_path
46
+ kwargs["repo_path"] = repo_path
47
+
48
+ # Common logic for processing configuration files (same for all model hubs)
49
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
50
+ with open(
51
+ os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8"
52
+ ) as f:
53
+ conf_json = json.load(f)
54
+ cfg = {}
55
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
56
+ cfg.update(kwargs)
57
+ config = OmegaConf.load(cfg["config"])
58
+ kwargs = OmegaConf.merge(config, cfg)
59
+ kwargs["model"] = config["model"]
60
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
61
+ os.path.join(model_or_path, "model.pt")
62
+ ):
63
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
64
+ kwargs = OmegaConf.merge(config, kwargs)
65
+ init_param = os.path.join(model_or_path, "model.pb")
66
+ kwargs["init_param"] = init_param
67
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
68
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(
69
+ model_or_path, "tokens.txt"
70
+ )
71
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
72
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(
73
+ model_or_path, "tokens.json"
74
+ )
75
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
76
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(
77
+ model_or_path, "seg_dict"
78
+ )
79
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
80
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(
81
+ model_or_path, "bpe.model"
82
+ )
83
+ kwargs["model"] = config["model"]
84
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
85
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
86
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
87
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
88
+
89
+ return OmegaConf.to_container(kwargs, resolve=True)
90
+
91
+
92
+ def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
93
+
94
+ if isinstance(file_path_metas, dict):
95
+ for k, v in file_path_metas.items():
96
+ if isinstance(v, str):
97
+ p = os.path.join(model_or_path, v)
98
+ if os.path.exists(p):
99
+ cfg[k] = p
100
+ elif isinstance(v, dict):
101
+ if k not in cfg:
102
+ cfg[k] = {}
103
+ add_file_root_path(model_or_path, v, cfg[k])
104
+
105
+ return cfg
106
+
107
+
108
+ def get_or_download_model_dir(
109
+ model,
110
+ model_revision=None,
111
+ is_training=False,
112
+ check_latest=True,
113
+ model_hub="ms",
114
+ ):
115
+ """Get local model directory or download model if necessary.
116
+
117
+ Args:
118
+ model (str): model id or path to local model directory.
119
+ For HF subfolders, use format: "repo_id/subfolder_path"
120
+ model_revision (str, optional): model version number.
121
+ is_training (bool): Whether this is for training
122
+ check_latest (bool): Whether to check for latest version
123
+ model_hub (str): Model hub type ("ms" for ModelScope, "hf" for HuggingFace)
124
+ """
125
+ # Extract repo_id for caching (handle subfolder case)
126
+ if "/" in model and len(model.split("/")) > 2:
127
+ parts = model.split("/")
128
+ repo_id = "/".join(parts[:2]) # e.g., "organization/repo" or "stepfun-ai/Step-Audio-EditX"
129
+ subfolder = "/".join(parts[2:]) # e.g., "subfolder/model"
130
+ else:
131
+ repo_id = model
132
+ subfolder = None
133
+
134
+ # Create cache key
135
+ cache_key = (repo_id, model_revision, model_hub)
136
+
137
+ # Check cache first
138
+ with _cache_lock:
139
+ if cache_key in _model_cache:
140
+ cached_repo_dir = _model_cache[cache_key]
141
+ print(f"Using cached model for {repo_id}: {cached_repo_dir}")
142
+
143
+ # For subfolder case, construct the model_cache_dir from cached repo
144
+ if subfolder:
145
+ model_cache_dir = os.path.join(cached_repo_dir, subfolder)
146
+ if not os.path.exists(model_cache_dir):
147
+ raise FileNotFoundError(f"Subfolder {subfolder} not found in cached repo {repo_id}")
148
+ else:
149
+ model_cache_dir = cached_repo_dir
150
+
151
+ return cached_repo_dir, model_cache_dir
152
+
153
+ # Cache miss, need to download
154
+ if model_hub == "ms":
155
+ # ModelScope download
156
+ from modelscope.hub.snapshot_download import snapshot_download
157
+ from modelscope.utils.constant import Invoke, ThirdParty
158
+
159
+ key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
160
+
161
+ # Download the repo (use repo_id, not the full model path with subfolder)
162
+ repo_cache_dir = snapshot_download(
163
+ repo_id,
164
+ revision=model_revision,
165
+ user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"},
166
+ )
167
+ repo_cache_dir = normalize_cache_path(repo_cache_dir)
168
+
169
+ # Construct model_cache_dir
170
+ if subfolder:
171
+ model_cache_dir = os.path.join(repo_cache_dir, subfolder)
172
+ if not os.path.exists(model_cache_dir):
173
+ raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
174
+ else:
175
+ model_cache_dir = normalize_cache_path(repo_cache_dir)
176
+
177
+ elif model_hub == "hf":
178
+ # HuggingFace download
179
+ try:
180
+ from huggingface_hub import snapshot_download
181
+ except ImportError:
182
+ raise ImportError(
183
+ "huggingface_hub is required for downloading from HuggingFace. "
184
+ "Please install it with: pip install huggingface_hub"
185
+ )
186
+
187
+ # Download the repo (use repo_id, not the full model path with subfolder)
188
+ repo_cache_dir = snapshot_download(
189
+ repo_id=repo_id,
190
+ revision=model_revision,
191
+ allow_patterns=None, # Download all files to ensure resource files are available
192
+ )
193
+ repo_cache_dir = normalize_cache_path(repo_cache_dir)
194
+
195
+ # Construct model_cache_dir
196
+ if subfolder:
197
+ model_cache_dir = os.path.join(repo_cache_dir, subfolder)
198
+ if not os.path.exists(model_cache_dir):
199
+ raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
200
+ else:
201
+ model_cache_dir = normalize_cache_path(repo_cache_dir)
202
+ else:
203
+ raise ValueError(f"Unsupported model_hub: {model_hub}")
204
+
205
+ # Cache the result before returning
206
+ with _cache_lock:
207
+ _model_cache[cache_key] = repo_cache_dir
208
+
209
+ print(f"Model downloaded to: {model_cache_dir}")
210
+ return repo_cache_dir, model_cache_dir
211
+
212
+ def normalize_cache_path(cache_path):
213
+ """Normalize cache path to ensure consistent format with snapshots/{commit_id}."""
214
+ # Check if the cache_path directory contains a snapshots folder
215
+ snapshots_dir = os.path.join(cache_path, "snapshots")
216
+ if os.path.exists(snapshots_dir) and os.path.isdir(snapshots_dir):
217
+ # Find the commit_id subdirectory in snapshots
218
+ try:
219
+ snapshot_items = os.listdir(snapshots_dir)
220
+ # Look for the first directory (should be the commit_id)
221
+ for item in snapshot_items:
222
+ item_path = os.path.join(snapshots_dir, item)
223
+ if os.path.isdir(item_path):
224
+ # Found commit_id directory, return the full path
225
+ return os.path.join(cache_path, "snapshots", item)
226
+ except OSError:
227
+ pass
228
+
229
+ # If no snapshots directory found or error occurred, return original path
230
+ return cache_path
231
+
funasr_detach/download/file.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import contextlib
4
+ import os
5
+ import tempfile
6
+ from abc import ABCMeta, abstractmethod
7
+ from pathlib import Path
8
+ from typing import Generator, Union
9
+
10
+ import requests
11
+ from urllib.parse import urlparse
12
+
13
+
14
+ def download_from_url(url):
15
+ result = urlparse(url)
16
+ file_path = None
17
+ if result.scheme is not None and len(result.scheme) > 0:
18
+ storage = HTTPStorage()
19
+ # bytes
20
+ data = storage.read(url)
21
+ work_dir = tempfile.TemporaryDirectory().name
22
+ if not os.path.exists(work_dir):
23
+ os.makedirs(work_dir)
24
+ file_path = os.path.join(work_dir, os.path.basename(url))
25
+ with open(file_path, "wb") as fb:
26
+ fb.write(data)
27
+ assert file_path is not None, f"failed to download: {url}"
28
+ return file_path
29
+
30
+
31
+ class Storage(metaclass=ABCMeta):
32
+ """Abstract class of storage.
33
+
34
+ All backends need to implement two apis: ``read()`` and ``read_text()``.
35
+ ``read()`` reads the file as a byte stream and ``read_text()`` reads
36
+ the file as texts.
37
+ """
38
+
39
+ @abstractmethod
40
+ def read(self, filepath: str):
41
+ pass
42
+
43
+ @abstractmethod
44
+ def read_text(self, filepath: str):
45
+ pass
46
+
47
+ @abstractmethod
48
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
49
+ pass
50
+
51
+ @abstractmethod
52
+ def write_text(
53
+ self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
54
+ ) -> None:
55
+ pass
56
+
57
+
58
+ class LocalStorage(Storage):
59
+ """Local hard disk storage"""
60
+
61
+ def read(self, filepath: Union[str, Path]) -> bytes:
62
+ """Read data from a given ``filepath`` with 'rb' mode.
63
+
64
+ Args:
65
+ filepath (str or Path): Path to read data.
66
+
67
+ Returns:
68
+ bytes: Expected bytes object.
69
+ """
70
+ with open(filepath, "rb") as f:
71
+ content = f.read()
72
+ return content
73
+
74
+ def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
75
+ """Read data from a given ``filepath`` with 'r' mode.
76
+
77
+ Args:
78
+ filepath (str or Path): Path to read data.
79
+ encoding (str): The encoding format used to open the ``filepath``.
80
+ Default: 'utf-8'.
81
+
82
+ Returns:
83
+ str: Expected text reading from ``filepath``.
84
+ """
85
+ with open(filepath, "r", encoding=encoding) as f:
86
+ value_buf = f.read()
87
+ return value_buf
88
+
89
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
90
+ """Write data to a given ``filepath`` with 'wb' mode.
91
+
92
+ Note:
93
+ ``write`` will create a directory if the directory of ``filepath``
94
+ does not exist.
95
+
96
+ Args:
97
+ obj (bytes): Data to be written.
98
+ filepath (str or Path): Path to write data.
99
+ """
100
+ dirname = os.path.dirname(filepath)
101
+ if dirname and not os.path.exists(dirname):
102
+ os.makedirs(dirname, exist_ok=True)
103
+
104
+ with open(filepath, "wb") as f:
105
+ f.write(obj)
106
+
107
+ def write_text(
108
+ self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
109
+ ) -> None:
110
+ """Write data to a given ``filepath`` with 'w' mode.
111
+
112
+ Note:
113
+ ``write_text`` will create a directory if the directory of
114
+ ``filepath`` does not exist.
115
+
116
+ Args:
117
+ obj (str): Data to be written.
118
+ filepath (str or Path): Path to write data.
119
+ encoding (str): The encoding format used to open the ``filepath``.
120
+ Default: 'utf-8'.
121
+ """
122
+ dirname = os.path.dirname(filepath)
123
+ if dirname and not os.path.exists(dirname):
124
+ os.makedirs(dirname, exist_ok=True)
125
+
126
+ with open(filepath, "w", encoding=encoding) as f:
127
+ f.write(obj)
128
+
129
+ @contextlib.contextmanager
130
+ def as_local_path(
131
+ self, filepath: Union[str, Path]
132
+ ) -> Generator[Union[str, Path], None, None]:
133
+ """Only for unified API and do nothing."""
134
+ yield filepath
135
+
136
+
137
+ class HTTPStorage(Storage):
138
+ """HTTP and HTTPS storage."""
139
+
140
+ def read(self, url):
141
+ # TODO @wenmeng.zwm add progress bar if file is too large
142
+ r = requests.get(url)
143
+ r.raise_for_status()
144
+ return r.content
145
+
146
+ def read_text(self, url):
147
+ r = requests.get(url)
148
+ r.raise_for_status()
149
+ return r.text
150
+
151
+ @contextlib.contextmanager
152
+ def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
153
+ """Download a file from ``filepath``.
154
+
155
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
156
+ can be called with ``with`` statement, and when exists from the
157
+ ``with`` statement, the temporary path will be released.
158
+
159
+ Args:
160
+ filepath (str): Download a file from ``filepath``.
161
+
162
+ Examples:
163
+ >>> storage = HTTPStorage()
164
+ >>> # After existing from the ``with`` clause,
165
+ >>> # the path will be removed
166
+ >>> with storage.get_local_path('http://path/to/file') as path:
167
+ ... # do something here
168
+ """
169
+ try:
170
+ f = tempfile.NamedTemporaryFile(delete=False)
171
+ f.write(self.read(filepath))
172
+ f.close()
173
+ yield f.name
174
+ finally:
175
+ os.remove(f.name)
176
+
177
+ def write(self, obj: bytes, url: Union[str, Path]) -> None:
178
+ raise NotImplementedError("write is not supported by HTTP Storage")
179
+
180
+ def write_text(
181
+ self, obj: str, url: Union[str, Path], encoding: str = "utf-8"
182
+ ) -> None:
183
+ raise NotImplementedError("write_text is not supported by HTTP Storage")
184
+
185
+
186
+ class OSSStorage(Storage):
187
+ """OSS storage."""
188
+
189
+ def __init__(self, oss_config_file=None):
190
+ # read from config file or env var
191
+ raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")
192
+
193
+ def read(self, filepath):
194
+ raise NotImplementedError("OSSStorage.read to be implemented in the future")
195
+
196
+ def read_text(self, filepath, encoding="utf-8"):
197
+ raise NotImplementedError(
198
+ "OSSStorage.read_text to be implemented in the future"
199
+ )
200
+
201
+ @contextlib.contextmanager
202
+ def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
203
+ """Download a file from ``filepath``.
204
+
205
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
206
+ can be called with ``with`` statement, and when exists from the
207
+ ``with`` statement, the temporary path will be released.
208
+
209
+ Args:
210
+ filepath (str): Download a file from ``filepath``.
211
+
212
+ Examples:
213
+ >>> storage = OSSStorage()
214
+ >>> # After existing from the ``with`` clause,
215
+ >>> # the path will be removed
216
+ >>> with storage.get_local_path('http://path/to/file') as path:
217
+ ... # do something here
218
+ """
219
+ try:
220
+ f = tempfile.NamedTemporaryFile(delete=False)
221
+ f.write(self.read(filepath))
222
+ f.close()
223
+ yield f.name
224
+ finally:
225
+ os.remove(f.name)
226
+
227
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
228
+ raise NotImplementedError("OSSStorage.write to be implemented in the future")
229
+
230
+ def write_text(
231
+ self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
232
+ ) -> None:
233
+ raise NotImplementedError(
234
+ "OSSStorage.write_text to be implemented in the future"
235
+ )
236
+
237
+
238
+ G_STORAGES = {}
239
+
240
+
241
+ class File(object):
242
+ _prefix_to_storage: dict = {
243
+ "oss": OSSStorage,
244
+ "http": HTTPStorage,
245
+ "https": HTTPStorage,
246
+ "local": LocalStorage,
247
+ }
248
+
249
+ @staticmethod
250
+ def _get_storage(uri):
251
+ assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"
252
+
253
+ if "://" not in uri:
254
+ # local path
255
+ storage_type = "local"
256
+ else:
257
+ prefix, _ = uri.split("://")
258
+ storage_type = prefix
259
+
260
+ assert storage_type in File._prefix_to_storage, (
261
+ f"Unsupported uri {uri}, valid prefixs: "
262
+ f"{list(File._prefix_to_storage.keys())}"
263
+ )
264
+
265
+ if storage_type not in G_STORAGES:
266
+ G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
267
+
268
+ return G_STORAGES[storage_type]
269
+
270
+ @staticmethod
271
+ def read(uri: str) -> bytes:
272
+ """Read data from a given ``filepath`` with 'rb' mode.
273
+
274
+ Args:
275
+ filepath (str or Path): Path to read data.
276
+
277
+ Returns:
278
+ bytes: Expected bytes object.
279
+ """
280
+ storage = File._get_storage(uri)
281
+ return storage.read(uri)
282
+
283
+ @staticmethod
284
+ def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
285
+ """Read data from a given ``filepath`` with 'r' mode.
286
+
287
+ Args:
288
+ filepath (str or Path): Path to read data.
289
+ encoding (str): The encoding format used to open the ``filepath``.
290
+ Default: 'utf-8'.
291
+
292
+ Returns:
293
+ str: Expected text reading from ``filepath``.
294
+ """
295
+ storage = File._get_storage(uri)
296
+ return storage.read_text(uri)
297
+
298
+ @staticmethod
299
+ def write(obj: bytes, uri: Union[str, Path]) -> None:
300
+ """Write data to a given ``filepath`` with 'wb' mode.
301
+
302
+ Note:
303
+ ``write`` will create a directory if the directory of ``filepath``
304
+ does not exist.
305
+
306
+ Args:
307
+ obj (bytes): Data to be written.
308
+ filepath (str or Path): Path to write data.
309
+ """
310
+ storage = File._get_storage(uri)
311
+ return storage.write(obj, uri)
312
+
313
+ @staticmethod
314
+ def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
315
+ """Write data to a given ``filepath`` with 'w' mode.
316
+
317
+ Note:
318
+ ``write_text`` will create a directory if the directory of
319
+ ``filepath`` does not exist.
320
+
321
+ Args:
322
+ obj (str): Data to be written.
323
+ filepath (str or Path): Path to write data.
324
+ encoding (str): The encoding format used to open the ``filepath``.
325
+ Default: 'utf-8'.
326
+ """
327
+ storage = File._get_storage(uri)
328
+ return storage.write_text(obj, uri)
329
+
330
+ @contextlib.contextmanager
331
+ def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
332
+ """Only for unified API and do nothing."""
333
+ storage = File._get_storage(uri)
334
+ with storage.as_local_path(uri) as local_path:
335
+ yield local_path
funasr_detach/download/name_maps_from_hub.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name_maps_ms = {
2
+ "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
3
+ "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
4
+ "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
5
+ "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
6
+ "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
7
+ "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
8
+ "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
9
+ "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
10
+ "cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
11
+ }
12
+
13
+ name_maps_hf = {}
funasr_detach/download/runtime_sdk_download_tool.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+ from funasr_detach.utils.types import str2bool
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--model-name", type=str, required=True)
11
+ parser.add_argument("--export-dir", type=str, required=True)
12
+ parser.add_argument(
13
+ "--export", type=str2bool, default=True, help="whether to export model"
14
+ )
15
+ parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torch"]')
16
+ parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]')
17
+ parser.add_argument(
18
+ "--quantize", type=str2bool, default=False, help="export quantized model"
19
+ )
20
+ parser.add_argument(
21
+ "--fallback-num", type=int, default=0, help="amp fallback number"
22
+ )
23
+ parser.add_argument("--audio_in", type=str, default=None, help='["wav", "wav.scp"]')
24
+ parser.add_argument(
25
+ "--model_revision", type=str, default=None, help="model_revision"
26
+ )
27
+ parser.add_argument("--calib_num", type=int, default=200, help="calib max num")
28
+ args = parser.parse_args()
29
+
30
+ model_dir = args.model_name
31
+ if not Path(args.model_name).exists():
32
+ from modelscope.hub.snapshot_download import snapshot_download
33
+
34
+ try:
35
+ model_dir = snapshot_download(
36
+ args.model_name, cache_dir=args.export_dir, revision=args.model_revision
37
+ )
38
+ except:
39
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
40
+ model_dir
41
+ )
42
+ if args.export:
43
+ model_file = os.path.join(model_dir, "model.onnx")
44
+ if args.quantize:
45
+ model_file = os.path.join(model_dir, "model_quant.onnx")
46
+ if not os.path.exists(model_file):
47
+ print(".onnx is not exist, begin to export onnx")
48
+ from funasr_detach.bin.export_model import ModelExport
49
+
50
+ export_model = ModelExport(
51
+ cache_dir=args.export_dir,
52
+ onnx=True,
53
+ device="cpu",
54
+ quant=args.quantize,
55
+ )
56
+ export_model.export(model_dir)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
funasr_detach/frontends/__init__.py ADDED
File without changes
funasr_detach/frontends/default.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+ import logging
6
+ import humanfriendly
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ from torch_complex.tensor import ComplexTensor
13
+ except:
14
+ print("Please install torch_complex firstly")
15
+
16
+ from funasr_detach.frontends.utils.log_mel import LogMel
17
+ from funasr_detach.frontends.utils.stft import Stft
18
+ from funasr_detach.frontends.utils.frontend import Frontend
19
+ from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
20
+
21
+
22
+ class DefaultFrontend(nn.Module):
23
+ """Conventional frontend structure for ASR.
24
+ Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ fs: Union[int, str] = 16000,
30
+ n_fft: int = 512,
31
+ win_length: int = None,
32
+ hop_length: int = 128,
33
+ window: Optional[str] = "hann",
34
+ center: bool = True,
35
+ normalized: bool = False,
36
+ onesided: bool = True,
37
+ n_mels: int = 80,
38
+ fmin: int = None,
39
+ fmax: int = None,
40
+ htk: bool = False,
41
+ frontend_conf: Optional[dict] = None,
42
+ apply_stft: bool = True,
43
+ use_channel: int = None,
44
+ ):
45
+ super().__init__()
46
+ if isinstance(fs, str):
47
+ fs = humanfriendly.parse_size(fs)
48
+
49
+ # Deepcopy (In general, dict shouldn't be used as default arg)
50
+ frontend_conf = copy.deepcopy(frontend_conf)
51
+ self.hop_length = hop_length
52
+
53
+ if apply_stft:
54
+ self.stft = Stft(
55
+ n_fft=n_fft,
56
+ win_length=win_length,
57
+ hop_length=hop_length,
58
+ center=center,
59
+ window=window,
60
+ normalized=normalized,
61
+ onesided=onesided,
62
+ )
63
+ else:
64
+ self.stft = None
65
+ self.apply_stft = apply_stft
66
+
67
+ if frontend_conf is not None:
68
+ self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
69
+ else:
70
+ self.frontend = None
71
+
72
+ self.logmel = LogMel(
73
+ fs=fs,
74
+ n_fft=n_fft,
75
+ n_mels=n_mels,
76
+ fmin=fmin,
77
+ fmax=fmax,
78
+ htk=htk,
79
+ )
80
+ self.n_mels = n_mels
81
+ self.use_channel = use_channel
82
+ self.frontend_type = "default"
83
+
84
+ def output_size(self) -> int:
85
+ return self.n_mels
86
+
87
+ def forward(
88
+ self, input: torch.Tensor, input_lengths: torch.Tensor
89
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
90
+ # 1. Domain-conversion: e.g. Stft: time -> time-freq
91
+ if self.stft is not None:
92
+ input_stft, feats_lens = self._compute_stft(input, input_lengths)
93
+ else:
94
+ input_stft = ComplexTensor(input[..., 0], input[..., 1])
95
+ feats_lens = input_lengths
96
+ # 2. [Option] Speech enhancement
97
+ if self.frontend is not None:
98
+ assert isinstance(input_stft, ComplexTensor), type(input_stft)
99
+ # input_stft: (Batch, Length, [Channel], Freq)
100
+ input_stft, _, mask = self.frontend(input_stft, feats_lens)
101
+
102
+ # 3. [Multi channel case]: Select a channel
103
+ if input_stft.dim() == 4:
104
+ # h: (B, T, C, F) -> h: (B, T, F)
105
+ if self.training:
106
+ if self.use_channel is not None:
107
+ input_stft = input_stft[:, :, self.use_channel, :]
108
+ else:
109
+ # Select 1ch randomly
110
+ ch = np.random.randint(input_stft.size(2))
111
+ input_stft = input_stft[:, :, ch, :]
112
+ else:
113
+ # Use the first channel
114
+ input_stft = input_stft[:, :, 0, :]
115
+
116
+ # 4. STFT -> Power spectrum
117
+ # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
118
+ input_power = input_stft.real**2 + input_stft.imag**2
119
+
120
+ # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
121
+ # input_power: (Batch, [Channel,] Length, Freq)
122
+ # -> input_feats: (Batch, Length, Dim)
123
+ input_feats, _ = self.logmel(input_power, feats_lens)
124
+
125
+ return input_feats, feats_lens
126
+
127
+ def _compute_stft(
128
+ self, input: torch.Tensor, input_lengths: torch.Tensor
129
+ ) -> torch.Tensor:
130
+ input_stft, feats_lens = self.stft(input, input_lengths)
131
+
132
+ assert input_stft.dim() >= 4, input_stft.shape
133
+ # "2" refers to the real/imag parts of Complex
134
+ assert input_stft.shape[-1] == 2, input_stft.shape
135
+
136
+ # Change torch.Tensor to ComplexTensor
137
+ # input_stft: (..., F, 2) -> (..., F)
138
+ input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
139
+ return input_stft, feats_lens
140
+
141
+
142
+ class MultiChannelFrontend(nn.Module):
143
+ """Conventional frontend structure for ASR.
144
+ Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ fs: Union[int, str] = 16000,
150
+ n_fft: int = 512,
151
+ win_length: int = None,
152
+ hop_length: int = None,
153
+ frame_length: int = None,
154
+ frame_shift: int = None,
155
+ window: Optional[str] = "hann",
156
+ center: bool = True,
157
+ normalized: bool = False,
158
+ onesided: bool = True,
159
+ n_mels: int = 80,
160
+ fmin: int = None,
161
+ fmax: int = None,
162
+ htk: bool = False,
163
+ frontend_conf: Optional[dict] = None,
164
+ apply_stft: bool = True,
165
+ use_channel: int = None,
166
+ lfr_m: int = 1,
167
+ lfr_n: int = 1,
168
+ cmvn_file: str = None,
169
+ mc: bool = True,
170
+ ):
171
+ super().__init__()
172
+ if isinstance(fs, str):
173
+ fs = humanfriendly.parse_size(fs)
174
+
175
+ # Deepcopy (In general, dict shouldn't be used as default arg)
176
+ frontend_conf = copy.deepcopy(frontend_conf)
177
+ if win_length is None and hop_length is None:
178
+ self.win_length = frame_length * 16
179
+ self.hop_length = frame_shift * 16
180
+ elif frame_length is None and frame_shift is None:
181
+ self.win_length = self.win_length
182
+ self.hop_length = self.hop_length
183
+ else:
184
+ logging.error(
185
+ "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
186
+ "can be set."
187
+ )
188
+ exit(1)
189
+
190
+ if apply_stft:
191
+ self.stft = Stft(
192
+ n_fft=n_fft,
193
+ win_length=self.win_length,
194
+ hop_length=self.hop_length,
195
+ center=center,
196
+ window=window,
197
+ normalized=normalized,
198
+ onesided=onesided,
199
+ )
200
+ else:
201
+ self.stft = None
202
+ self.apply_stft = apply_stft
203
+
204
+ if frontend_conf is not None:
205
+ self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
206
+ else:
207
+ self.frontend = None
208
+
209
+ self.logmel = LogMel(
210
+ fs=fs,
211
+ n_fft=n_fft,
212
+ n_mels=n_mels,
213
+ fmin=fmin,
214
+ fmax=fmax,
215
+ htk=htk,
216
+ )
217
+ self.n_mels = n_mels
218
+ self.use_channel = use_channel
219
+ self.mc = mc
220
+ if not self.mc:
221
+ if self.use_channel is not None:
222
+ logging.info("use the channel %d" % (self.use_channel))
223
+ else:
224
+ logging.info("random select channel")
225
+ self.cmvn_file = cmvn_file
226
+ if self.cmvn_file is not None:
227
+ mean, std = self._load_cmvn(self.cmvn_file)
228
+ self.register_buffer("mean", torch.from_numpy(mean))
229
+ self.register_buffer("std", torch.from_numpy(std))
230
+ self.frontend_type = "multichannelfrontend"
231
+
232
+ def output_size(self) -> int:
233
+ return self.n_mels
234
+
235
+ def forward(
236
+ self, input: torch.Tensor, input_lengths: torch.Tensor
237
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
238
+ # 1. Domain-conversion: e.g. Stft: time -> time-freq
239
+ # import pdb;pdb.set_trace()
240
+ if self.stft is not None:
241
+ input_stft, feats_lens = self._compute_stft(input, input_lengths)
242
+ else:
243
+ input_stft = ComplexTensor(input[..., 0], input[..., 1])
244
+ feats_lens = input_lengths
245
+ # 2. [Option] Speech enhancement
246
+ if self.frontend is not None:
247
+ assert isinstance(input_stft, ComplexTensor), type(input_stft)
248
+ # input_stft: (Batch, Length, [Channel], Freq)
249
+ input_stft, _, mask = self.frontend(input_stft, feats_lens)
250
+
251
+ # 3. [Multi channel case]: Select a channel(sa_asr)
252
+ if input_stft.dim() == 4 and not self.mc:
253
+ # h: (B, T, C, F) -> h: (B, T, F)
254
+ if self.training:
255
+ if self.use_channel is not None:
256
+ input_stft = input_stft[:, :, self.use_channel, :]
257
+
258
+ else:
259
+ # Select 1ch randomly
260
+ ch = np.random.randint(input_stft.size(2))
261
+ input_stft = input_stft[:, :, ch, :]
262
+ else:
263
+ # Use the first channel
264
+ input_stft = input_stft[:, :, 0, :]
265
+
266
+ # 4. STFT -> Power spectrum
267
+ # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
268
+ input_power = input_stft.real**2 + input_stft.imag**2
269
+
270
+ # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
271
+ # input_power: (Batch, [Channel,] Length, Freq)
272
+ # -> input_feats: (Batch, Length, Dim)
273
+ input_feats, _ = self.logmel(input_power, feats_lens)
274
+ if self.mc:
275
+ # MFCCA
276
+ if input_feats.dim() == 4:
277
+ bt = input_feats.size(0)
278
+ channel_size = input_feats.size(2)
279
+ input_feats = (
280
+ input_feats.transpose(1, 2)
281
+ .reshape(bt * channel_size, -1, 80)
282
+ .contiguous()
283
+ )
284
+ feats_lens = feats_lens.repeat(1, channel_size).squeeze()
285
+ else:
286
+ channel_size = 1
287
+ return input_feats, feats_lens, channel_size
288
+ else:
289
+ # 6. Apply CMVN
290
+ if self.cmvn_file is not None:
291
+ if feats_lens is None:
292
+ feats_lens = input_feats.new_full(
293
+ [input_feats.size(0)], input_feats.size(1)
294
+ )
295
+ self.mean = self.mean.to(input_feats.device, input_feats.dtype)
296
+ self.std = self.std.to(input_feats.device, input_feats.dtype)
297
+ mask = make_pad_mask(feats_lens, input_feats, 1)
298
+
299
+ if input_feats.requires_grad:
300
+ input_feats = input_feats + self.mean
301
+ else:
302
+ input_feats += self.mean
303
+ if input_feats.requires_grad:
304
+ input_feats = input_feats.masked_fill(mask, 0.0)
305
+ else:
306
+ input_feats.masked_fill_(mask, 0.0)
307
+
308
+ input_feats *= self.std
309
+
310
+ return input_feats, feats_lens
311
+
312
+ def _compute_stft(
313
+ self, input: torch.Tensor, input_lengths: torch.Tensor
314
+ ) -> torch.Tensor:
315
+ input_stft, feats_lens = self.stft(input, input_lengths)
316
+
317
+ assert input_stft.dim() >= 4, input_stft.shape
318
+ # "2" refers to the real/imag parts of Complex
319
+ assert input_stft.shape[-1] == 2, input_stft.shape
320
+
321
+ # Change torch.Tensor to ComplexTensor
322
+ # input_stft: (..., F, 2) -> (..., F)
323
+ input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
324
+ return input_stft, feats_lens
325
+
326
+ def _load_cmvn(self, cmvn_file):
327
+ with open(cmvn_file, "r", encoding="utf-8") as f:
328
+ lines = f.readlines()
329
+ means_list = []
330
+ vars_list = []
331
+ for i in range(len(lines)):
332
+ line_item = lines[i].split()
333
+ if line_item[0] == "<AddShift>":
334
+ line_item = lines[i + 1].split()
335
+ if line_item[0] == "<LearnRateCoef>":
336
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
337
+ means_list = list(add_shift_line)
338
+ continue
339
+ elif line_item[0] == "<Rescale>":
340
+ line_item = lines[i + 1].split()
341
+ if line_item[0] == "<LearnRateCoef>":
342
+ rescale_line = line_item[3 : (len(line_item) - 1)]
343
+ vars_list = list(rescale_line)
344
+ continue
345
+ means = np.array(means_list).astype(np.float)
346
+ vars = np.array(vars_list).astype(np.float)
347
+ return means, vars
funasr_detach/frontends/eend_ola_feature.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
2
+ # Licensed under the MIT license.
3
+ #
4
+ # This module is for computing audio features
5
+
6
+ import librosa
7
+ import numpy as np
8
+
9
+
10
+ def transform(Y, dtype=np.float32):
11
+ Y = np.abs(Y)
12
+ n_fft = 2 * (Y.shape[1] - 1)
13
+ sr = 8000
14
+ n_mels = 23
15
+ mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
16
+ Y = np.dot(Y**2, mel_basis.T)
17
+ Y = np.log10(np.maximum(Y, 1e-10))
18
+ mean = np.mean(Y, axis=0)
19
+ Y = Y - mean
20
+ return Y.astype(dtype)
21
+
22
+
23
+ def subsample(Y, T, subsampling=1):
24
+ Y_ss = Y[::subsampling]
25
+ T_ss = T[::subsampling]
26
+ return Y_ss, T_ss
27
+
28
+
29
+ def splice(Y, context_size=0):
30
+ Y_pad = np.pad(Y, [(context_size, context_size), (0, 0)], "constant")
31
+ Y_spliced = np.lib.stride_tricks.as_strided(
32
+ np.ascontiguousarray(Y_pad),
33
+ (Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
34
+ (Y.itemsize * Y.shape[1], Y.itemsize),
35
+ writeable=False,
36
+ )
37
+ return Y_spliced
38
+
39
+
40
+ def stft(data, frame_size=1024, frame_shift=256):
41
+ fft_size = 1 << (frame_size - 1).bit_length()
42
+ if len(data) % frame_shift == 0:
43
+ return librosa.stft(
44
+ data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift
45
+ ).T[:-1]
46
+ else:
47
+ return librosa.stft(
48
+ data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift
49
+ ).T
funasr_detach/frontends/fused.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from funasr_detach.frontends.default import DefaultFrontend
2
+ from funasr_detach.frontends.s3prl import S3prlFrontend
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Tuple
7
+
8
+
9
+ class FusedFrontends(nn.Module):
10
+ def __init__(
11
+ self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
12
+ ):
13
+
14
+ super().__init__()
15
+ self.align_method = (
16
+ align_method # fusing method : linear_projection only for now
17
+ )
18
+ self.proj_dim = proj_dim # dim of the projection done on each frontend
19
+ self.frontends = [] # list of the frontends to combine
20
+
21
+ for i, frontend in enumerate(frontends):
22
+ frontend_type = frontend["frontend_type"]
23
+ if frontend_type == "default":
24
+ n_mels, fs, n_fft, win_length, hop_length = (
25
+ frontend.get("n_mels", 80),
26
+ fs,
27
+ frontend.get("n_fft", 512),
28
+ frontend.get("win_length"),
29
+ frontend.get("hop_length", 128),
30
+ )
31
+ window, center, normalized, onesided = (
32
+ frontend.get("window", "hann"),
33
+ frontend.get("center", True),
34
+ frontend.get("normalized", False),
35
+ frontend.get("onesided", True),
36
+ )
37
+ fmin, fmax, htk, apply_stft = (
38
+ frontend.get("fmin", None),
39
+ frontend.get("fmax", None),
40
+ frontend.get("htk", False),
41
+ frontend.get("apply_stft", True),
42
+ )
43
+
44
+ self.frontends.append(
45
+ DefaultFrontend(
46
+ n_mels=n_mels,
47
+ n_fft=n_fft,
48
+ fs=fs,
49
+ win_length=win_length,
50
+ hop_length=hop_length,
51
+ window=window,
52
+ center=center,
53
+ normalized=normalized,
54
+ onesided=onesided,
55
+ fmin=fmin,
56
+ fmax=fmax,
57
+ htk=htk,
58
+ apply_stft=apply_stft,
59
+ )
60
+ )
61
+ elif frontend_type == "s3prl":
62
+ frontend_conf, download_dir, multilayer_feature = (
63
+ frontend.get("frontend_conf"),
64
+ frontend.get("download_dir"),
65
+ frontend.get("multilayer_feature"),
66
+ )
67
+ self.frontends.append(
68
+ S3prlFrontend(
69
+ fs=fs,
70
+ frontend_conf=frontend_conf,
71
+ download_dir=download_dir,
72
+ multilayer_feature=multilayer_feature,
73
+ )
74
+ )
75
+
76
+ else:
77
+ raise NotImplementedError # frontends are only default or s3prl
78
+
79
+ self.frontends = torch.nn.ModuleList(self.frontends)
80
+
81
+ self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
82
+ self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
83
+ if torch.cuda.is_available():
84
+ dev = "cuda"
85
+ else:
86
+ dev = "cpu"
87
+ if self.align_method == "linear_projection":
88
+ self.projection_layers = [
89
+ torch.nn.Linear(
90
+ in_features=frontend.output_size(),
91
+ out_features=self.factors[i] * self.proj_dim,
92
+ )
93
+ for i, frontend in enumerate(self.frontends)
94
+ ]
95
+ self.projection_layers = torch.nn.ModuleList(self.projection_layers)
96
+ self.projection_layers = self.projection_layers.to(torch.device(dev))
97
+
98
+ def output_size(self) -> int:
99
+ return len(self.frontends) * self.proj_dim
100
+
101
+ def forward(
102
+ self, input: torch.Tensor, input_lengths: torch.Tensor
103
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
104
+
105
+ # step 0 : get all frontends features
106
+ self.feats = []
107
+ for frontend in self.frontends:
108
+ with torch.no_grad():
109
+ input_feats, feats_lens = frontend.forward(input, input_lengths)
110
+ self.feats.append([input_feats, feats_lens])
111
+
112
+ if (
113
+ self.align_method == "linear_projection"
114
+ ): # TODO(Dan): to add other align methods
115
+
116
+ # first step : projections
117
+ self.feats_proj = []
118
+ for i, frontend in enumerate(self.frontends):
119
+ input_feats = self.feats[i][0]
120
+ self.feats_proj.append(self.projection_layers[i](input_feats))
121
+
122
+ # 2nd step : reshape
123
+ self.feats_reshaped = []
124
+ for i, frontend in enumerate(self.frontends):
125
+ input_feats_proj = self.feats_proj[i]
126
+ bs, nf, dim = input_feats_proj.shape
127
+ input_feats_reshaped = torch.reshape(
128
+ input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
129
+ )
130
+ self.feats_reshaped.append(input_feats_reshaped)
131
+
132
+ # 3rd step : drop the few last frames
133
+ m = min([x.shape[1] for x in self.feats_reshaped])
134
+ self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
135
+
136
+ input_feats = torch.cat(
137
+ self.feats_final, dim=-1
138
+ ) # change the input size of the preencoder : proj_dim * n_frontends
139
+ feats_lens = torch.ones_like(self.feats[0][1]) * (m)
140
+
141
+ else:
142
+ raise NotImplementedError
143
+
144
+ return input_feats, feats_lens
funasr_detach/frontends/s3prl.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from argparse import Namespace
5
+ from typing import Optional
6
+ from typing import Tuple
7
+ from typing import Union
8
+
9
+ import humanfriendly
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from funasr_detach.frontends.utils.frontend import Frontend
14
+ from funasr_detach.models.transformer.utils.nets_utils import pad_list
15
+
16
+
17
+ def base_s3prl_setup(args):
18
+ args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
19
+ args.upstream_model_config = getattr(args, "upstream_model_config", None)
20
+ args.upstream_refresh = getattr(args, "upstream_refresh", False)
21
+ args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
22
+ args.init_ckpt = getattr(args, "init_ckpt", None)
23
+ args.verbose = getattr(args, "verbose", False)
24
+ args.tile_factor = getattr(args, "tile_factor", 1)
25
+ return args
26
+
27
+
28
+ class S3prlFrontend(nn.Module):
29
+ """Speech Pretrained Representation frontend structure for ASR."""
30
+
31
+ def __init__(
32
+ self,
33
+ fs: Union[int, str] = 16000,
34
+ frontend_conf: Optional[dict] = None,
35
+ download_dir: str = None,
36
+ multilayer_feature: bool = False,
37
+ ):
38
+ super().__init__()
39
+ if isinstance(fs, str):
40
+ fs = humanfriendly.parse_size(fs)
41
+
42
+ if download_dir is not None:
43
+ torch.hub.set_dir(download_dir)
44
+
45
+ self.multilayer_feature = multilayer_feature
46
+ self.upstream, self.featurizer = self._get_upstream(frontend_conf)
47
+ self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
48
+ self.output_dim = self.featurizer.output_dim
49
+ self.frontend_type = "s3prl"
50
+ self.hop_length = self.upstream.get_downsample_rates("key")
51
+
52
+ def _get_upstream(self, frontend_conf):
53
+ """Get S3PRL upstream model."""
54
+ s3prl_args = base_s3prl_setup(
55
+ Namespace(**frontend_conf, device="cpu"),
56
+ )
57
+ self.args = s3prl_args
58
+
59
+ s3prl_path = None
60
+ python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
61
+ for p in python_path_list:
62
+ if p.endswith("s3prl"):
63
+ s3prl_path = p
64
+ break
65
+ assert s3prl_path is not None
66
+
67
+ s3prl_upstream = torch.hub.load(
68
+ s3prl_path,
69
+ s3prl_args.upstream,
70
+ ckpt=s3prl_args.upstream_ckpt,
71
+ model_config=s3prl_args.upstream_model_config,
72
+ refresh=s3prl_args.upstream_refresh,
73
+ source="local",
74
+ ).to("cpu")
75
+
76
+ if getattr(
77
+ s3prl_upstream, "model", None
78
+ ) is not None and s3prl_upstream.model.__class__.__name__ in [
79
+ "Wav2Vec2Model",
80
+ "HubertModel",
81
+ ]:
82
+ s3prl_upstream.model.encoder.layerdrop = 0.0
83
+
84
+ from s3prl.upstream.interfaces import Featurizer
85
+
86
+ if self.multilayer_feature is None:
87
+ feature_selection = "last_hidden_state"
88
+ else:
89
+ feature_selection = "hidden_states"
90
+ s3prl_featurizer = Featurizer(
91
+ upstream=s3prl_upstream,
92
+ feature_selection=feature_selection,
93
+ upstream_device="cpu",
94
+ )
95
+
96
+ return s3prl_upstream, s3prl_featurizer
97
+
98
+ def _tile_representations(self, feature):
99
+ """Tile up the representations by `tile_factor`.
100
+ Input - sequence of representations
101
+ shape: (batch_size, seq_len, feature_dim)
102
+ Output - sequence of tiled representations
103
+ shape: (batch_size, seq_len * factor, feature_dim)
104
+ """
105
+ assert (
106
+ len(feature.shape) == 3
107
+ ), "Input argument `feature` has invalid shape: {}".format(feature.shape)
108
+ tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
109
+ tiled_feature = tiled_feature.reshape(
110
+ feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
111
+ )
112
+ return tiled_feature
113
+
114
+ def output_size(self) -> int:
115
+ return self.output_dim
116
+
117
+ def forward(
118
+ self, input: torch.Tensor, input_lengths: torch.Tensor
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
121
+ self.upstream.eval()
122
+ with torch.no_grad():
123
+ feats = self.upstream(wavs)
124
+ feats = self.featurizer(wavs, feats)
125
+
126
+ if self.args.tile_factor != 1:
127
+ feats = self._tile_representations(feats)
128
+
129
+ input_feats = pad_list(feats, 0.0)
130
+ feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
131
+
132
+ # Saving CUDA Memory
133
+ del feats
134
+
135
+ return input_feats, feats_lens
136
+
137
+ def reload_pretrained_parameters(self):
138
+ self.upstream.load_state_dict(self.pretrained_params)
139
+ logging.info("Pretrained S3PRL frontend model parameters reloaded!")
funasr_detach/frontends/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
funasr_detach/frontends/utils/beamformer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_complex import functional as FC
3
+ from torch_complex.tensor import ComplexTensor
4
+
5
+
6
+ def get_power_spectral_density_matrix(
7
+ xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
8
+ ) -> ComplexTensor:
9
+ """Return cross-channel power spectral density (PSD) matrix
10
+
11
+ Args:
12
+ xs (ComplexTensor): (..., F, C, T)
13
+ mask (torch.Tensor): (..., F, C, T)
14
+ normalization (bool):
15
+ eps (float):
16
+ Returns
17
+ psd (ComplexTensor): (..., F, C, C)
18
+
19
+ """
20
+ # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
21
+ psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
22
+
23
+ # Averaging mask along C: (..., C, T) -> (..., T)
24
+ mask = mask.mean(dim=-2)
25
+
26
+ # Normalized mask along T: (..., T)
27
+ if normalization:
28
+ # If assuming the tensor is padded with zero, the summation along
29
+ # the time axis is same regardless of the padding length.
30
+ mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
31
+
32
+ # psd: (..., T, C, C)
33
+ psd = psd_Y * mask[..., None, None]
34
+ # (..., T, C, C) -> (..., C, C)
35
+ psd = psd.sum(dim=-3)
36
+
37
+ return psd
38
+
39
+
40
+ def get_mvdr_vector(
41
+ psd_s: ComplexTensor,
42
+ psd_n: ComplexTensor,
43
+ reference_vector: torch.Tensor,
44
+ eps: float = 1e-15,
45
+ ) -> ComplexTensor:
46
+ """Return the MVDR(Minimum Variance Distortionless Response) vector:
47
+
48
+ h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
49
+
50
+ Reference:
51
+ On optimal frequency-domain multichannel linear filtering
52
+ for noise reduction; M. Souden et al., 2010;
53
+ https://ieeexplore.ieee.org/document/5089420
54
+
55
+ Args:
56
+ psd_s (ComplexTensor): (..., F, C, C)
57
+ psd_n (ComplexTensor): (..., F, C, C)
58
+ reference_vector (torch.Tensor): (..., C)
59
+ eps (float):
60
+ Returns:
61
+ beamform_vector (ComplexTensor)r: (..., F, C)
62
+ """
63
+ # Add eps
64
+ C = psd_n.size(-1)
65
+ eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
66
+ shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
67
+ eye = eye.view(*shape)
68
+ psd_n += eps * eye
69
+
70
+ # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
71
+ numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
72
+ # ws: (..., C, C) / (...,) -> (..., C, C)
73
+ ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
74
+ # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
75
+ beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
76
+ return beamform_vector
77
+
78
+
79
+ def apply_beamforming_vector(
80
+ beamform_vector: ComplexTensor, mix: ComplexTensor
81
+ ) -> ComplexTensor:
82
+ # (..., C) x (..., C, T) -> (..., T)
83
+ es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
84
+ return es
funasr_detach/frontends/utils/complex_utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Beamformer module."""
2
+
3
+ from distutils.version import LooseVersion
4
+ from typing import Sequence
5
+ from typing import Tuple
6
+ from typing import Union
7
+
8
+ import torch
9
+
10
+ try:
11
+ from torch_complex import functional as FC
12
+ from torch_complex.tensor import ComplexTensor
13
+ except:
14
+ print("Please install torch_complex firstly")
15
+
16
+
17
+ EPS = torch.finfo(torch.double).eps
18
+ is_torch_1_8_plus = LooseVersion(torch.__version__) >= LooseVersion("1.8.0")
19
+ is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
20
+
21
+
22
+ def new_complex_like(
23
+ ref: Union[torch.Tensor, ComplexTensor],
24
+ real_imag: Tuple[torch.Tensor, torch.Tensor],
25
+ ):
26
+ if isinstance(ref, ComplexTensor):
27
+ return ComplexTensor(*real_imag)
28
+ elif is_torch_complex_tensor(ref):
29
+ return torch.complex(*real_imag)
30
+ else:
31
+ raise ValueError(
32
+ "Please update your PyTorch version to 1.9+ for complex support."
33
+ )
34
+
35
+
36
+ def is_torch_complex_tensor(c):
37
+ return (
38
+ not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
39
+ )
40
+
41
+
42
+ def is_complex(c):
43
+ return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
44
+
45
+
46
+ def to_double(c):
47
+ if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
48
+ return c.to(dtype=torch.complex128)
49
+ else:
50
+ return c.double()
51
+
52
+
53
+ def to_float(c):
54
+ if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
55
+ return c.to(dtype=torch.complex64)
56
+ else:
57
+ return c.float()
58
+
59
+
60
+ def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
61
+ if not isinstance(seq, (list, tuple)):
62
+ raise TypeError(
63
+ "cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
64
+ "not Tensor"
65
+ )
66
+ if isinstance(seq[0], ComplexTensor):
67
+ return FC.cat(seq, *args, **kwargs)
68
+ else:
69
+ return torch.cat(seq, *args, **kwargs)
70
+
71
+
72
+ def complex_norm(
73
+ c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
74
+ ) -> torch.Tensor:
75
+ if not is_complex(c):
76
+ raise TypeError("Input is not a complex tensor.")
77
+ if is_torch_complex_tensor(c):
78
+ return torch.norm(c, dim=dim, keepdim=keepdim)
79
+ else:
80
+ return torch.sqrt((c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS)
81
+
82
+
83
+ def einsum(equation, *operands):
84
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
85
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
86
+ # mixed input with complex and real tensors.
87
+ if len(operands) == 1:
88
+ if isinstance(operands[0], (tuple, list)):
89
+ operands = operands[0]
90
+ complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
91
+ return complex_module.einsum(equation, *operands)
92
+ elif len(operands) != 2:
93
+ op0 = operands[0]
94
+ same_type = all(op.dtype == op0.dtype for op in operands[1:])
95
+ if same_type:
96
+ _einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
97
+ return _einsum(equation, *operands)
98
+ else:
99
+ raise ValueError("0 or More than 2 operands are not supported.")
100
+ a, b = operands
101
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
102
+ return FC.einsum(equation, a, b)
103
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
104
+ if not torch.is_complex(a):
105
+ o_real = torch.einsum(equation, a, b.real)
106
+ o_imag = torch.einsum(equation, a, b.imag)
107
+ return torch.complex(o_real, o_imag)
108
+ elif not torch.is_complex(b):
109
+ o_real = torch.einsum(equation, a.real, b)
110
+ o_imag = torch.einsum(equation, a.imag, b)
111
+ return torch.complex(o_real, o_imag)
112
+ else:
113
+ return torch.einsum(equation, a, b)
114
+ else:
115
+ return torch.einsum(equation, a, b)
116
+
117
+
118
+ def inverse(
119
+ c: Union[torch.Tensor, ComplexTensor],
120
+ ) -> Union[torch.Tensor, ComplexTensor]:
121
+ if isinstance(c, ComplexTensor):
122
+ return c.inverse2()
123
+ else:
124
+ return c.inverse()
125
+
126
+
127
+ def matmul(
128
+ a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
129
+ ) -> Union[torch.Tensor, ComplexTensor]:
130
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
131
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
132
+ # multiplication between complex and real tensors.
133
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
134
+ return FC.matmul(a, b)
135
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
136
+ if not torch.is_complex(a):
137
+ o_real = torch.matmul(a, b.real)
138
+ o_imag = torch.matmul(a, b.imag)
139
+ return torch.complex(o_real, o_imag)
140
+ elif not torch.is_complex(b):
141
+ o_real = torch.matmul(a.real, b)
142
+ o_imag = torch.matmul(a.imag, b)
143
+ return torch.complex(o_real, o_imag)
144
+ else:
145
+ return torch.matmul(a, b)
146
+ else:
147
+ return torch.matmul(a, b)
148
+
149
+
150
+ def trace(a: Union[torch.Tensor, ComplexTensor]):
151
+ # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
152
+ # support bacth processing. Use FC.trace() as fallback.
153
+ return FC.trace(a)
154
+
155
+
156
+ def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
157
+ if isinstance(a, ComplexTensor):
158
+ return FC.reverse(a, dim=dim)
159
+ else:
160
+ return torch.flip(a, dims=(dim,))
161
+
162
+
163
+ def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
164
+ """Solve the linear equation ax = b."""
165
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
166
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
167
+ # mixed input with complex and real tensors.
168
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
169
+ if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
170
+ return FC.solve(b, a, return_LU=False)
171
+ else:
172
+ return matmul(inverse(a), b)
173
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
174
+ if torch.is_complex(a) and torch.is_complex(b):
175
+ return torch.linalg.solve(a, b)
176
+ else:
177
+ return matmul(inverse(a), b)
178
+ else:
179
+ if is_torch_1_8_plus:
180
+ return torch.linalg.solve(a, b)
181
+ else:
182
+ return torch.solve(b, a)[0]
183
+
184
+
185
+ def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
186
+ if not isinstance(seq, (list, tuple)):
187
+ raise TypeError(
188
+ "stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
189
+ "not Tensor"
190
+ )
191
+ if isinstance(seq[0], ComplexTensor):
192
+ return FC.stack(seq, *args, **kwargs)
193
+ else:
194
+ return torch.stack(seq, *args, **kwargs)
funasr_detach/frontends/utils/dnn_beamformer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DNN beamformer module."""
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ from torch.nn import functional as F
7
+
8
+ from funasr_detach.frontends.utils.beamformer import apply_beamforming_vector
9
+ from funasr_detach.frontends.utils.beamformer import get_mvdr_vector
10
+ from funasr_detach.frontends.utils.beamformer import (
11
+ get_power_spectral_density_matrix, # noqa: H301
12
+ )
13
+ from funasr_detach.frontends.utils.mask_estimator import MaskEstimator
14
+ from torch_complex.tensor import ComplexTensor
15
+
16
+
17
+ class DNN_Beamformer(torch.nn.Module):
18
+ """DNN mask based Beamformer
19
+
20
+ Citation:
21
+ Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
22
+ https://arxiv.org/abs/1703.04783
23
+
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ bidim,
29
+ btype="blstmp",
30
+ blayers=3,
31
+ bunits=300,
32
+ bprojs=320,
33
+ bnmask=2,
34
+ dropout_rate=0.0,
35
+ badim=320,
36
+ ref_channel: int = -1,
37
+ beamformer_type="mvdr",
38
+ ):
39
+ super().__init__()
40
+ self.mask = MaskEstimator(
41
+ btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
42
+ )
43
+ self.ref = AttentionReference(bidim, badim)
44
+ self.ref_channel = ref_channel
45
+
46
+ self.nmask = bnmask
47
+
48
+ if beamformer_type != "mvdr":
49
+ raise ValueError(
50
+ "Not supporting beamformer_type={}".format(beamformer_type)
51
+ )
52
+ self.beamformer_type = beamformer_type
53
+
54
+ def forward(
55
+ self, data: ComplexTensor, ilens: torch.LongTensor
56
+ ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
57
+ """The forward function
58
+
59
+ Notation:
60
+ B: Batch
61
+ C: Channel
62
+ T: Time or Sequence length
63
+ F: Freq
64
+
65
+ Args:
66
+ data (ComplexTensor): (B, T, C, F)
67
+ ilens (torch.Tensor): (B,)
68
+ Returns:
69
+ enhanced (ComplexTensor): (B, T, F)
70
+ ilens (torch.Tensor): (B,)
71
+
72
+ """
73
+
74
+ def apply_beamforming(data, ilens, psd_speech, psd_noise):
75
+ # u: (B, C)
76
+ if self.ref_channel < 0:
77
+ u, _ = self.ref(psd_speech, ilens)
78
+ else:
79
+ # (optional) Create onehot vector for fixed reference microphone
80
+ u = torch.zeros(
81
+ *(data.size()[:-3] + (data.size(-2),)), device=data.device
82
+ )
83
+ u[..., self.ref_channel].fill_(1)
84
+
85
+ ws = get_mvdr_vector(psd_speech, psd_noise, u)
86
+ enhanced = apply_beamforming_vector(ws, data)
87
+
88
+ return enhanced, ws
89
+
90
+ # data (B, T, C, F) -> (B, F, C, T)
91
+ data = data.permute(0, 3, 2, 1)
92
+
93
+ # mask: (B, F, C, T)
94
+ masks, _ = self.mask(data, ilens)
95
+ assert self.nmask == len(masks)
96
+
97
+ if self.nmask == 2: # (mask_speech, mask_noise)
98
+ mask_speech, mask_noise = masks
99
+
100
+ psd_speech = get_power_spectral_density_matrix(data, mask_speech)
101
+ psd_noise = get_power_spectral_density_matrix(data, mask_noise)
102
+
103
+ enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
104
+
105
+ # (..., F, T) -> (..., T, F)
106
+ enhanced = enhanced.transpose(-1, -2)
107
+ mask_speech = mask_speech.transpose(-1, -3)
108
+ else: # multi-speaker case: (mask_speech1, ..., mask_noise)
109
+ mask_speech = list(masks[:-1])
110
+ mask_noise = masks[-1]
111
+
112
+ psd_speeches = [
113
+ get_power_spectral_density_matrix(data, mask) for mask in mask_speech
114
+ ]
115
+ psd_noise = get_power_spectral_density_matrix(data, mask_noise)
116
+
117
+ enhanced = []
118
+ ws = []
119
+ for i in range(self.nmask - 1):
120
+ psd_speech = psd_speeches.pop(i)
121
+ # treat all other speakers' psd_speech as noises
122
+ enh, w = apply_beamforming(
123
+ data, ilens, psd_speech, sum(psd_speeches) + psd_noise
124
+ )
125
+ psd_speeches.insert(i, psd_speech)
126
+
127
+ # (..., F, T) -> (..., T, F)
128
+ enh = enh.transpose(-1, -2)
129
+ mask_speech[i] = mask_speech[i].transpose(-1, -3)
130
+
131
+ enhanced.append(enh)
132
+ ws.append(w)
133
+
134
+ return enhanced, ilens, mask_speech
135
+
136
+
137
+ class AttentionReference(torch.nn.Module):
138
+ def __init__(self, bidim, att_dim):
139
+ super().__init__()
140
+ self.mlp_psd = torch.nn.Linear(bidim, att_dim)
141
+ self.gvec = torch.nn.Linear(att_dim, 1)
142
+
143
+ def forward(
144
+ self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
145
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
146
+ """The forward function
147
+
148
+ Args:
149
+ psd_in (ComplexTensor): (B, F, C, C)
150
+ ilens (torch.Tensor): (B,)
151
+ scaling (float):
152
+ Returns:
153
+ u (torch.Tensor): (B, C)
154
+ ilens (torch.Tensor): (B,)
155
+ """
156
+ B, _, C = psd_in.size()[:3]
157
+ assert psd_in.size(2) == psd_in.size(3), psd_in.size()
158
+ # psd_in: (B, F, C, C)
159
+ psd = psd_in.masked_fill(
160
+ torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
161
+ )
162
+ # psd: (B, F, C, C) -> (B, C, F)
163
+ psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
164
+
165
+ # Calculate amplitude
166
+ psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
167
+
168
+ # (B, C, F) -> (B, C, F2)
169
+ mlp_psd = self.mlp_psd(psd_feat)
170
+ # (B, C, F2) -> (B, C, 1) -> (B, C)
171
+ e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
172
+ u = F.softmax(scaling * e, dim=-1)
173
+ return u, ilens
funasr_detach/frontends/utils/dnn_wpe.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from pytorch_wpe import wpe_one_iteration
4
+ import torch
5
+ from torch_complex.tensor import ComplexTensor
6
+
7
+ from funasr_detach.frontends.utils.mask_estimator import MaskEstimator
8
+ from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
9
+
10
+
11
+ class DNN_WPE(torch.nn.Module):
12
+ def __init__(
13
+ self,
14
+ wtype: str = "blstmp",
15
+ widim: int = 257,
16
+ wlayers: int = 3,
17
+ wunits: int = 300,
18
+ wprojs: int = 320,
19
+ dropout_rate: float = 0.0,
20
+ taps: int = 5,
21
+ delay: int = 3,
22
+ use_dnn_mask: bool = True,
23
+ iterations: int = 1,
24
+ normalization: bool = False,
25
+ ):
26
+ super().__init__()
27
+ self.iterations = iterations
28
+ self.taps = taps
29
+ self.delay = delay
30
+
31
+ self.normalization = normalization
32
+ self.use_dnn_mask = use_dnn_mask
33
+
34
+ self.inverse_power = True
35
+
36
+ if self.use_dnn_mask:
37
+ self.mask_est = MaskEstimator(
38
+ wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
39
+ )
40
+
41
+ def forward(
42
+ self, data: ComplexTensor, ilens: torch.LongTensor
43
+ ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
44
+ """The forward function
45
+
46
+ Notation:
47
+ B: Batch
48
+ C: Channel
49
+ T: Time or Sequence length
50
+ F: Freq or Some dimension of the feature vector
51
+
52
+ Args:
53
+ data: (B, C, T, F)
54
+ ilens: (B,)
55
+ Returns:
56
+ data: (B, C, T, F)
57
+ ilens: (B,)
58
+ """
59
+ # (B, T, C, F) -> (B, F, C, T)
60
+ enhanced = data = data.permute(0, 3, 2, 1)
61
+ mask = None
62
+
63
+ for i in range(self.iterations):
64
+ # Calculate power: (..., C, T)
65
+ power = enhanced.real**2 + enhanced.imag**2
66
+ if i == 0 and self.use_dnn_mask:
67
+ # mask: (B, F, C, T)
68
+ (mask,), _ = self.mask_est(enhanced, ilens)
69
+ if self.normalization:
70
+ # Normalize along T
71
+ mask = mask / mask.sum(dim=-1)[..., None]
72
+ # (..., C, T) * (..., C, T) -> (..., C, T)
73
+ power = power * mask
74
+
75
+ # Averaging along the channel axis: (..., C, T) -> (..., T)
76
+ power = power.mean(dim=-2)
77
+
78
+ # enhanced: (..., C, T) -> (..., C, T)
79
+ enhanced = wpe_one_iteration(
80
+ data.contiguous(),
81
+ power,
82
+ taps=self.taps,
83
+ delay=self.delay,
84
+ inverse_power=self.inverse_power,
85
+ )
86
+
87
+ enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
88
+
89
+ # (B, F, C, T) -> (B, T, C, F)
90
+ enhanced = enhanced.permute(0, 3, 2, 1)
91
+ if mask is not None:
92
+ mask = mask.transpose(-1, -3)
93
+ return enhanced, ilens, mask
funasr_detach/frontends/utils/feature_transform.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from typing import Tuple
3
+ from typing import Union
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ from torch_complex.tensor import ComplexTensor
9
+
10
+ from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
11
+
12
+
13
+ class FeatureTransform(torch.nn.Module):
14
+ def __init__(
15
+ self,
16
+ # Mel options,
17
+ fs: int = 16000,
18
+ n_fft: int = 512,
19
+ n_mels: int = 80,
20
+ fmin: float = 0.0,
21
+ fmax: float = None,
22
+ # Normalization
23
+ stats_file: str = None,
24
+ apply_uttmvn: bool = True,
25
+ uttmvn_norm_means: bool = True,
26
+ uttmvn_norm_vars: bool = False,
27
+ ):
28
+ super().__init__()
29
+ self.apply_uttmvn = apply_uttmvn
30
+
31
+ self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
32
+ self.stats_file = stats_file
33
+ if stats_file is not None:
34
+ self.global_mvn = GlobalMVN(stats_file)
35
+ else:
36
+ self.global_mvn = None
37
+
38
+ if self.apply_uttmvn is not None:
39
+ self.uttmvn = UtteranceMVN(
40
+ norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
41
+ )
42
+ else:
43
+ self.uttmvn = None
44
+
45
+ def forward(
46
+ self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
47
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
48
+ # (B, T, F) or (B, T, C, F)
49
+ if x.dim() not in (3, 4):
50
+ raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
51
+ if not torch.is_tensor(ilens):
52
+ ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
53
+
54
+ if x.dim() == 4:
55
+ # h: (B, T, C, F) -> h: (B, T, F)
56
+ if self.training:
57
+ # Select 1ch randomly
58
+ ch = np.random.randint(x.size(2))
59
+ h = x[:, :, ch, :]
60
+ else:
61
+ # Use the first channel
62
+ h = x[:, :, 0, :]
63
+ else:
64
+ h = x
65
+
66
+ # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
67
+ h = h.real**2 + h.imag**2
68
+
69
+ h, _ = self.logmel(h, ilens)
70
+ if self.stats_file is not None:
71
+ h, _ = self.global_mvn(h, ilens)
72
+ if self.apply_uttmvn:
73
+ h, _ = self.uttmvn(h, ilens)
74
+
75
+ return h, ilens
76
+
77
+
78
+ class LogMel(torch.nn.Module):
79
+ """Convert STFT to fbank feats
80
+
81
+ The arguments is same as librosa.filters.mel
82
+
83
+ Args:
84
+ fs: number > 0 [scalar] sampling rate of the incoming signal
85
+ n_fft: int > 0 [scalar] number of FFT components
86
+ n_mels: int > 0 [scalar] number of Mel bands to generate
87
+ fmin: float >= 0 [scalar] lowest frequency (in Hz)
88
+ fmax: float >= 0 [scalar] highest frequency (in Hz).
89
+ If `None`, use `fmax = fs / 2.0`
90
+ htk: use HTK formula instead of Slaney
91
+ norm: {None, 1, np.inf} [scalar]
92
+ if 1, divide the triangular mel weights by the width of the mel band
93
+ (area normalization). Otherwise, leave all the triangles aiming for
94
+ a peak value of 1.0
95
+
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ fs: int = 16000,
101
+ n_fft: int = 512,
102
+ n_mels: int = 80,
103
+ fmin: float = 0.0,
104
+ fmax: float = None,
105
+ htk: bool = False,
106
+ norm=1,
107
+ ):
108
+ super().__init__()
109
+
110
+ _mel_options = dict(
111
+ sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
112
+ )
113
+ self.mel_options = _mel_options
114
+
115
+ # Note(kamo): The mel matrix of librosa is different from kaldi.
116
+ melmat = librosa.filters.mel(**_mel_options)
117
+ # melmat: (D2, D1) -> (D1, D2)
118
+ self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
119
+
120
+ def extra_repr(self):
121
+ return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
122
+
123
+ def forward(
124
+ self, feat: torch.Tensor, ilens: torch.LongTensor
125
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
126
+ # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
127
+ mel_feat = torch.matmul(feat, self.melmat)
128
+
129
+ logmel_feat = (mel_feat + 1e-20).log()
130
+ # Zero padding
131
+ logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
132
+ return logmel_feat, ilens
133
+
134
+
135
+ class GlobalMVN(torch.nn.Module):
136
+ """Apply global mean and variance normalization
137
+
138
+ Args:
139
+ stats_file(str): npy file of 1-dim array or text file.
140
+ From the _first element to
141
+ the {(len(array) - 1) / 2}th element are treated as
142
+ the sum of features,
143
+ and the rest excluding the last elements are
144
+ treated as the sum of the square value of features,
145
+ and the last elements eqauls to the number of samples.
146
+ std_floor(float):
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ stats_file: str,
152
+ norm_means: bool = True,
153
+ norm_vars: bool = True,
154
+ eps: float = 1.0e-20,
155
+ ):
156
+ super().__init__()
157
+ self.norm_means = norm_means
158
+ self.norm_vars = norm_vars
159
+
160
+ self.stats_file = stats_file
161
+ stats = np.load(stats_file)
162
+
163
+ stats = stats.astype(float)
164
+ assert (len(stats) - 1) % 2 == 0, stats.shape
165
+
166
+ count = stats.flatten()[-1]
167
+ mean = stats[: (len(stats) - 1) // 2] / count
168
+ var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
169
+ std = np.maximum(np.sqrt(var), eps)
170
+
171
+ self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
172
+ self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
173
+
174
+ def extra_repr(self):
175
+ return (
176
+ f"stats_file={self.stats_file}, "
177
+ f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
178
+ )
179
+
180
+ def forward(
181
+ self, x: torch.Tensor, ilens: torch.LongTensor
182
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
183
+ # feat: (B, T, D)
184
+ if self.norm_means:
185
+ x += self.bias.type_as(x)
186
+ x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
187
+
188
+ if self.norm_vars:
189
+ x *= self.scale.type_as(x)
190
+ return x, ilens
191
+
192
+
193
+ class UtteranceMVN(torch.nn.Module):
194
+ def __init__(
195
+ self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
196
+ ):
197
+ super().__init__()
198
+ self.norm_means = norm_means
199
+ self.norm_vars = norm_vars
200
+ self.eps = eps
201
+
202
+ def extra_repr(self):
203
+ return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
204
+
205
+ def forward(
206
+ self, x: torch.Tensor, ilens: torch.LongTensor
207
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
208
+ return utterance_mvn(
209
+ x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
210
+ )
211
+
212
+
213
+ def utterance_mvn(
214
+ x: torch.Tensor,
215
+ ilens: torch.LongTensor,
216
+ norm_means: bool = True,
217
+ norm_vars: bool = False,
218
+ eps: float = 1.0e-20,
219
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
220
+ """Apply utterance mean and variance normalization
221
+
222
+ Args:
223
+ x: (B, T, D), assumed zero padded
224
+ ilens: (B, T, D)
225
+ norm_means:
226
+ norm_vars:
227
+ eps:
228
+
229
+ """
230
+ ilens_ = ilens.type_as(x)
231
+ # mean: (B, D)
232
+ mean = x.sum(dim=1) / ilens_[:, None]
233
+
234
+ if norm_means:
235
+ x -= mean[:, None, :]
236
+ x_ = x
237
+ else:
238
+ x_ = x - mean[:, None, :]
239
+
240
+ # Zero padding
241
+ x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
242
+ if norm_vars:
243
+ var = x_.pow(2).sum(dim=1) / ilens_[:, None]
244
+ var = torch.clamp(var, min=eps)
245
+ x /= var.sqrt()[:, None, :]
246
+ x_ = x
247
+ return x_, ilens
248
+
249
+
250
+ def feature_transform_for(args, n_fft):
251
+ return FeatureTransform(
252
+ # Mel options,
253
+ fs=args.fbank_fs,
254
+ n_fft=n_fft,
255
+ n_mels=args.n_mels,
256
+ fmin=args.fbank_fmin,
257
+ fmax=args.fbank_fmax,
258
+ # Normalization
259
+ stats_file=args.stats_file,
260
+ apply_uttmvn=args.apply_uttmvn,
261
+ uttmvn_norm_means=args.uttmvn_norm_means,
262
+ uttmvn_norm_vars=args.uttmvn_norm_vars,
263
+ )
funasr_detach/frontends/utils/frontend.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import numpy
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch_complex.tensor import ComplexTensor
10
+
11
+ from funasr_detach.frontends.utils.dnn_beamformer import DNN_Beamformer
12
+ from funasr_detach.frontends.utils.dnn_wpe import DNN_WPE
13
+
14
+
15
+ class Frontend(nn.Module):
16
+ def __init__(
17
+ self,
18
+ idim: int,
19
+ # WPE options
20
+ use_wpe: bool = False,
21
+ wtype: str = "blstmp",
22
+ wlayers: int = 3,
23
+ wunits: int = 300,
24
+ wprojs: int = 320,
25
+ wdropout_rate: float = 0.0,
26
+ taps: int = 5,
27
+ delay: int = 3,
28
+ use_dnn_mask_for_wpe: bool = True,
29
+ # Beamformer options
30
+ use_beamformer: bool = False,
31
+ btype: str = "blstmp",
32
+ blayers: int = 3,
33
+ bunits: int = 300,
34
+ bprojs: int = 320,
35
+ bnmask: int = 2,
36
+ badim: int = 320,
37
+ ref_channel: int = -1,
38
+ bdropout_rate=0.0,
39
+ ):
40
+ super().__init__()
41
+
42
+ self.use_beamformer = use_beamformer
43
+ self.use_wpe = use_wpe
44
+ self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
45
+ # use frontend for all the data,
46
+ # e.g. in the case of multi-speaker speech separation
47
+ self.use_frontend_for_all = bnmask > 2
48
+
49
+ if self.use_wpe:
50
+ if self.use_dnn_mask_for_wpe:
51
+ # Use DNN for power estimation
52
+ # (Not observed significant gains)
53
+ iterations = 1
54
+ else:
55
+ # Performing as conventional WPE, without DNN Estimator
56
+ iterations = 2
57
+
58
+ self.wpe = DNN_WPE(
59
+ wtype=wtype,
60
+ widim=idim,
61
+ wunits=wunits,
62
+ wprojs=wprojs,
63
+ wlayers=wlayers,
64
+ taps=taps,
65
+ delay=delay,
66
+ dropout_rate=wdropout_rate,
67
+ iterations=iterations,
68
+ use_dnn_mask=use_dnn_mask_for_wpe,
69
+ )
70
+ else:
71
+ self.wpe = None
72
+
73
+ if self.use_beamformer:
74
+ self.beamformer = DNN_Beamformer(
75
+ btype=btype,
76
+ bidim=idim,
77
+ bunits=bunits,
78
+ bprojs=bprojs,
79
+ blayers=blayers,
80
+ bnmask=bnmask,
81
+ dropout_rate=bdropout_rate,
82
+ badim=badim,
83
+ ref_channel=ref_channel,
84
+ )
85
+ else:
86
+ self.beamformer = None
87
+
88
+ def forward(
89
+ self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
90
+ ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
91
+ assert len(x) == len(ilens), (len(x), len(ilens))
92
+ # (B, T, F) or (B, T, C, F)
93
+ if x.dim() not in (3, 4):
94
+ raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
95
+ if not torch.is_tensor(ilens):
96
+ ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
97
+
98
+ mask = None
99
+ h = x
100
+ if h.dim() == 4:
101
+ if self.training:
102
+ choices = [(False, False)] if not self.use_frontend_for_all else []
103
+ if self.use_wpe:
104
+ choices.append((True, False))
105
+
106
+ if self.use_beamformer:
107
+ choices.append((False, True))
108
+
109
+ use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
110
+
111
+ else:
112
+ use_wpe = self.use_wpe
113
+ use_beamformer = self.use_beamformer
114
+
115
+ # 1. WPE
116
+ if use_wpe:
117
+ # h: (B, T, C, F) -> h: (B, T, C, F)
118
+ h, ilens, mask = self.wpe(h, ilens)
119
+
120
+ # 2. Beamformer
121
+ if use_beamformer:
122
+ # h: (B, T, C, F) -> h: (B, T, F)
123
+ h, ilens, mask = self.beamformer(h, ilens)
124
+
125
+ return h, ilens, mask
126
+
127
+
128
+ def frontend_for(args, idim):
129
+ return Frontend(
130
+ idim=idim,
131
+ # WPE options
132
+ use_wpe=args.use_wpe,
133
+ wtype=args.wtype,
134
+ wlayers=args.wlayers,
135
+ wunits=args.wunits,
136
+ wprojs=args.wprojs,
137
+ wdropout_rate=args.wdropout_rate,
138
+ taps=args.wpe_taps,
139
+ delay=args.wpe_delay,
140
+ use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
141
+ # Beamformer options
142
+ use_beamformer=args.use_beamformer,
143
+ btype=args.btype,
144
+ blayers=args.blayers,
145
+ bunits=args.bunits,
146
+ bprojs=args.bprojs,
147
+ bnmask=args.bnmask,
148
+ badim=args.badim,
149
+ ref_channel=args.ref_channel,
150
+ bdropout_rate=args.bdropout_rate,
151
+ )
funasr_detach/frontends/utils/log_mel.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ from typing import Tuple
4
+
5
+ from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
6
+
7
+
8
+ class LogMel(torch.nn.Module):
9
+ """Convert STFT to fbank feats
10
+
11
+ The arguments is same as librosa.filters.mel
12
+
13
+ Args:
14
+ fs: number > 0 [scalar] sampling rate of the incoming signal
15
+ n_fft: int > 0 [scalar] number of FFT components
16
+ n_mels: int > 0 [scalar] number of Mel bands to generate
17
+ fmin: float >= 0 [scalar] lowest frequency (in Hz)
18
+ fmax: float >= 0 [scalar] highest frequency (in Hz).
19
+ If `None`, use `fmax = fs / 2.0`
20
+ htk: use HTK formula instead of Slaney
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ fs: int = 16000,
26
+ n_fft: int = 512,
27
+ n_mels: int = 80,
28
+ fmin: float = None,
29
+ fmax: float = None,
30
+ htk: bool = False,
31
+ log_base: float = None,
32
+ ):
33
+ super().__init__()
34
+
35
+ fmin = 0 if fmin is None else fmin
36
+ fmax = fs / 2 if fmax is None else fmax
37
+ _mel_options = dict(
38
+ sr=fs,
39
+ n_fft=n_fft,
40
+ n_mels=n_mels,
41
+ fmin=fmin,
42
+ fmax=fmax,
43
+ htk=htk,
44
+ )
45
+ self.mel_options = _mel_options
46
+ self.log_base = log_base
47
+
48
+ # Note(kamo): The mel matrix of librosa is different from kaldi.
49
+ melmat = librosa.filters.mel(**_mel_options)
50
+ # melmat: (D2, D1) -> (D1, D2)
51
+ self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
52
+
53
+ def extra_repr(self):
54
+ return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
55
+
56
+ def forward(
57
+ self,
58
+ feat: torch.Tensor,
59
+ ilens: torch.Tensor = None,
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
62
+ mel_feat = torch.matmul(feat, self.melmat)
63
+ mel_feat = torch.clamp(mel_feat, min=1e-10)
64
+
65
+ if self.log_base is None:
66
+ logmel_feat = mel_feat.log()
67
+ elif self.log_base == 2.0:
68
+ logmel_feat = mel_feat.log2()
69
+ elif self.log_base == 10.0:
70
+ logmel_feat = mel_feat.log10()
71
+ else:
72
+ logmel_feat = mel_feat.log() / torch.log(self.log_base)
73
+
74
+ # Zero padding
75
+ if ilens is not None:
76
+ logmel_feat = logmel_feat.masked_fill(
77
+ make_pad_mask(ilens, logmel_feat, 1), 0.0
78
+ )
79
+ else:
80
+ ilens = feat.new_full(
81
+ [feat.size(0)], fill_value=feat.size(1), dtype=torch.long
82
+ )
83
+ return logmel_feat, ilens
funasr_detach/frontends/utils/mask_estimator.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch_complex.tensor import ComplexTensor
7
+
8
+ from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
9
+ from funasr_detach.models.language_model.rnn.encoders import RNN
10
+ from funasr_detach.models.language_model.rnn.encoders import RNNP
11
+
12
+
13
+ class MaskEstimator(torch.nn.Module):
14
+ def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
15
+ super().__init__()
16
+ subsample = np.ones(layers + 1, dtype=np.int32)
17
+
18
+ typ = type.lstrip("vgg").rstrip("p")
19
+ if type[-1] == "p":
20
+ self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
21
+ else:
22
+ self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
23
+
24
+ self.type = type
25
+ self.nmask = nmask
26
+ self.linears = torch.nn.ModuleList(
27
+ [torch.nn.Linear(projs, idim) for _ in range(nmask)]
28
+ )
29
+
30
+ def forward(
31
+ self, xs: ComplexTensor, ilens: torch.LongTensor
32
+ ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
33
+ """The forward function
34
+
35
+ Args:
36
+ xs: (B, F, C, T)
37
+ ilens: (B,)
38
+ Returns:
39
+ hs (torch.Tensor): The hidden vector (B, F, C, T)
40
+ masks: A tuple of the masks. (B, F, C, T)
41
+ ilens: (B,)
42
+ """
43
+ assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
44
+ _, _, C, input_length = xs.size()
45
+ # (B, F, C, T) -> (B, C, T, F)
46
+ xs = xs.permute(0, 2, 3, 1)
47
+
48
+ # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
49
+ xs = (xs.real**2 + xs.imag**2) ** 0.5
50
+ # xs: (B, C, T, F) -> xs: (B * C, T, F)
51
+ xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
52
+ # ilens: (B,) -> ilens_: (B * C)
53
+ ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
54
+
55
+ # xs: (B * C, T, F) -> xs: (B * C, T, D)
56
+ xs, _, _ = self.brnn(xs, ilens_)
57
+ # xs: (B * C, T, D) -> xs: (B, C, T, D)
58
+ xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
59
+
60
+ masks = []
61
+ for linear in self.linears:
62
+ # xs: (B, C, T, D) -> mask:(B, C, T, F)
63
+ mask = linear(xs)
64
+
65
+ mask = torch.sigmoid(mask)
66
+ # Zero padding
67
+ mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
68
+
69
+ # (B, C, T, F) -> (B, F, C, T)
70
+ mask = mask.permute(0, 3, 1, 2)
71
+
72
+ # Take cares of multi gpu cases: If input_length > max(ilens)
73
+ if mask.size(-1) < input_length:
74
+ mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
75
+ masks.append(mask)
76
+
77
+ return tuple(masks), ilens
funasr_detach/frontends/utils/stft.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.version import LooseVersion
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ try:
9
+ from torch_complex.tensor import ComplexTensor
10
+ except:
11
+ print("Please install torch_complex firstly")
12
+ from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
13
+ from funasr_detach.frontends.utils.complex_utils import is_complex
14
+
15
+ import librosa
16
+ import numpy as np
17
+
18
+ is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
19
+
20
+
21
+ is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
22
+
23
+
24
+ class Stft(torch.nn.Module):
25
+ def __init__(
26
+ self,
27
+ n_fft: int = 512,
28
+ win_length: int = None,
29
+ hop_length: int = 128,
30
+ window: Optional[str] = "hann",
31
+ center: bool = True,
32
+ normalized: bool = False,
33
+ onesided: bool = True,
34
+ ):
35
+ super().__init__()
36
+ self.n_fft = n_fft
37
+ if win_length is None:
38
+ self.win_length = n_fft
39
+ else:
40
+ self.win_length = win_length
41
+ self.hop_length = hop_length
42
+ self.center = center
43
+ self.normalized = normalized
44
+ self.onesided = onesided
45
+ if window is not None and not hasattr(torch, f"{window}_window"):
46
+ if window.lower() != "povey":
47
+ raise ValueError(f"{window} window is not implemented")
48
+ self.window = window
49
+
50
+ def extra_repr(self):
51
+ return (
52
+ f"n_fft={self.n_fft}, "
53
+ f"win_length={self.win_length}, "
54
+ f"hop_length={self.hop_length}, "
55
+ f"center={self.center}, "
56
+ f"normalized={self.normalized}, "
57
+ f"onesided={self.onesided}"
58
+ )
59
+
60
+ def forward(
61
+ self, input: torch.Tensor, ilens: torch.Tensor = None
62
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
63
+ """STFT forward function.
64
+
65
+ Args:
66
+ input: (Batch, Nsamples) or (Batch, Nsample, Channels)
67
+ ilens: (Batch)
68
+ Returns:
69
+ output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
70
+
71
+ """
72
+ bs = input.size(0)
73
+ if input.dim() == 3:
74
+ multi_channel = True
75
+ # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
76
+ input = input.transpose(1, 2).reshape(-1, input.size(1))
77
+ else:
78
+ multi_channel = False
79
+
80
+ # NOTE(kamo):
81
+ # The default behaviour of torch.stft is compatible with librosa.stft
82
+ # about padding and scaling.
83
+ # Note that it's different from scipy.signal.stft
84
+
85
+ # output: (Batch, Freq, Frames, 2=real_imag)
86
+ # or (Batch, Channel, Freq, Frames, 2=real_imag)
87
+ if self.window is not None:
88
+ if self.window.lower() == "povey":
89
+ window = torch.hann_window(
90
+ self.win_length,
91
+ periodic=False,
92
+ device=input.device,
93
+ dtype=input.dtype,
94
+ ).pow(0.85)
95
+ else:
96
+ window_func = getattr(torch, f"{self.window}_window")
97
+ window = window_func(
98
+ self.win_length, dtype=input.dtype, device=input.device
99
+ )
100
+ else:
101
+ window = None
102
+
103
+ # For the compatibility of ARM devices, which do not support
104
+ # torch.stft() due to the lake of MKL.
105
+ if input.is_cuda or torch.backends.mkl.is_available():
106
+ stft_kwargs = dict(
107
+ n_fft=self.n_fft,
108
+ win_length=self.win_length,
109
+ hop_length=self.hop_length,
110
+ center=self.center,
111
+ window=window,
112
+ normalized=self.normalized,
113
+ onesided=self.onesided,
114
+ )
115
+ if is_torch_1_7_plus:
116
+ stft_kwargs["return_complex"] = False
117
+ output = torch.stft(input, **stft_kwargs)
118
+ else:
119
+ if self.training:
120
+ raise NotImplementedError(
121
+ "stft is implemented with librosa on this device, which does not "
122
+ "support the training mode."
123
+ )
124
+
125
+ # use stft_kwargs to flexibly control different PyTorch versions' kwargs
126
+ stft_kwargs = dict(
127
+ n_fft=self.n_fft,
128
+ win_length=self.win_length,
129
+ hop_length=self.hop_length,
130
+ center=self.center,
131
+ window=window,
132
+ )
133
+
134
+ if window is not None:
135
+ # pad the given window to n_fft
136
+ n_pad_left = (self.n_fft - window.shape[0]) // 2
137
+ n_pad_right = self.n_fft - window.shape[0] - n_pad_left
138
+ stft_kwargs["window"] = torch.cat(
139
+ [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
140
+ ).numpy()
141
+ else:
142
+ win_length = (
143
+ self.win_length if self.win_length is not None else self.n_fft
144
+ )
145
+ stft_kwargs["window"] = torch.ones(win_length)
146
+
147
+ output = []
148
+ # iterate over istances in a batch
149
+ for i, instance in enumerate(input):
150
+ stft = librosa.stft(input[i].numpy(), **stft_kwargs)
151
+ output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
152
+ output = torch.stack(output, 0)
153
+ if not self.onesided:
154
+ len_conj = self.n_fft - output.shape[1]
155
+ conj = output[:, 1 : 1 + len_conj].flip(1)
156
+ conj[:, :, :, -1].data *= -1
157
+ output = torch.cat([output, conj], 1)
158
+ if self.normalized:
159
+ output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
160
+
161
+ # output: (Batch, Freq, Frames, 2=real_imag)
162
+ # -> (Batch, Frames, Freq, 2=real_imag)
163
+ output = output.transpose(1, 2)
164
+ if multi_channel:
165
+ # output: (Batch * Channel, Frames, Freq, 2=real_imag)
166
+ # -> (Batch, Frame, Channel, Freq, 2=real_imag)
167
+ output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
168
+ 1, 2
169
+ )
170
+
171
+ if ilens is not None:
172
+ if self.center:
173
+ pad = self.n_fft // 2
174
+ ilens = ilens + 2 * pad
175
+
176
+ olens = (ilens - self.n_fft) // self.hop_length + 1
177
+ output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
178
+ else:
179
+ olens = None
180
+
181
+ return output, olens
182
+
183
+ def inverse(
184
+ self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
185
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
186
+ """Inverse STFT.
187
+
188
+ Args:
189
+ input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
190
+ ilens: (batch,)
191
+ Returns:
192
+ wavs: (batch, samples)
193
+ ilens: (batch,)
194
+ """
195
+ if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
196
+ istft = torch.functional.istft
197
+ else:
198
+ try:
199
+ import torchaudio
200
+ except ImportError:
201
+ raise ImportError(
202
+ "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
203
+ )
204
+
205
+ if not hasattr(torchaudio.functional, "istft"):
206
+ raise ImportError(
207
+ "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
208
+ )
209
+ istft = torchaudio.functional.istft
210
+
211
+ if self.window is not None:
212
+ window_func = getattr(torch, f"{self.window}_window")
213
+ if is_complex(input):
214
+ datatype = input.real.dtype
215
+ else:
216
+ datatype = input.dtype
217
+ window = window_func(self.win_length, dtype=datatype, device=input.device)
218
+ else:
219
+ window = None
220
+
221
+ if is_complex(input):
222
+ input = torch.stack([input.real, input.imag], dim=-1)
223
+ elif input.shape[-1] != 2:
224
+ raise TypeError("Invalid input type")
225
+ input = input.transpose(1, 2)
226
+
227
+ wavs = istft(
228
+ input,
229
+ n_fft=self.n_fft,
230
+ hop_length=self.hop_length,
231
+ win_length=self.win_length,
232
+ window=window,
233
+ center=self.center,
234
+ normalized=self.normalized,
235
+ onesided=self.onesided,
236
+ length=ilens.max() if ilens is not None else ilens,
237
+ )
238
+
239
+ return wavs, ilens
funasr_detach/frontends/wav_frontend.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # Part of the implementation is borrowed from espnet/espnet.
3
+ from typing import Tuple
4
+ import copy
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchaudio.compliance.kaldi as kaldi
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ import funasr_detach.frontends.eend_ola_feature as eend_ola_feature
12
+ from funasr_detach.register import tables
13
+
14
+
15
+ def load_cmvn(cmvn_file):
16
+ with open(cmvn_file, "r", encoding="utf-8") as f:
17
+ lines = f.readlines()
18
+ means_list = []
19
+ vars_list = []
20
+ for i in range(len(lines)):
21
+ line_item = lines[i].split()
22
+ if line_item[0] == "<AddShift>":
23
+ line_item = lines[i + 1].split()
24
+ if line_item[0] == "<LearnRateCoef>":
25
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
26
+ means_list = list(add_shift_line)
27
+ continue
28
+ elif line_item[0] == "<Rescale>":
29
+ line_item = lines[i + 1].split()
30
+ if line_item[0] == "<LearnRateCoef>":
31
+ rescale_line = line_item[3 : (len(line_item) - 1)]
32
+ vars_list = list(rescale_line)
33
+ continue
34
+ means = np.array(means_list).astype(np.float32)
35
+ vars = np.array(vars_list).astype(np.float32)
36
+ cmvn = np.array([means, vars])
37
+ cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
38
+ return cmvn
39
+
40
+
41
+ def apply_cmvn(inputs, cmvn): # noqa
42
+ """
43
+ Apply CMVN with mvn data
44
+ """
45
+
46
+ device = inputs.device
47
+ dtype = inputs.dtype
48
+ frame, dim = inputs.shape
49
+
50
+ means = cmvn[0:1, :dim]
51
+ vars = cmvn[1:2, :dim]
52
+ inputs += means.to(device)
53
+ inputs *= vars.to(device)
54
+
55
+ return inputs.type(torch.float32)
56
+
57
+
58
+ def apply_lfr(inputs, lfr_m, lfr_n):
59
+ LFR_inputs = []
60
+ T = inputs.shape[0]
61
+ T_lfr = int(np.ceil(T / lfr_n))
62
+ left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
63
+ inputs = torch.vstack((left_padding, inputs))
64
+ T = T + (lfr_m - 1) // 2
65
+ for i in range(T_lfr):
66
+ if lfr_m <= T - i * lfr_n:
67
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
68
+ else: # process last LFR frame
69
+ num_padding = lfr_m - (T - i * lfr_n)
70
+ frame = (inputs[i * lfr_n :]).view(-1)
71
+ for _ in range(num_padding):
72
+ frame = torch.hstack((frame, inputs[-1]))
73
+ LFR_inputs.append(frame)
74
+ LFR_outputs = torch.vstack(LFR_inputs)
75
+ return LFR_outputs.type(torch.float32)
76
+
77
+
78
+ @tables.register("frontend_classes", "WavFrontend")
79
+ class WavFrontend(nn.Module):
80
+ """Conventional frontend structure for ASR."""
81
+
82
+ def __init__(
83
+ self,
84
+ cmvn_file: str = None,
85
+ fs: int = 16000,
86
+ window: str = "hamming",
87
+ n_mels: int = 80,
88
+ frame_length: int = 25,
89
+ frame_shift: int = 10,
90
+ filter_length_min: int = -1,
91
+ filter_length_max: int = -1,
92
+ lfr_m: int = 1,
93
+ lfr_n: int = 1,
94
+ dither: float = 1.0,
95
+ snip_edges: bool = True,
96
+ upsacle_samples: bool = True,
97
+ **kwargs,
98
+ ):
99
+ super().__init__()
100
+ self.fs = fs
101
+ self.window = window
102
+ self.n_mels = n_mels
103
+ self.frame_length = frame_length
104
+ self.frame_shift = frame_shift
105
+ self.filter_length_min = filter_length_min
106
+ self.filter_length_max = filter_length_max
107
+ self.lfr_m = lfr_m
108
+ self.lfr_n = lfr_n
109
+ self.cmvn_file = cmvn_file
110
+ self.dither = dither
111
+ self.snip_edges = snip_edges
112
+ self.upsacle_samples = upsacle_samples
113
+ self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
114
+
115
+ def output_size(self) -> int:
116
+ return self.n_mels * self.lfr_m
117
+
118
+ def forward(
119
+ self,
120
+ input: torch.Tensor,
121
+ input_lengths,
122
+ **kwargs,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ batch_size = input.size(0)
125
+ feats = []
126
+ feats_lens = []
127
+ for i in range(batch_size):
128
+ waveform_length = input_lengths[i]
129
+ waveform = input[i][:waveform_length]
130
+ if self.upsacle_samples:
131
+ waveform = waveform * (1 << 15)
132
+ waveform = waveform.unsqueeze(0)
133
+ mat = kaldi.fbank(
134
+ waveform,
135
+ num_mel_bins=self.n_mels,
136
+ frame_length=self.frame_length,
137
+ frame_shift=self.frame_shift,
138
+ dither=self.dither,
139
+ energy_floor=0.0,
140
+ window_type=self.window,
141
+ sample_frequency=self.fs,
142
+ snip_edges=self.snip_edges,
143
+ )
144
+
145
+ if self.lfr_m != 1 or self.lfr_n != 1:
146
+ mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
147
+ if self.cmvn is not None:
148
+ mat = apply_cmvn(mat, self.cmvn)
149
+ feat_length = mat.size(0)
150
+ feats.append(mat)
151
+ feats_lens.append(feat_length)
152
+
153
+ feats_lens = torch.as_tensor(feats_lens)
154
+ if batch_size == 1:
155
+ feats_pad = feats[0][None, :, :]
156
+ else:
157
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
158
+ return feats_pad, feats_lens
159
+
160
+ def forward_fbank(
161
+ self, input: torch.Tensor, input_lengths: torch.Tensor
162
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
163
+ batch_size = input.size(0)
164
+ feats = []
165
+ feats_lens = []
166
+ for i in range(batch_size):
167
+ waveform_length = input_lengths[i]
168
+ waveform = input[i][:waveform_length]
169
+ waveform = waveform * (1 << 15)
170
+ waveform = waveform.unsqueeze(0)
171
+ mat = kaldi.fbank(
172
+ waveform,
173
+ num_mel_bins=self.n_mels,
174
+ frame_length=self.frame_length,
175
+ frame_shift=self.frame_shift,
176
+ dither=self.dither,
177
+ energy_floor=0.0,
178
+ window_type=self.window,
179
+ sample_frequency=self.fs,
180
+ )
181
+
182
+ feat_length = mat.size(0)
183
+ feats.append(mat)
184
+ feats_lens.append(feat_length)
185
+
186
+ feats_lens = torch.as_tensor(feats_lens)
187
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
188
+ return feats_pad, feats_lens
189
+
190
+ def forward_lfr_cmvn(
191
+ self, input: torch.Tensor, input_lengths: torch.Tensor
192
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
193
+ batch_size = input.size(0)
194
+ feats = []
195
+ feats_lens = []
196
+ for i in range(batch_size):
197
+ mat = input[i, : input_lengths[i], :]
198
+ if self.lfr_m != 1 or self.lfr_n != 1:
199
+ mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
200
+ if self.cmvn is not None:
201
+ mat = apply_cmvn(mat, self.cmvn)
202
+ feat_length = mat.size(0)
203
+ feats.append(mat)
204
+ feats_lens.append(feat_length)
205
+
206
+ feats_lens = torch.as_tensor(feats_lens)
207
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
208
+ return feats_pad, feats_lens
209
+
210
+
211
+ @tables.register("frontend_classes", "WavFrontendOnline")
212
+ class WavFrontendOnline(nn.Module):
213
+ """Conventional frontend structure for streaming ASR/VAD."""
214
+
215
+ def __init__(
216
+ self,
217
+ cmvn_file: str = None,
218
+ fs: int = 16000,
219
+ window: str = "hamming",
220
+ n_mels: int = 80,
221
+ frame_length: int = 25,
222
+ frame_shift: int = 10,
223
+ filter_length_min: int = -1,
224
+ filter_length_max: int = -1,
225
+ lfr_m: int = 1,
226
+ lfr_n: int = 1,
227
+ dither: float = 1.0,
228
+ snip_edges: bool = True,
229
+ upsacle_samples: bool = True,
230
+ **kwargs,
231
+ ):
232
+ super().__init__()
233
+ self.fs = fs
234
+ self.window = window
235
+ self.n_mels = n_mels
236
+ self.frame_length = frame_length
237
+ self.frame_shift = frame_shift
238
+ self.frame_sample_length = int(self.frame_length * self.fs / 1000)
239
+ self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
240
+ self.filter_length_min = filter_length_min
241
+ self.filter_length_max = filter_length_max
242
+ self.lfr_m = lfr_m
243
+ self.lfr_n = lfr_n
244
+ self.cmvn_file = cmvn_file
245
+ self.dither = dither
246
+ self.snip_edges = snip_edges
247
+ self.upsacle_samples = upsacle_samples
248
+ # self.waveforms = None
249
+ # self.reserve_waveforms = None
250
+ # self.fbanks = None
251
+ # self.fbanks_lens = None
252
+ self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
253
+ # self.input_cache = None
254
+ # self.lfr_splice_cache = []
255
+
256
+ def output_size(self) -> int:
257
+ return self.n_mels * self.lfr_m
258
+
259
+ @staticmethod
260
+ def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
261
+ """
262
+ Apply CMVN with mvn data
263
+ """
264
+
265
+ device = inputs.device
266
+ dtype = inputs.dtype
267
+ frame, dim = inputs.shape
268
+
269
+ means = np.tile(cmvn[0:1, :dim], (frame, 1))
270
+ vars = np.tile(cmvn[1:2, :dim], (frame, 1))
271
+ inputs += torch.from_numpy(means).type(dtype).to(device)
272
+ inputs *= torch.from_numpy(vars).type(dtype).to(device)
273
+
274
+ return inputs.type(torch.float32)
275
+
276
+ @staticmethod
277
+ def apply_lfr(
278
+ inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False
279
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
280
+ """
281
+ Apply lfr with data
282
+ """
283
+
284
+ LFR_inputs = []
285
+ # inputs = torch.vstack((inputs_lfr_cache, inputs))
286
+ T = inputs.shape[0] # include the right context
287
+ T_lfr = int(
288
+ np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
289
+ ) # minus the right context: (lfr_m - 1) // 2
290
+ splice_idx = T_lfr
291
+ for i in range(T_lfr):
292
+ if lfr_m <= T - i * lfr_n:
293
+ LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
294
+ else: # process last LFR frame
295
+ if is_final:
296
+ num_padding = lfr_m - (T - i * lfr_n)
297
+ frame = (inputs[i * lfr_n :]).view(-1)
298
+ for _ in range(num_padding):
299
+ frame = torch.hstack((frame, inputs[-1]))
300
+ LFR_inputs.append(frame)
301
+ else:
302
+ # update splice_idx and break the circle
303
+ splice_idx = i
304
+ break
305
+ splice_idx = min(T - 1, splice_idx * lfr_n)
306
+ lfr_splice_cache = inputs[splice_idx:, :]
307
+ LFR_outputs = torch.vstack(LFR_inputs)
308
+ return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
309
+
310
+ @staticmethod
311
+ def compute_frame_num(
312
+ sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
313
+ ) -> int:
314
+ frame_num = int(
315
+ (sample_length - frame_sample_length) / frame_shift_sample_length + 1
316
+ )
317
+ return (
318
+ frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
319
+ )
320
+
321
+ def forward_fbank(
322
+ self,
323
+ input: torch.Tensor,
324
+ input_lengths: torch.Tensor,
325
+ cache: dict = {},
326
+ **kwargs,
327
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
328
+ batch_size = input.size(0)
329
+ assert batch_size == 1
330
+ input = torch.cat((cache["input_cache"], input), dim=1)
331
+ frame_num = self.compute_frame_num(
332
+ input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
333
+ )
334
+ # update self.in_cache
335
+ cache["input_cache"] = input[
336
+ :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
337
+ ]
338
+ waveforms = torch.empty(0)
339
+ feats_pad = torch.empty(0)
340
+ feats_lens = torch.empty(0)
341
+ if frame_num:
342
+ waveforms = []
343
+ feats = []
344
+ feats_lens = []
345
+ for i in range(batch_size):
346
+ waveform = input[i].cuda()
347
+ # we need accurate wave samples that used for fbank extracting
348
+ waveforms.append(
349
+ waveform[
350
+ : (
351
+ (frame_num - 1) * self.frame_shift_sample_length
352
+ + self.frame_sample_length
353
+ )
354
+ ]
355
+ )
356
+ waveform = waveform * (1 << 15)
357
+ waveform = waveform.unsqueeze(0)
358
+ mat = kaldi.fbank(
359
+ waveform,
360
+ num_mel_bins=self.n_mels,
361
+ frame_length=self.frame_length,
362
+ frame_shift=self.frame_shift,
363
+ dither=self.dither,
364
+ energy_floor=0.0,
365
+ window_type=self.window,
366
+ sample_frequency=self.fs,
367
+ )
368
+
369
+ feat_length = mat.size(0)
370
+ feats.append(mat)
371
+ feats_lens.append(feat_length)
372
+
373
+ waveforms = torch.stack(waveforms)
374
+ feats_lens = torch.as_tensor(feats_lens)
375
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
376
+ cache["fbanks"] = feats_pad
377
+ cache["fbanks_lens"] = copy.deepcopy(feats_lens)
378
+ return waveforms, feats_pad, feats_lens
379
+
380
+ def forward_lfr_cmvn(
381
+ self,
382
+ input: torch.Tensor,
383
+ input_lengths: torch.Tensor,
384
+ is_final: bool = False,
385
+ cache: dict = {},
386
+ **kwargs,
387
+ ):
388
+ batch_size = input.size(0)
389
+ feats = []
390
+ feats_lens = []
391
+ lfr_splice_frame_idxs = []
392
+ for i in range(batch_size):
393
+ mat = input[i, : input_lengths[i], :]
394
+ if self.lfr_m != 1 or self.lfr_n != 1:
395
+ # update self.lfr_splice_cache in self.apply_lfr
396
+ # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
397
+ mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = (
398
+ self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
399
+ )
400
+ if self.cmvn_file is not None:
401
+ mat = self.apply_cmvn(mat, self.cmvn)
402
+ feat_length = mat.size(0)
403
+ feats.append(mat)
404
+ feats_lens.append(feat_length)
405
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
406
+ feats_lens = torch.as_tensor(feats_lens)
407
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
408
+ lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
409
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
410
+
411
+ def forward(self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs):
412
+ is_final = kwargs.get("is_final", False)
413
+ cache = kwargs.get("cache", {})
414
+ if len(cache) == 0:
415
+ self.init_cache(cache)
416
+
417
+ batch_size = input.shape[0]
418
+ assert (
419
+ batch_size == 1
420
+ ), "we support to extract feature online only when the batch size is equal to 1 now"
421
+
422
+ waveforms, feats, feats_lengths = self.forward_fbank(
423
+ input, input_lengths, cache=cache
424
+ ) # input shape: B T D
425
+
426
+ if feats.shape[0]:
427
+
428
+ cache["waveforms"] = torch.cat(
429
+ (cache["reserve_waveforms"], waveforms.cpu()), dim=1
430
+ )
431
+
432
+ if not cache["lfr_splice_cache"]: # 初始化splice_cache
433
+ for i in range(batch_size):
434
+ cache["lfr_splice_cache"].append(
435
+ feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1)
436
+ )
437
+ # need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
438
+ if feats_lengths[0] + cache["lfr_splice_cache"][0].shape[0] >= self.lfr_m:
439
+ lfr_splice_cache_tensor = torch.stack(
440
+ cache["lfr_splice_cache"]
441
+ ) # B T D
442
+ feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
443
+
444
+ feats_lengths += lfr_splice_cache_tensor[0].shape[0]
445
+ frame_from_waveforms = int(
446
+ (cache["waveforms"].shape[1] - self.frame_sample_length)
447
+ / self.frame_shift_sample_length
448
+ + 1
449
+ )
450
+ minus_frame = (
451
+ (self.lfr_m - 1) // 2
452
+ if cache["reserve_waveforms"].numel() == 0
453
+ else 0
454
+ )
455
+ feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(
456
+ feats, feats_lengths, is_final, cache=cache
457
+ )
458
+ if self.lfr_m == 1:
459
+ cache["reserve_waveforms"] = torch.empty(0)
460
+ else:
461
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
462
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
463
+ # print('frame_frame: ' + str(frame_from_waveforms))
464
+ cache["reserve_waveforms"] = cache["waveforms"][
465
+ :,
466
+ reserve_frame_idx
467
+ * self.frame_shift_sample_length : frame_from_waveforms
468
+ * self.frame_shift_sample_length,
469
+ ]
470
+ sample_length = (
471
+ frame_from_waveforms - 1
472
+ ) * self.frame_shift_sample_length + self.frame_sample_length
473
+ cache["waveforms"] = cache["waveforms"][:, :sample_length]
474
+ else:
475
+ # update self.reserve_waveforms and self.lfr_splice_cache
476
+ cache["reserve_waveforms"] = cache["waveforms"][
477
+ :, : -(self.frame_sample_length - self.frame_shift_sample_length)
478
+ ]
479
+ for i in range(batch_size):
480
+ cache["lfr_splice_cache"][i] = torch.cat(
481
+ (cache["lfr_splice_cache"][i], feats[i]), dim=0
482
+ )
483
+ return torch.empty(0), feats_lengths
484
+ else:
485
+ if is_final:
486
+ cache["waveforms"] = (
487
+ waveforms
488
+ if cache["reserve_waveforms"].numel() == 0
489
+ else cache["reserve_waveforms"]
490
+ )
491
+ feats = torch.stack(cache["lfr_splice_cache"])
492
+ feats_lengths = (
493
+ torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
494
+ )
495
+ feats, feats_lengths, _ = self.forward_lfr_cmvn(
496
+ feats, feats_lengths, is_final, cache=cache
497
+ )
498
+ # if is_final:
499
+ # self.init_cache(cache)
500
+ return feats, feats_lengths
501
+
502
+ def init_cache(self, cache: dict = {}):
503
+ cache["reserve_waveforms"] = torch.empty(0)
504
+ cache["input_cache"] = torch.empty(0)
505
+ cache["lfr_splice_cache"] = []
506
+ cache["waveforms"] = None
507
+ cache["fbanks"] = None
508
+ cache["fbanks_lens"] = None
509
+ return cache
510
+
511
+
512
+ class WavFrontendMel23(nn.Module):
513
+ """Conventional frontend structure for ASR."""
514
+
515
+ def __init__(
516
+ self,
517
+ fs: int = 16000,
518
+ frame_length: int = 25,
519
+ frame_shift: int = 10,
520
+ lfr_m: int = 1,
521
+ lfr_n: int = 1,
522
+ **kwargs,
523
+ ):
524
+ super().__init__()
525
+ self.fs = fs
526
+ self.frame_length = frame_length
527
+ self.frame_shift = frame_shift
528
+ self.lfr_m = lfr_m
529
+ self.lfr_n = lfr_n
530
+ self.n_mels = 23
531
+
532
+ def output_size(self) -> int:
533
+ return self.n_mels * (2 * self.lfr_m + 1)
534
+
535
+ def forward(
536
+ self, input: torch.Tensor, input_lengths: torch.Tensor
537
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
538
+ batch_size = input.size(0)
539
+ feats = []
540
+ feats_lens = []
541
+ for i in range(batch_size):
542
+ waveform_length = input_lengths[i]
543
+ waveform = input[i][:waveform_length]
544
+ waveform = waveform.numpy()
545
+ mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
546
+ mat = eend_ola_feature.transform(mat)
547
+ mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
548
+ mat = mat[:: self.lfr_n]
549
+ mat = torch.from_numpy(mat)
550
+ feat_length = mat.size(0)
551
+ feats.append(mat)
552
+ feats_lens.append(feat_length)
553
+
554
+ feats_lens = torch.as_tensor(feats_lens)
555
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
556
+ return feats_pad, feats_lens
funasr_detach/frontends/windowing.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 2020, Technische Universität München; Ludwig Kürzinger
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+
5
+ """Sliding Window for raw audio input data."""
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Tuple
10
+
11
+
12
+ class SlidingWindow(nn.Module):
13
+ """Sliding Window.
14
+ Provides a sliding window over a batched continuous raw audio tensor.
15
+ Optionally, provides padding (Currently not implemented).
16
+ Combine this module with a pre-encoder compatible with raw audio data,
17
+ for example Sinc convolutions.
18
+ Known issues:
19
+ Output length is calculated incorrectly if audio shorter than win_length.
20
+ WARNING: trailing values are discarded - padding not implemented yet.
21
+ There is currently no additional window function applied to input values.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ win_length: int = 400,
27
+ hop_length: int = 160,
28
+ channels: int = 1,
29
+ padding: int = None,
30
+ fs=None,
31
+ ):
32
+ """Initialize.
33
+ Args:
34
+ win_length: Length of frame.
35
+ hop_length: Relative starting point of next frame.
36
+ channels: Number of input channels.
37
+ padding: Padding (placeholder, currently not implemented).
38
+ fs: Sampling rate (placeholder for compatibility, not used).
39
+ """
40
+ super().__init__()
41
+ self.fs = fs
42
+ self.win_length = win_length
43
+ self.hop_length = hop_length
44
+ self.channels = channels
45
+ self.padding = padding
46
+
47
+ def forward(
48
+ self, input: torch.Tensor, input_lengths: torch.Tensor
49
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
50
+ """Apply a sliding window on the input.
51
+ Args:
52
+ input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
53
+ input_lengths: Input lengths within batch.
54
+ Returns:
55
+ Tensor: Output with dimensions (B, T, C, D), with D=win_length.
56
+ Tensor: Output lengths within batch.
57
+ """
58
+ input_size = input.size()
59
+ B = input_size[0]
60
+ T = input_size[1]
61
+ C = self.channels
62
+ D = self.win_length
63
+ # (B, T, C) --> (T, B, C)
64
+ continuous = input.view(B, T, C).permute(1, 0, 2)
65
+ windowed = continuous.unfold(0, D, self.hop_length)
66
+ # (T, B, C, D) --> (B, T, C, D)
67
+ output = windowed.permute(1, 0, 2, 3).contiguous()
68
+ # After unfold(), windowed lengths change:
69
+ output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
70
+ return output, output_lengths
71
+
72
+ def output_size(self) -> int:
73
+ """Return output length of feature dimension D, i.e. the window length."""
74
+ return self.win_length
funasr_detach/losses/__init__.py ADDED
File without changes
funasr_detach/losses/label_smoothing_loss.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Label smoothing module."""
8
+
9
+ import torch
10
+ from torch import nn
11
+ from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
12
+
13
+
14
+ class LabelSmoothingLoss(nn.Module):
15
+ """Label-smoothing loss.
16
+
17
+ :param int size: the number of class
18
+ :param int padding_idx: ignored class id
19
+ :param float smoothing: smoothing rate (0.0 means the conventional CE)
20
+ :param bool normalize_length: normalize loss by sequence length if True
21
+ :param torch.nn.Module criterion: loss function to be smoothed
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ size,
27
+ padding_idx,
28
+ smoothing,
29
+ normalize_length=False,
30
+ criterion=nn.KLDivLoss(reduction="none"),
31
+ ):
32
+ """Construct an LabelSmoothingLoss object."""
33
+ super(LabelSmoothingLoss, self).__init__()
34
+ self.criterion = criterion
35
+ self.padding_idx = padding_idx
36
+ self.confidence = 1.0 - smoothing
37
+ self.smoothing = smoothing
38
+ self.size = size
39
+ self.true_dist = None
40
+ self.normalize_length = normalize_length
41
+
42
+ def forward(self, x, target):
43
+ """Compute loss between x and target.
44
+
45
+ :param torch.Tensor x: prediction (batch, seqlen, class)
46
+ :param torch.Tensor target:
47
+ target signal masked with self.padding_id (batch, seqlen)
48
+ :return: scalar float value
49
+ :rtype torch.Tensor
50
+ """
51
+ assert x.size(2) == self.size
52
+ batch_size = x.size(0)
53
+ x = x.view(-1, self.size)
54
+ target = target.view(-1)
55
+ with torch.no_grad():
56
+ true_dist = x.clone()
57
+ true_dist.fill_(self.smoothing / (self.size - 1))
58
+ ignore = target == self.padding_idx # (B,)
59
+ total = len(target) - ignore.sum().item()
60
+ target = target.masked_fill(ignore, 0) # avoid -1 index
61
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
62
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
63
+ denom = total if self.normalize_length else batch_size
64
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
65
+
66
+
67
+ class SequenceBinaryCrossEntropy(nn.Module):
68
+ def __init__(
69
+ self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none")
70
+ ):
71
+ super().__init__()
72
+ self.normalize_length = normalize_length
73
+ self.criterion = criterion
74
+
75
+ def forward(self, pred, label, lengths):
76
+ pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
77
+ loss = self.criterion(pred, label)
78
+ denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
79
+ return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
80
+
81
+
82
+ class NllLoss(nn.Module):
83
+ """Nll loss.
84
+
85
+ :param int size: the number of class
86
+ :param int padding_idx: ignored class id
87
+ :param bool normalize_length: normalize loss by sequence length if True
88
+ :param torch.nn.Module criterion: loss function
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ size,
94
+ padding_idx,
95
+ normalize_length=False,
96
+ criterion=nn.NLLLoss(reduction="none"),
97
+ ):
98
+ """Construct an NllLoss object."""
99
+ super(NllLoss, self).__init__()
100
+ self.criterion = criterion
101
+ self.padding_idx = padding_idx
102
+ self.size = size
103
+ self.true_dist = None
104
+ self.normalize_length = normalize_length
105
+
106
+ def forward(self, x, target):
107
+ """Compute loss between x and target.
108
+
109
+ :param torch.Tensor x: prediction (batch, seqlen, class)
110
+ :param torch.Tensor target:
111
+ target signal masked with self.padding_id (batch, seqlen)
112
+ :return: scalar float value
113
+ :rtype torch.Tensor
114
+ """
115
+ assert x.size(2) == self.size
116
+ batch_size = x.size(0)
117
+ x = x.view(-1, self.size)
118
+ target = target.view(-1)
119
+ with torch.no_grad():
120
+ ignore = target == self.padding_idx # (B,)
121
+ total = len(target) - ignore.sum().item()
122
+ target = target.masked_fill(ignore, 0) # avoid -1 index
123
+ kl = self.criterion(x, target)
124
+ denom = total if self.normalize_length else batch_size
125
+ return kl.masked_fill(ignore, 0).sum() / denom