Spaces:
Running
on
Zero
audio edit
Browse filesremove useless file
feat: add spaces, change edit_app name
feat: change readme, dockerfile
feat: change readme
feat: add default config
feat: remote useless file
feat: change readme
feat: change readme
feat: change requirements version
feat: change requirements
feat: remove dockerfile
feat: change pkg version
feat: support hf model source
feat: fix model loader
feat: test
feat: fix model loader
feat: fix model cache path
feat: add log
feat: fix download
feat: fix tokenizer
feat: fix tokenizer
feat: fix download
feat: add hf login
feat: add log
feat: remove useless log
feat: fix model loader
feat: fix model loader
feat: add log
feat: fix model loader
feat: rollback code
feat: fix
feat: fix model loader
feat: fix model path
feat: zerogpu
feat: fix
feat: fix app
feat: optimize download
feat: optimize download
feat: change app desc
feat: add log
feat: add log
- .gitattributes +4 -4
- .gitignore +2 -0
- README.md +13 -1
- __init__.py +0 -0
- app.py +499 -0
- config/__init__.py +12 -0
- config/edit_config.py +33 -0
- config/prompts.py +62 -0
- funasr_detach/__init__.py +38 -0
- funasr_detach/auto/__init__.py +0 -0
- funasr_detach/auto/auto_frontend.py +90 -0
- funasr_detach/auto/auto_model.py +575 -0
- funasr_detach/auto/auto_tokenizer.py +7 -0
- funasr_detach/bin/__init__.py +0 -0
- funasr_detach/bin/compute_audio_cmvn.py +152 -0
- funasr_detach/bin/inference.py +33 -0
- funasr_detach/bin/tokenize_text.py +281 -0
- funasr_detach/bin/train.py +227 -0
- funasr_detach/datasets/__init__.py +0 -0
- funasr_detach/datasets/audio_datasets/__init__.py +0 -0
- funasr_detach/datasets/audio_datasets/datasets.py +112 -0
- funasr_detach/datasets/audio_datasets/index_ds.py +150 -0
- funasr_detach/datasets/audio_datasets/preprocessor.py +55 -0
- funasr_detach/datasets/audio_datasets/samplers.py +306 -0
- funasr_detach/datasets/audio_datasets/scp2jsonl.py +116 -0
- funasr_detach/download/__init__.py +0 -0
- funasr_detach/download/download_dataset_from_hub.py +19 -0
- funasr_detach/download/download_from_hub.py +231 -0
- funasr_detach/download/file.py +335 -0
- funasr_detach/download/name_maps_from_hub.py +13 -0
- funasr_detach/download/runtime_sdk_download_tool.py +60 -0
- funasr_detach/frontends/__init__.py +0 -0
- funasr_detach/frontends/default.py +347 -0
- funasr_detach/frontends/eend_ola_feature.py +49 -0
- funasr_detach/frontends/fused.py +144 -0
- funasr_detach/frontends/s3prl.py +139 -0
- funasr_detach/frontends/utils/__init__.py +1 -0
- funasr_detach/frontends/utils/beamformer.py +84 -0
- funasr_detach/frontends/utils/complex_utils.py +194 -0
- funasr_detach/frontends/utils/dnn_beamformer.py +173 -0
- funasr_detach/frontends/utils/dnn_wpe.py +93 -0
- funasr_detach/frontends/utils/feature_transform.py +263 -0
- funasr_detach/frontends/utils/frontend.py +151 -0
- funasr_detach/frontends/utils/log_mel.py +83 -0
- funasr_detach/frontends/utils/mask_estimator.py +77 -0
- funasr_detach/frontends/utils/stft.py +239 -0
- funasr_detach/frontends/wav_frontend.py +556 -0
- funasr_detach/frontends/windowing.py +74 -0
- funasr_detach/losses/__init__.py +0 -0
- funasr_detach/losses/label_smoothing_loss.py +125 -0
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 1 |
+
examples filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
speakers/nezha_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
speakers/nezhaRAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
speakers/nezha哼唱_prompt.wav filter=lfs diff=lfs merge=lfs -text
|
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
output/
|
|
@@ -1 +1,13 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Step-Audio-EditX
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
short_description: Try out Step-Audio-EditX
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
File without changes
|
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import torch
|
| 5 |
+
import logging
|
| 6 |
+
import threading
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import torchaudio
|
| 9 |
+
import librosa
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
|
| 12 |
+
# ZeroGPU support
|
| 13 |
+
try:
|
| 14 |
+
import spaces
|
| 15 |
+
ZEROGPU_AVAILABLE = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
ZEROGPU_AVAILABLE = False
|
| 18 |
+
# Create a dummy decorator for non-ZeroGPU environments
|
| 19 |
+
class spaces:
|
| 20 |
+
@staticmethod
|
| 21 |
+
def GPU(duration=10):
|
| 22 |
+
def decorator(func):
|
| 23 |
+
return func
|
| 24 |
+
return decorator
|
| 25 |
+
|
| 26 |
+
# Project imports
|
| 27 |
+
from tokenizer import StepAudioTokenizer
|
| 28 |
+
from tts import StepAudioTTS
|
| 29 |
+
from model_loader import ModelSource
|
| 30 |
+
from config.edit_config import get_supported_edit_types
|
| 31 |
+
|
| 32 |
+
# Configure logging
|
| 33 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# Global variables for ZeroGPU-optimized loading
|
| 37 |
+
encoder = None
|
| 38 |
+
common_tts_engine = None
|
| 39 |
+
args_global = None
|
| 40 |
+
_model_lock = threading.Lock() # Thread lock for model initialization
|
| 41 |
+
|
| 42 |
+
def initialize_models():
|
| 43 |
+
"""Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
|
| 44 |
+
global encoder, common_tts_engine, args_global
|
| 45 |
+
|
| 46 |
+
# Fast path: check if already initialized (without lock)
|
| 47 |
+
if common_tts_engine is not None:
|
| 48 |
+
return # Already initialized
|
| 49 |
+
|
| 50 |
+
# Slow path: acquire lock and double-check
|
| 51 |
+
with _model_lock:
|
| 52 |
+
# Double-check pattern: another thread might have initialized while waiting for lock
|
| 53 |
+
if common_tts_engine is not None:
|
| 54 |
+
return # Already initialized by another thread
|
| 55 |
+
|
| 56 |
+
if args_global is None:
|
| 57 |
+
raise RuntimeError("Global args not set. Cannot initialize models.")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
logger.info("🚀 Initializing models inside GPU context (first call)...")
|
| 61 |
+
|
| 62 |
+
# Determine model source
|
| 63 |
+
source_mapping = {
|
| 64 |
+
"auto": ModelSource.AUTO,
|
| 65 |
+
"local": ModelSource.LOCAL,
|
| 66 |
+
"modelscope": ModelSource.MODELSCOPE,
|
| 67 |
+
"huggingface": ModelSource.HUGGINGFACE
|
| 68 |
+
}
|
| 69 |
+
model_source = source_mapping[args_global.model_source]
|
| 70 |
+
|
| 71 |
+
# Load StepAudioTokenizer (avoid CUDA initialization in main process)
|
| 72 |
+
encoder = StepAudioTokenizer(
|
| 73 |
+
os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
|
| 74 |
+
model_source=model_source,
|
| 75 |
+
funasr_model_id=args_global.tokenizer_model_id
|
| 76 |
+
)
|
| 77 |
+
logger.info("✓ StepAudioTokenizer loaded")
|
| 78 |
+
|
| 79 |
+
# Initialize common TTS engine (avoid CUDA initialization in main process)
|
| 80 |
+
common_tts_engine = StepAudioTTS(
|
| 81 |
+
os.path.join(args_global.model_path, "Step-Audio-EditX"),
|
| 82 |
+
encoder,
|
| 83 |
+
model_source=model_source,
|
| 84 |
+
tts_model_id=args_global.tts_model_id
|
| 85 |
+
)
|
| 86 |
+
logger.info("✓ StepCommonAudioTTS loaded")
|
| 87 |
+
print("Models initialized inside GPU context.")
|
| 88 |
+
|
| 89 |
+
if ZEROGPU_AVAILABLE:
|
| 90 |
+
logger.info("💡 Models loaded inside GPU context - ready for inference")
|
| 91 |
+
else:
|
| 92 |
+
logger.info("💡 Models loaded - ready for inference")
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"❌ Error loading models: {e}")
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
def get_model_config():
|
| 99 |
+
"""Get model configuration without initializing GPU models"""
|
| 100 |
+
if args_global is None:
|
| 101 |
+
raise RuntimeError("Global args not set. Cannot get model config.")
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
"encoder_path": os.path.join(args_global.model_path, "Step-Audio-Tokenizer"),
|
| 105 |
+
"tts_path": os.path.join(args_global.model_path, "Step-Audio-EditX"),
|
| 106 |
+
"model_source": args_global.model_source,
|
| 107 |
+
"tokenizer_model_id": args_global.tokenizer_model_id,
|
| 108 |
+
"tts_model_id": args_global.tts_model_id
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
def get_gpu_duration(audio_input, text_input, target_text, task_type, task_info):
|
| 112 |
+
"""Dynamic GPU duration based on whether models need initialization"""
|
| 113 |
+
global common_tts_engine
|
| 114 |
+
|
| 115 |
+
if common_tts_engine is None:
|
| 116 |
+
# First call - need time for model loading (up to 5 minutes)
|
| 117 |
+
return 300 # Maximum allowed duration for model initialization
|
| 118 |
+
else:
|
| 119 |
+
# Subsequent calls - only inference time needed
|
| 120 |
+
return 120 # Standard inference duration
|
| 121 |
+
|
| 122 |
+
@spaces.GPU(duration=get_gpu_duration) # Dynamic duration based on model state
|
| 123 |
+
def process_audio_with_gpu(audio_input, text_input, target_text, task_type, task_info):
|
| 124 |
+
"""Process audio using GPU (models are loaded inside GPU context to avoid main process errors)"""
|
| 125 |
+
global common_tts_engine
|
| 126 |
+
|
| 127 |
+
# Initialize models if not already loaded (inside GPU context to avoid main process errors)
|
| 128 |
+
if common_tts_engine is None:
|
| 129 |
+
print("Initializing common_tts_engine inside GPU context...")
|
| 130 |
+
logger.info("🎯 GPU allocated for 300s (first call with model loading)...")
|
| 131 |
+
initialize_models()
|
| 132 |
+
logger.info("✅ Models loaded successfully inside GPU context")
|
| 133 |
+
else:
|
| 134 |
+
print("common_tts_engine already initialized.")
|
| 135 |
+
logger.info("🎯 GPU allocated for 120s (inference with loaded models)...")
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
# Use loaded models (first call may include loading time, subsequent calls are fast)
|
| 139 |
+
if task_type == "clone":
|
| 140 |
+
output_audio, sr = common_tts_engine.clone(audio_input, text_input, target_text)
|
| 141 |
+
else:
|
| 142 |
+
output_audio, sr = common_tts_engine.edit(audio_input, text_input, task_type, task_info, target_text)
|
| 143 |
+
|
| 144 |
+
logger.info("✅ Audio processing completed")
|
| 145 |
+
return output_audio, sr
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"❌ Audio processing failed: {e}")
|
| 149 |
+
raise
|
| 150 |
+
# GPU automatically deallocated when function exits
|
| 151 |
+
|
| 152 |
+
# Save audio to temporary directory
|
| 153 |
+
def save_audio(audio_type, audio_data, sr, tmp_dir):
|
| 154 |
+
"""Save audio data to a temporary file with timestamp"""
|
| 155 |
+
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 156 |
+
save_path = os.path.join(tmp_dir, audio_type, f"{current_time}.wav")
|
| 157 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
if isinstance(audio_data, torch.Tensor):
|
| 161 |
+
torchaudio.save(save_path, audio_data, sr)
|
| 162 |
+
else:
|
| 163 |
+
sf.write(save_path, audio_data, sr)
|
| 164 |
+
logger.debug(f"Audio saved to: {save_path}")
|
| 165 |
+
return save_path
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Failed to save audio: {e}")
|
| 168 |
+
raise
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class EditxTab:
|
| 172 |
+
"""Audio editing and voice cloning interface tab"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, args):
|
| 175 |
+
self.args = args
|
| 176 |
+
self.edit_type_list = list(get_supported_edit_types().keys())
|
| 177 |
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
| 178 |
+
|
| 179 |
+
def history_messages_to_show(self, messages):
|
| 180 |
+
"""Convert message history to gradio chatbot format"""
|
| 181 |
+
show_msgs = []
|
| 182 |
+
for message in messages:
|
| 183 |
+
edit_type = message['edit_type']
|
| 184 |
+
edit_info = message['edit_info']
|
| 185 |
+
source_text = message['source_text']
|
| 186 |
+
target_text = message['target_text']
|
| 187 |
+
raw_audio_part = message['raw_wave']
|
| 188 |
+
edit_audio_part = message['edit_wave']
|
| 189 |
+
type_str = f"{edit_type}-{edit_info}" if edit_info is not None else f"{edit_type}"
|
| 190 |
+
show_msgs.extend([
|
| 191 |
+
{"role": "user", "content": f"任务类型:{type_str}\n文本:{source_text}"},
|
| 192 |
+
{"role": "user", "content": gr.Audio(value=raw_audio_part, interactive=False)},
|
| 193 |
+
{"role": "assistant", "content": f"输出音频:\n文本:{target_text}"},
|
| 194 |
+
{"role": "assistant", "content": gr.Audio(value=edit_audio_part, interactive=False)}
|
| 195 |
+
])
|
| 196 |
+
return show_msgs
|
| 197 |
+
|
| 198 |
+
def generate_clone(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
|
| 199 |
+
"""Generate cloned audio (models are loaded on first GPU call)"""
|
| 200 |
+
self.logger.info("Starting voice cloning process")
|
| 201 |
+
state['history_audio'] = []
|
| 202 |
+
state['history_messages'] = []
|
| 203 |
+
|
| 204 |
+
# Input validation
|
| 205 |
+
if not prompt_text_input or prompt_text_input.strip() == "":
|
| 206 |
+
error_msg = "[Error] Uploaded text cannot be empty."
|
| 207 |
+
self.logger.error(error_msg)
|
| 208 |
+
return [{"role": "user", "content": error_msg}], state
|
| 209 |
+
if not prompt_audio_input:
|
| 210 |
+
error_msg = "[Error] Uploaded audio cannot be empty."
|
| 211 |
+
self.logger.error(error_msg)
|
| 212 |
+
return [{"role": "user", "content": error_msg}], state
|
| 213 |
+
if not generated_text or generated_text.strip() == "":
|
| 214 |
+
error_msg = "[Error] Clone content cannot be empty."
|
| 215 |
+
self.logger.error(error_msg)
|
| 216 |
+
return [{"role": "user", "content": error_msg}], state
|
| 217 |
+
if edit_type != "clone":
|
| 218 |
+
error_msg = "[Error] CLONE button must use clone task."
|
| 219 |
+
self.logger.error(error_msg)
|
| 220 |
+
return [{"role": "user", "content": error_msg}], state
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
# Use GPU inference with models loaded inside GPU context
|
| 224 |
+
output_audio, output_sr = process_audio_with_gpu(
|
| 225 |
+
prompt_audio_input, prompt_text_input, generated_text, "clone", edit_info
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if output_audio is not None and output_sr is not None:
|
| 229 |
+
# Convert tensor to numpy if needed
|
| 230 |
+
if isinstance(output_audio, torch.Tensor):
|
| 231 |
+
audio_numpy = output_audio.cpu().numpy().squeeze()
|
| 232 |
+
else:
|
| 233 |
+
audio_numpy = output_audio
|
| 234 |
+
|
| 235 |
+
# Load original audio for comparison
|
| 236 |
+
input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
|
| 237 |
+
|
| 238 |
+
# Create message for history
|
| 239 |
+
cur_assistant_msg = {
|
| 240 |
+
"edit_type": edit_type,
|
| 241 |
+
"edit_info": edit_info,
|
| 242 |
+
"source_text": prompt_text_input,
|
| 243 |
+
"target_text": generated_text,
|
| 244 |
+
"raw_wave": (input_sample_rate, input_audio_data_numpy),
|
| 245 |
+
"edit_wave": (output_sr, audio_numpy),
|
| 246 |
+
}
|
| 247 |
+
state["history_audio"].append((output_sr, audio_numpy, generated_text))
|
| 248 |
+
state["history_messages"].append(cur_assistant_msg)
|
| 249 |
+
|
| 250 |
+
show_msgs = self.history_messages_to_show(state["history_messages"])
|
| 251 |
+
self.logger.info("Voice cloning completed successfully")
|
| 252 |
+
return show_msgs, state
|
| 253 |
+
else:
|
| 254 |
+
error_msg = "[Error] Clone failed"
|
| 255 |
+
self.logger.error(error_msg)
|
| 256 |
+
return [{"role": "user", "content": error_msg}], state
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
error_msg = f"[Error] Clone failed: {str(e)}"
|
| 260 |
+
self.logger.error(error_msg)
|
| 261 |
+
return [{"role": "user", "content": error_msg}], state
|
| 262 |
+
|
| 263 |
+
def generate_edit(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state):
|
| 264 |
+
"""Generate edited audio (models are loaded on first GPU call)"""
|
| 265 |
+
self.logger.info("Starting audio editing process")
|
| 266 |
+
|
| 267 |
+
# Input validation
|
| 268 |
+
if not prompt_text_input or prompt_text_input.strip() == "":
|
| 269 |
+
error_msg = "[Error] Uploaded text cannot be empty."
|
| 270 |
+
self.logger.error(error_msg)
|
| 271 |
+
return [{"role": "user", "content": error_msg}], state
|
| 272 |
+
if not prompt_audio_input:
|
| 273 |
+
error_msg = "[Error] Uploaded audio cannot be empty."
|
| 274 |
+
self.logger.error(error_msg)
|
| 275 |
+
return [{"role": "user", "content": error_msg}], state
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
# Determine which audio to use
|
| 279 |
+
if len(state["history_audio"]) == 0:
|
| 280 |
+
# First edit - use uploaded audio
|
| 281 |
+
audio_to_edit = prompt_audio_input
|
| 282 |
+
text_to_use = prompt_text_input
|
| 283 |
+
self.logger.debug("Using prompt audio, no history found")
|
| 284 |
+
else:
|
| 285 |
+
# Use previous edited audio - save it to temp file first
|
| 286 |
+
sample_rate, audio_numpy, previous_text = state["history_audio"][-1]
|
| 287 |
+
temp_path = save_audio("temp", audio_numpy, sample_rate, self.args.tmp_dir)
|
| 288 |
+
audio_to_edit = temp_path
|
| 289 |
+
text_to_use = previous_text
|
| 290 |
+
self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}")
|
| 291 |
+
|
| 292 |
+
# For para-linguistic, use generated_text; otherwise use source text
|
| 293 |
+
if edit_type not in {"para-linguistic"}:
|
| 294 |
+
generated_text = text_to_use
|
| 295 |
+
|
| 296 |
+
# Use GPU inference with models loaded inside GPU context
|
| 297 |
+
output_audio, output_sr = process_audio_with_gpu(
|
| 298 |
+
audio_to_edit, text_to_use, generated_text, edit_type, edit_info
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if output_audio is not None and output_sr is not None:
|
| 302 |
+
# Convert tensor to numpy if needed
|
| 303 |
+
if isinstance(output_audio, torch.Tensor):
|
| 304 |
+
audio_numpy = output_audio.cpu().numpy().squeeze()
|
| 305 |
+
else:
|
| 306 |
+
audio_numpy = output_audio
|
| 307 |
+
|
| 308 |
+
# Load original audio for comparison
|
| 309 |
+
if len(state["history_audio"]) == 0:
|
| 310 |
+
input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input)
|
| 311 |
+
else:
|
| 312 |
+
input_sample_rate, input_audio_data_numpy, _ = state["history_audio"][-1]
|
| 313 |
+
|
| 314 |
+
# Create message for history
|
| 315 |
+
cur_assistant_msg = {
|
| 316 |
+
"edit_type": edit_type,
|
| 317 |
+
"edit_info": edit_info,
|
| 318 |
+
"source_text": text_to_use,
|
| 319 |
+
"target_text": generated_text,
|
| 320 |
+
"raw_wave": (input_sample_rate, input_audio_data_numpy),
|
| 321 |
+
"edit_wave": (output_sr, audio_numpy),
|
| 322 |
+
}
|
| 323 |
+
state["history_audio"].append((output_sr, audio_numpy, generated_text))
|
| 324 |
+
state["history_messages"].append(cur_assistant_msg)
|
| 325 |
+
|
| 326 |
+
show_msgs = self.history_messages_to_show(state["history_messages"])
|
| 327 |
+
self.logger.info("Audio editing completed successfully")
|
| 328 |
+
return show_msgs, state
|
| 329 |
+
else:
|
| 330 |
+
error_msg = "[Error] Edit failed"
|
| 331 |
+
self.logger.error(error_msg)
|
| 332 |
+
return [{"role": "user", "content": error_msg}], state
|
| 333 |
+
|
| 334 |
+
except Exception as e:
|
| 335 |
+
error_msg = f"[Error] Edit failed: {str(e)}"
|
| 336 |
+
self.logger.error(error_msg)
|
| 337 |
+
return [{"role": "user", "content": error_msg}], state
|
| 338 |
+
|
| 339 |
+
def clear_history(self, state):
|
| 340 |
+
"""Clear conversation history"""
|
| 341 |
+
state["history_messages"] = []
|
| 342 |
+
state["history_audio"] = []
|
| 343 |
+
return [], state
|
| 344 |
+
|
| 345 |
+
def init_state(self):
|
| 346 |
+
"""Initialize conversation state"""
|
| 347 |
+
return {
|
| 348 |
+
"history_messages": [],
|
| 349 |
+
"history_audio": []
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
def register_components(self):
|
| 353 |
+
"""Register gradio components - maintaining exact layout from original"""
|
| 354 |
+
with gr.Tab("Editx"):
|
| 355 |
+
with gr.Row():
|
| 356 |
+
with gr.Column():
|
| 357 |
+
self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1)
|
| 358 |
+
self.prompt_text_input = gr.Textbox(label="Audio Text Content", value="", scale=1)
|
| 359 |
+
self.prompt_audio_input = gr.Audio(
|
| 360 |
+
sources=["upload", "microphone"],
|
| 361 |
+
format="wav",
|
| 362 |
+
type="filepath",
|
| 363 |
+
label="Input Audio",
|
| 364 |
+
)
|
| 365 |
+
self.generated_text = gr.Textbox(label="Clone Text", lines=1, max_lines=200)
|
| 366 |
+
with gr.Row():
|
| 367 |
+
self.button_tts = gr.Button("CLONE")
|
| 368 |
+
self.button_edit = gr.Button("EDIT")
|
| 369 |
+
|
| 370 |
+
with gr.Column():
|
| 371 |
+
with gr.Row():
|
| 372 |
+
self.edit_type = gr.Dropdown(label="Task", choices=self.edit_type_list, value="clone")
|
| 373 |
+
self.edit_info = gr.Dropdown(label="Sub-task", choices=[], value=None)
|
| 374 |
+
self.chat_box = gr.Chatbot(label="History", type="messages", height=480*1)
|
| 375 |
+
self.clean_history_submit = gr.Button("Clear History")
|
| 376 |
+
|
| 377 |
+
gr.Markdown("---")
|
| 378 |
+
gr.Markdown("""
|
| 379 |
+
**Button Description:**
|
| 380 |
+
- CLONE: Synthesizes audio based on uploaded audio and text, only used for clone mode, will clear history information when used.
|
| 381 |
+
- EDIT: Edits based on uploaded audio, or continues to stack edit effects based on the previous round of generated audio.
|
| 382 |
+
""")
|
| 383 |
+
gr.Markdown("""
|
| 384 |
+
**Operation Workflow:**
|
| 385 |
+
- Upload the audio to be edited on the left side and fill in the corresponding text content of the audio;
|
| 386 |
+
- If the task requires modifying text content (such as clone, para-linguistic), fill in the text to be synthesized in the "clone text" field. For all other tasks, keep the uploaded audio text content unchanged;
|
| 387 |
+
- Select tasks and subtasks on the right side (some tasks have no subtasks, such as vad, etc.);
|
| 388 |
+
- Click the "CLONE" or "EDIT" button on the left side, and audio will be generated in the dialog box on the right side.
|
| 389 |
+
""")
|
| 390 |
+
gr.Markdown("""
|
| 391 |
+
**Para-linguistic Description:**
|
| 392 |
+
- Supported tags include: [Breathing] [Laughter] [Cough] [Sigh] [Confirmation-en] [Question-en] [Question-ah] [Question-oh] [Surprise-ah] [Surprise-oh] [Dissatisfaction-hnn] [Uhm] [Shh] [Crying] [Surprise-wa] [Surprise-yo] [Question-ei] [Question-yi]
|
| 393 |
+
- Example:
|
| 394 |
+
- Fill in "clone text" field: "Great, the weather is so nice today." Click the "CLONE" button to get audio.
|
| 395 |
+
- Change "clone text" field to: "Great[Laughter], the weather is so nice today[Surprise-ah]." Click the "EDIT" button to get para-linguistic audio.
|
| 396 |
+
""")
|
| 397 |
+
|
| 398 |
+
def register_events(self):
|
| 399 |
+
"""Register event handlers"""
|
| 400 |
+
# Create independent state for each session
|
| 401 |
+
state = gr.State(self.init_state())
|
| 402 |
+
|
| 403 |
+
self.button_tts.click(self.generate_clone,
|
| 404 |
+
inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
|
| 405 |
+
outputs=[self.chat_box, state])
|
| 406 |
+
self.button_edit.click(self.generate_edit,
|
| 407 |
+
inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state],
|
| 408 |
+
outputs=[self.chat_box, state])
|
| 409 |
+
|
| 410 |
+
self.clean_history_submit.click(self.clear_history, inputs=[state], outputs=[self.chat_box, state])
|
| 411 |
+
self.edit_type.change(
|
| 412 |
+
fn=self.update_edit_info,
|
| 413 |
+
inputs=self.edit_type,
|
| 414 |
+
outputs=self.edit_info,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
def update_edit_info(self, category):
|
| 418 |
+
"""Update sub-task dropdown based on main task selection"""
|
| 419 |
+
category_items = get_supported_edit_types()
|
| 420 |
+
choices = category_items.get(category, [])
|
| 421 |
+
value = None if len(choices) == 0 else choices[0]
|
| 422 |
+
return gr.Dropdown(label="Sub-task", choices=choices, value=value)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def launch_demo(args, editx_tab):
|
| 426 |
+
"""Launch the gradio demo"""
|
| 427 |
+
with gr.Blocks(title="🎙️ Step-Audio-EditX") as demo:
|
| 428 |
+
gr.Markdown("## 🎙️ Step-Audio-EditX")
|
| 429 |
+
gr.Markdown("Audio editing and voice cloning using Step-Audio-Edit model.")
|
| 430 |
+
|
| 431 |
+
# Register components
|
| 432 |
+
editx_tab.register_components()
|
| 433 |
+
|
| 434 |
+
# Register events
|
| 435 |
+
editx_tab.register_events()
|
| 436 |
+
|
| 437 |
+
# Launch demo
|
| 438 |
+
demo.queue().launch(
|
| 439 |
+
server_name=args.server_name,
|
| 440 |
+
server_port=args.server_port,
|
| 441 |
+
share=args.share if hasattr(args, 'share') else False
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
if __name__ == "__main__":
|
| 446 |
+
# Parse command line arguments
|
| 447 |
+
parser = argparse.ArgumentParser(description="Step-Audio Edit Demo")
|
| 448 |
+
parser.add_argument("--model-path", type=str, default="stepfun-ai", help="Model path.")
|
| 449 |
+
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
|
| 450 |
+
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
|
| 451 |
+
parser.add_argument("--tmp-dir", type=str, default="/tmp/gradio", help="Save path.")
|
| 452 |
+
parser.add_argument("--share", action="store_true", help="Share gradio app.")
|
| 453 |
+
|
| 454 |
+
# Multi-source loading support parameters
|
| 455 |
+
parser.add_argument(
|
| 456 |
+
"--model-source",
|
| 457 |
+
type=str,
|
| 458 |
+
default="huggingface",
|
| 459 |
+
choices=["auto", "local", "modelscope", "huggingface"],
|
| 460 |
+
help="Model source: auto (detect automatically), local, modelscope, or huggingface"
|
| 461 |
+
)
|
| 462 |
+
parser.add_argument(
|
| 463 |
+
"--tokenizer-model-id",
|
| 464 |
+
type=str,
|
| 465 |
+
default="dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online",
|
| 466 |
+
help="Tokenizer model ID for online loading"
|
| 467 |
+
)
|
| 468 |
+
parser.add_argument(
|
| 469 |
+
"--tts-model-id",
|
| 470 |
+
type=str,
|
| 471 |
+
default=None,
|
| 472 |
+
help="TTS model ID for online loading (if different from model-path)"
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
args = parser.parse_args()
|
| 476 |
+
|
| 477 |
+
# Store args globally for model configuration
|
| 478 |
+
args_global = args
|
| 479 |
+
|
| 480 |
+
logger.info(f"Configuration loaded:")
|
| 481 |
+
logger.info(f"Model source: {args.model_source}")
|
| 482 |
+
logger.info(f"Model path: {args.model_path}")
|
| 483 |
+
logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}")
|
| 484 |
+
if args.tts_model_id:
|
| 485 |
+
logger.info(f"TTS model ID: {args.tts_model_id}")
|
| 486 |
+
|
| 487 |
+
# Models will be initialized on first GPU call to avoid ZeroGPU main process errors
|
| 488 |
+
|
| 489 |
+
if ZEROGPU_AVAILABLE:
|
| 490 |
+
logger.info("🎉 ZeroGPU detected - using dynamic GPU duration management!")
|
| 491 |
+
logger.info("💡 First call: 300s (model loading), subsequent calls: 120s (inference only)")
|
| 492 |
+
else:
|
| 493 |
+
logger.info("💻 Running in local mode - models will be loaded on first call")
|
| 494 |
+
|
| 495 |
+
# Create EditxTab instance
|
| 496 |
+
editx_tab = EditxTab(args)
|
| 497 |
+
|
| 498 |
+
# Launch demo
|
| 499 |
+
launch_demo(args, editx_tab)
|
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration module for Step-Audio
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .prompts import TTS_SYSTEM_PROMPTS, AUDIO_EDIT_SYSTEM_PROMPT
|
| 6 |
+
from .edit_config import get_supported_edit_types
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'TTS_SYSTEM_PROMPTS',
|
| 10 |
+
'AUDIO_EDIT_SYSTEM_PROMPT',
|
| 11 |
+
'get_supported_edit_types'
|
| 12 |
+
]
|
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
音频编辑配置模块
|
| 3 |
+
包含支持的编辑类型和相关配置
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
def get_supported_edit_types():
|
| 7 |
+
"""
|
| 8 |
+
获取支持的编辑类型和选项
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
Dict[str, list]: Dictionary of edit types and their options
|
| 12 |
+
"""
|
| 13 |
+
return {
|
| 14 |
+
"clone": [],
|
| 15 |
+
"emotion": [
|
| 16 |
+
'happy', 'angry', 'sad', 'humour', 'confusion', 'disgusted',
|
| 17 |
+
'empathy', 'embarrass', 'fear', 'surprised', 'excited',
|
| 18 |
+
'depressed', 'coldness', 'admiration'
|
| 19 |
+
],
|
| 20 |
+
"style": [
|
| 21 |
+
'serious', 'arrogant', 'child', 'older', 'girl', 'pure',
|
| 22 |
+
'sister', 'sweet', 'ethereal', 'whisper', 'gentle', 'recite',
|
| 23 |
+
'generous', 'act_coy', 'warm', 'shy', 'comfort', 'authority',
|
| 24 |
+
'chat', 'radio', 'soulful', 'story', 'vivid', 'program',
|
| 25 |
+
'news', 'advertising', 'roar', 'murmur', 'shout', 'deeply', 'loudly'
|
| 26 |
+
],
|
| 27 |
+
"vad": [],
|
| 28 |
+
"music": [],
|
| 29 |
+
"denoise": [],
|
| 30 |
+
"para-linguistic": [],
|
| 31 |
+
"speed": ["faster", "slower", "more faster", "more slower"],
|
| 32 |
+
"animal": [],
|
| 33 |
+
}
|
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
系统提示配置模块
|
| 3 |
+
包含所有TTS和编辑相关的系统提示
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# TTS相关系统提示
|
| 7 |
+
TTS_SYSTEM_PROMPTS = {
|
| 8 |
+
"sys_prompt_for_rap": "请参考对话历史里的音色,用RAP方式将文本内容大声说唱出来。",
|
| 9 |
+
"sys_prompt_for_vocal": "请参考对话历史里的音色,用哼唱的方式将文本内容大声唱出来。",
|
| 10 |
+
"sys_prompt_wo_spk": '以自然的语速读出下面的文字。',
|
| 11 |
+
"sys_prompt_with_spk": '请用{}的声音尽可能自然地说出下面这些话。',
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
# 音频编辑系统提示
|
| 15 |
+
AUDIO_EDIT_SYSTEM_PROMPT = """As a highly skilled audio editing and tuning specialist, you excel at interpreting user instructions and applying precise adjustments to audio files according to their needs. Your expertise spans a wide range of audio enhancement capabilities, including but not limited to the following:
|
| 16 |
+
|
| 17 |
+
# Emotional Enhancement of Speech:
|
| 18 |
+
You are capable of infusing speech with various emotions such as:
|
| 19 |
+
- happy
|
| 20 |
+
- angry
|
| 21 |
+
- sad
|
| 22 |
+
- fear
|
| 23 |
+
- disgusted
|
| 24 |
+
- surprised
|
| 25 |
+
- excited
|
| 26 |
+
|
| 27 |
+
# Speech Style Transfer:
|
| 28 |
+
You can adapt vocal delivery to diverse styles including:
|
| 29 |
+
- Whisper
|
| 30 |
+
- Coquettish
|
| 31 |
+
- Gentle
|
| 32 |
+
- Sweet
|
| 33 |
+
- Arrogant
|
| 34 |
+
- Innocent
|
| 35 |
+
- Radio Host
|
| 36 |
+
- Childlike
|
| 37 |
+
- Bold and Unconstrained
|
| 38 |
+
- Serious
|
| 39 |
+
- Expressive and Vivid
|
| 40 |
+
- Ethereal
|
| 41 |
+
- Exaggerated
|
| 42 |
+
- Recitation
|
| 43 |
+
- Girlish
|
| 44 |
+
- News Broadcast
|
| 45 |
+
- Mature Female Voice
|
| 46 |
+
- Middle-Aged or Elderly
|
| 47 |
+
- Program Hosting
|
| 48 |
+
|
| 49 |
+
# Paralinguistic Adjustments:
|
| 50 |
+
You can fine-tune non-verbal speech elements such as:
|
| 51 |
+
- Laughter Enhancement
|
| 52 |
+
- Emphatic Stress
|
| 53 |
+
- Rhythm and Pace Modulation
|
| 54 |
+
|
| 55 |
+
# Audio Tuning & Editing:
|
| 56 |
+
Your technical proficiency includes:
|
| 57 |
+
- Noise Reduction
|
| 58 |
+
- Background Music Removal
|
| 59 |
+
- Silence Trimming
|
| 60 |
+
- Speaker Extraction
|
| 61 |
+
|
| 62 |
+
Note: Users will provide instructions in natural language. You are expected to accurately interpret their requirements and perform the most suitable audio edits and enhancements."""
|
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Initialize funasr package."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pkgutil
|
| 5 |
+
import importlib
|
| 6 |
+
|
| 7 |
+
dirname = os.path.dirname(__file__)
|
| 8 |
+
version_file = os.path.join(dirname, "version.txt")
|
| 9 |
+
with open(version_file, "r") as f:
|
| 10 |
+
__version__ = f.read().strip()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import importlib
|
| 14 |
+
import pkgutil
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def import_submodules(package, recursive=True):
|
| 18 |
+
if isinstance(package, str):
|
| 19 |
+
package = importlib.import_module(package)
|
| 20 |
+
results = {}
|
| 21 |
+
for loader, name, is_pkg in pkgutil.walk_packages(
|
| 22 |
+
package.__path__, package.__name__ + "."
|
| 23 |
+
):
|
| 24 |
+
try:
|
| 25 |
+
results[name] = importlib.import_module(name)
|
| 26 |
+
except Exception as e:
|
| 27 |
+
# 如果想要看到导入错误的具体信息,可以取消注释下面的行
|
| 28 |
+
# print(f"Failed to import {name}: {e}")
|
| 29 |
+
pass
|
| 30 |
+
if recursive and is_pkg:
|
| 31 |
+
results.update(import_submodules(name))
|
| 32 |
+
return results
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import_submodules(__name__)
|
| 36 |
+
|
| 37 |
+
from funasr_detach.auto.auto_model import AutoModel
|
| 38 |
+
from funasr_detach.auto.auto_frontend import AutoFrontend
|
|
File without changes
|
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import logging
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
from funasr_detach.register import tables
|
| 6 |
+
from funasr_detach.download.download_from_hub import download_model
|
| 7 |
+
from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank
|
| 8 |
+
from funasr_detach.auto.auto_model import prepare_data_iterator
|
| 9 |
+
from funasr_detach.auto.auto_model import prepare_data_iterator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AutoFrontend:
|
| 13 |
+
def __init__(self, **kwargs):
|
| 14 |
+
assert "model" in kwargs
|
| 15 |
+
if "model_conf" not in kwargs:
|
| 16 |
+
logging.info(
|
| 17 |
+
"download models from model hub: {}".format(
|
| 18 |
+
kwargs.get("model_hub", "ms")
|
| 19 |
+
)
|
| 20 |
+
)
|
| 21 |
+
kwargs = download_model(**kwargs)
|
| 22 |
+
|
| 23 |
+
# build frontend
|
| 24 |
+
frontend = kwargs.get("frontend", None)
|
| 25 |
+
if frontend is not None:
|
| 26 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
| 27 |
+
frontend = frontend_class(**kwargs["frontend_conf"])
|
| 28 |
+
|
| 29 |
+
self.frontend = frontend
|
| 30 |
+
if "frontend" in kwargs:
|
| 31 |
+
del kwargs["frontend"]
|
| 32 |
+
self.kwargs = kwargs
|
| 33 |
+
|
| 34 |
+
def __call__(self, input, input_len=None, kwargs=None, **cfg):
|
| 35 |
+
|
| 36 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
| 37 |
+
kwargs.update(cfg)
|
| 38 |
+
|
| 39 |
+
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
|
| 40 |
+
batch_size = kwargs.get("batch_size", 1)
|
| 41 |
+
device = kwargs.get("device", "cpu")
|
| 42 |
+
if device == "cpu":
|
| 43 |
+
batch_size = 1
|
| 44 |
+
|
| 45 |
+
meta_data = {}
|
| 46 |
+
|
| 47 |
+
result_list = []
|
| 48 |
+
num_samples = len(data_list)
|
| 49 |
+
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
|
| 50 |
+
|
| 51 |
+
time0 = time.perf_counter()
|
| 52 |
+
for beg_idx in range(0, num_samples, batch_size):
|
| 53 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
| 54 |
+
data_batch = data_list[beg_idx:end_idx]
|
| 55 |
+
key_batch = key_list[beg_idx:end_idx]
|
| 56 |
+
|
| 57 |
+
# extract fbank feats
|
| 58 |
+
time1 = time.perf_counter()
|
| 59 |
+
audio_sample_list = load_audio_text_image_video(
|
| 60 |
+
data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
|
| 61 |
+
)
|
| 62 |
+
time2 = time.perf_counter()
|
| 63 |
+
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
| 64 |
+
speech, speech_lengths = extract_fbank(
|
| 65 |
+
audio_sample_list,
|
| 66 |
+
data_type=kwargs.get("data_type", "sound"),
|
| 67 |
+
frontend=self.frontend,
|
| 68 |
+
**kwargs,
|
| 69 |
+
)
|
| 70 |
+
time3 = time.perf_counter()
|
| 71 |
+
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
| 72 |
+
meta_data["batch_data_time"] = (
|
| 73 |
+
speech_lengths.sum().item()
|
| 74 |
+
* self.frontend.frame_shift
|
| 75 |
+
* self.frontend.lfr_n
|
| 76 |
+
/ 1000
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
speech.to(device=device), speech_lengths.to(device=device)
|
| 80 |
+
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
|
| 81 |
+
result_list.append(batch)
|
| 82 |
+
|
| 83 |
+
pbar.update(1)
|
| 84 |
+
description = f"{meta_data}, "
|
| 85 |
+
pbar.set_description(description)
|
| 86 |
+
|
| 87 |
+
time_end = time.perf_counter()
|
| 88 |
+
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
|
| 89 |
+
|
| 90 |
+
return result_list
|
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
import copy
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
import string
|
| 7 |
+
import logging
|
| 8 |
+
import os.path
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from funasr_detach.register import tables
|
| 13 |
+
from funasr_detach.utils.load_utils import load_bytes
|
| 14 |
+
from funasr_detach.download.file import download_from_url
|
| 15 |
+
from funasr_detach.download.download_from_hub import download_model
|
| 16 |
+
from funasr_detach.utils.vad_utils import slice_padding_audio_samples
|
| 17 |
+
from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
|
| 18 |
+
from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
|
| 19 |
+
from funasr_detach.utils.load_utils import load_audio_text_image_video
|
| 20 |
+
from funasr_detach.utils.timestamp_tools import timestamp_sentence
|
| 21 |
+
from funasr_detach.models.campplus.utils import sv_chunk, postprocess, distribute_spk
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from funasr_detach.models.campplus.cluster_backend import ClusterBackend
|
| 25 |
+
except:
|
| 26 |
+
print("If you want to use the speaker diarization, please `pip install hdbscan`")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
:param input:
|
| 33 |
+
:param input_len:
|
| 34 |
+
:param data_type:
|
| 35 |
+
:param frontend:
|
| 36 |
+
:return:
|
| 37 |
+
"""
|
| 38 |
+
data_list = []
|
| 39 |
+
key_list = []
|
| 40 |
+
filelist = [".scp", ".txt", ".json", ".jsonl"]
|
| 41 |
+
|
| 42 |
+
chars = string.ascii_letters + string.digits
|
| 43 |
+
if isinstance(data_in, str) and data_in.startswith("http"): # url
|
| 44 |
+
data_in = download_from_url(data_in)
|
| 45 |
+
if isinstance(data_in, str) and os.path.exists(
|
| 46 |
+
data_in
|
| 47 |
+
): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
|
| 48 |
+
_, file_extension = os.path.splitext(data_in)
|
| 49 |
+
file_extension = file_extension.lower()
|
| 50 |
+
if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
|
| 51 |
+
with open(data_in, encoding="utf-8") as fin:
|
| 52 |
+
for line in fin:
|
| 53 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
| 54 |
+
if data_in.endswith(
|
| 55 |
+
".jsonl"
|
| 56 |
+
): # file.jsonl: json.dumps({"source": data})
|
| 57 |
+
lines = json.loads(line.strip())
|
| 58 |
+
data = lines["source"]
|
| 59 |
+
key = data["key"] if "key" in data else key
|
| 60 |
+
else: # filelist, wav.scp, text.txt: id \t data or data
|
| 61 |
+
lines = line.strip().split(maxsplit=1)
|
| 62 |
+
data = lines[1] if len(lines) > 1 else lines[0]
|
| 63 |
+
key = lines[0] if len(lines) > 1 else key
|
| 64 |
+
|
| 65 |
+
data_list.append(data)
|
| 66 |
+
key_list.append(key)
|
| 67 |
+
else:
|
| 68 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
| 69 |
+
data_list = [data_in]
|
| 70 |
+
key_list = [key]
|
| 71 |
+
elif isinstance(data_in, (list, tuple)):
|
| 72 |
+
if data_type is not None and isinstance(
|
| 73 |
+
data_type, (list, tuple)
|
| 74 |
+
): # mutiple inputs
|
| 75 |
+
data_list_tmp = []
|
| 76 |
+
for data_in_i, data_type_i in zip(data_in, data_type):
|
| 77 |
+
key_list, data_list_i = prepare_data_iterator(
|
| 78 |
+
data_in=data_in_i, data_type=data_type_i
|
| 79 |
+
)
|
| 80 |
+
data_list_tmp.append(data_list_i)
|
| 81 |
+
data_list = []
|
| 82 |
+
for item in zip(*data_list_tmp):
|
| 83 |
+
data_list.append(item)
|
| 84 |
+
else:
|
| 85 |
+
# [audio sample point, fbank, text]
|
| 86 |
+
data_list = data_in
|
| 87 |
+
key_list = [
|
| 88 |
+
"rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
| 89 |
+
for _ in range(len(data_in))
|
| 90 |
+
]
|
| 91 |
+
else: # raw text; audio sample point, fbank; bytes
|
| 92 |
+
if isinstance(data_in, bytes): # audio bytes
|
| 93 |
+
data_in = load_bytes(data_in)
|
| 94 |
+
if key is None:
|
| 95 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
| 96 |
+
data_list = [data_in]
|
| 97 |
+
key_list = [key]
|
| 98 |
+
|
| 99 |
+
return key_list, data_list
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class AutoModel:
|
| 103 |
+
|
| 104 |
+
def __init__(self, **kwargs):
|
| 105 |
+
if not kwargs.get("disable_log", False):
|
| 106 |
+
tables.print()
|
| 107 |
+
|
| 108 |
+
model, kwargs = self.build_model(**kwargs)
|
| 109 |
+
|
| 110 |
+
# if vad_model is not None, build vad model else None
|
| 111 |
+
vad_model = kwargs.get("vad_model", None)
|
| 112 |
+
vad_kwargs = kwargs.get("vad_model_revision", None)
|
| 113 |
+
if vad_model is not None:
|
| 114 |
+
logging.info("Building VAD model.")
|
| 115 |
+
vad_kwargs = {
|
| 116 |
+
"model": vad_model,
|
| 117 |
+
"model_revision": vad_kwargs,
|
| 118 |
+
"device": kwargs["device"],
|
| 119 |
+
}
|
| 120 |
+
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
|
| 121 |
+
|
| 122 |
+
# if punc_model is not None, build punc model else None
|
| 123 |
+
punc_model = kwargs.get("punc_model", None)
|
| 124 |
+
punc_kwargs = kwargs.get("punc_model_revision", None)
|
| 125 |
+
if punc_model is not None:
|
| 126 |
+
logging.info("Building punc model.")
|
| 127 |
+
punc_kwargs = {
|
| 128 |
+
"model": punc_model,
|
| 129 |
+
"model_revision": punc_kwargs,
|
| 130 |
+
"device": kwargs["device"],
|
| 131 |
+
}
|
| 132 |
+
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
|
| 133 |
+
|
| 134 |
+
# if spk_model is not None, build spk model else None
|
| 135 |
+
spk_model = kwargs.get("spk_model", None)
|
| 136 |
+
spk_kwargs = kwargs.get("spk_model_revision", None)
|
| 137 |
+
if spk_model is not None:
|
| 138 |
+
logging.info("Building SPK model.")
|
| 139 |
+
spk_kwargs = {
|
| 140 |
+
"model": spk_model,
|
| 141 |
+
"model_revision": spk_kwargs,
|
| 142 |
+
"device": kwargs["device"],
|
| 143 |
+
}
|
| 144 |
+
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
|
| 145 |
+
self.cb_model = ClusterBackend().to(kwargs["device"])
|
| 146 |
+
spk_mode = kwargs.get("spk_mode", "punc_segment")
|
| 147 |
+
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
|
| 148 |
+
logging.error(
|
| 149 |
+
"spk_mode should be one of default, vad_segment and punc_segment."
|
| 150 |
+
)
|
| 151 |
+
self.spk_mode = spk_mode
|
| 152 |
+
|
| 153 |
+
self.kwargs = kwargs
|
| 154 |
+
self.model = model
|
| 155 |
+
self.vad_model = vad_model
|
| 156 |
+
self.vad_kwargs = vad_kwargs
|
| 157 |
+
self.punc_model = punc_model
|
| 158 |
+
self.punc_kwargs = punc_kwargs
|
| 159 |
+
self.spk_model = spk_model
|
| 160 |
+
self.spk_kwargs = spk_kwargs
|
| 161 |
+
self.model_path = kwargs.get("model_path")
|
| 162 |
+
self.repo_path = kwargs.get("repo_path")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def build_model(self, **kwargs):
|
| 166 |
+
assert "model" in kwargs
|
| 167 |
+
if "model_conf" not in kwargs:
|
| 168 |
+
logging.info(
|
| 169 |
+
"download models from model hub: {}".format(
|
| 170 |
+
kwargs.get("model_hub", "ms")
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
kwargs = download_model(**kwargs)
|
| 174 |
+
|
| 175 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
| 176 |
+
|
| 177 |
+
device = kwargs.get("device", "cuda")
|
| 178 |
+
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
|
| 179 |
+
device = "cpu"
|
| 180 |
+
kwargs["batch_size"] = 1
|
| 181 |
+
kwargs["device"] = device
|
| 182 |
+
|
| 183 |
+
if kwargs.get("ncpu", None):
|
| 184 |
+
torch.set_num_threads(kwargs.get("ncpu"))
|
| 185 |
+
|
| 186 |
+
# build tokenizer
|
| 187 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 188 |
+
if tokenizer is not None:
|
| 189 |
+
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
| 190 |
+
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
| 191 |
+
kwargs["tokenizer"] = tokenizer
|
| 192 |
+
kwargs["token_list"] = tokenizer.token_list
|
| 193 |
+
vocab_size = len(tokenizer.token_list)
|
| 194 |
+
else:
|
| 195 |
+
vocab_size = -1
|
| 196 |
+
|
| 197 |
+
# build frontend
|
| 198 |
+
frontend = kwargs.get("frontend", None)
|
| 199 |
+
if frontend is not None:
|
| 200 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
| 201 |
+
frontend = frontend_class(**kwargs["frontend_conf"])
|
| 202 |
+
kwargs["frontend"] = frontend
|
| 203 |
+
kwargs["input_size"] = frontend.output_size()
|
| 204 |
+
|
| 205 |
+
# build model
|
| 206 |
+
model_class = tables.model_classes.get(kwargs["model"])
|
| 207 |
+
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
|
| 208 |
+
|
| 209 |
+
model.to(device)
|
| 210 |
+
|
| 211 |
+
# init_param
|
| 212 |
+
init_param = kwargs.get("init_param", None)
|
| 213 |
+
if init_param is not None:
|
| 214 |
+
logging.info(f"Loading pretrained params from {init_param}")
|
| 215 |
+
load_pretrained_model(
|
| 216 |
+
model=model,
|
| 217 |
+
path=init_param,
|
| 218 |
+
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
|
| 219 |
+
oss_bucket=kwargs.get("oss_bucket", None),
|
| 220 |
+
scope_map=kwargs.get("scope_map", None),
|
| 221 |
+
excludes=kwargs.get("excludes", None),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
return model, kwargs
|
| 225 |
+
|
| 226 |
+
def __call__(self, *args, **cfg):
|
| 227 |
+
kwargs = self.kwargs
|
| 228 |
+
kwargs.update(cfg)
|
| 229 |
+
res = self.model(*args, kwargs)
|
| 230 |
+
return res
|
| 231 |
+
|
| 232 |
+
def generate(self, input, input_len=None, **cfg):
|
| 233 |
+
if self.vad_model is None:
|
| 234 |
+
return self.inference(input, input_len=input_len, **cfg)
|
| 235 |
+
|
| 236 |
+
else:
|
| 237 |
+
return self.inference_with_vad(input, input_len=input_len, **cfg)
|
| 238 |
+
|
| 239 |
+
def inference(
|
| 240 |
+
self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
|
| 241 |
+
):
|
| 242 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
| 243 |
+
kwargs.update(cfg)
|
| 244 |
+
model = self.model if model is None else model
|
| 245 |
+
model = model.cuda()
|
| 246 |
+
model.eval()
|
| 247 |
+
|
| 248 |
+
batch_size = kwargs.get("batch_size", 1)
|
| 249 |
+
# if kwargs.get("device", "cpu") == "cpu":
|
| 250 |
+
# batch_size = 1
|
| 251 |
+
|
| 252 |
+
key_list, data_list = prepare_data_iterator(
|
| 253 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
speed_stats = {}
|
| 257 |
+
asr_result_list = []
|
| 258 |
+
num_samples = len(data_list)
|
| 259 |
+
disable_pbar = kwargs.get("disable_pbar", False)
|
| 260 |
+
pbar = (
|
| 261 |
+
tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
|
| 262 |
+
if not disable_pbar
|
| 263 |
+
else None
|
| 264 |
+
)
|
| 265 |
+
time_speech_total = 0.0
|
| 266 |
+
time_escape_total = 0.0
|
| 267 |
+
for beg_idx in range(0, num_samples, batch_size):
|
| 268 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
| 269 |
+
data_batch = data_list[beg_idx:end_idx]
|
| 270 |
+
key_batch = key_list[beg_idx:end_idx]
|
| 271 |
+
batch = {"data_in": data_batch, "key": key_batch}
|
| 272 |
+
if (end_idx - beg_idx) == 1 and kwargs.get(
|
| 273 |
+
"data_type", None
|
| 274 |
+
) == "fbank": # fbank
|
| 275 |
+
batch["data_in"] = data_batch[0]
|
| 276 |
+
batch["data_lengths"] = input_len
|
| 277 |
+
|
| 278 |
+
time1 = time.perf_counter()
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
results, meta_data = model.inference(**batch, **kwargs)
|
| 281 |
+
time2 = time.perf_counter()
|
| 282 |
+
|
| 283 |
+
asr_result_list.extend(results)
|
| 284 |
+
|
| 285 |
+
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
| 286 |
+
batch_data_time = meta_data.get("batch_data_time", -1)
|
| 287 |
+
time_escape = time2 - time1
|
| 288 |
+
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
|
| 289 |
+
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
|
| 290 |
+
speed_stats["forward"] = f"{time_escape:0.3f}"
|
| 291 |
+
speed_stats["batch_size"] = f"{len(results)}"
|
| 292 |
+
speed_stats["time_cost"] = f"{(time_escape)}"
|
| 293 |
+
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
|
| 294 |
+
description = f"{speed_stats}, "
|
| 295 |
+
if pbar:
|
| 296 |
+
pbar.update(1)
|
| 297 |
+
pbar.set_description(description)
|
| 298 |
+
time_speech_total += batch_data_time
|
| 299 |
+
time_escape_total += time_escape
|
| 300 |
+
|
| 301 |
+
if pbar:
|
| 302 |
+
# pbar.update(1)
|
| 303 |
+
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
| 304 |
+
torch.cuda.empty_cache()
|
| 305 |
+
return asr_result_list
|
| 306 |
+
|
| 307 |
+
def inference_with_vad(self, input, input_len=None, **cfg):
|
| 308 |
+
|
| 309 |
+
# step.1: compute the vad model
|
| 310 |
+
self.vad_kwargs.update(cfg)
|
| 311 |
+
beg_vad = time.time()
|
| 312 |
+
res = self.inference(
|
| 313 |
+
input,
|
| 314 |
+
input_len=input_len,
|
| 315 |
+
model=self.vad_model,
|
| 316 |
+
kwargs=self.vad_kwargs,
|
| 317 |
+
**cfg,
|
| 318 |
+
)
|
| 319 |
+
end_vad = time.time()
|
| 320 |
+
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
|
| 321 |
+
|
| 322 |
+
# step.2 compute asr model
|
| 323 |
+
model = self.model
|
| 324 |
+
kwargs = self.kwargs
|
| 325 |
+
kwargs.update(cfg)
|
| 326 |
+
batch_size = int(kwargs.get("batch_size_s", 300)) * 1000
|
| 327 |
+
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
|
| 328 |
+
kwargs["batch_size"] = batch_size
|
| 329 |
+
|
| 330 |
+
key_list, data_list = prepare_data_iterator(
|
| 331 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None)
|
| 332 |
+
)
|
| 333 |
+
results_ret_list = []
|
| 334 |
+
time_speech_total_all_samples = 1e-6
|
| 335 |
+
|
| 336 |
+
beg_total = time.time()
|
| 337 |
+
pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
|
| 338 |
+
for i in range(len(res)):
|
| 339 |
+
key = res[i]["key"]
|
| 340 |
+
vadsegments = res[i]["value"]
|
| 341 |
+
input_i = data_list[i]
|
| 342 |
+
speech = load_audio_text_image_video(
|
| 343 |
+
input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000)
|
| 344 |
+
)
|
| 345 |
+
speech_lengths = len(speech)
|
| 346 |
+
n = len(vadsegments)
|
| 347 |
+
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
| 348 |
+
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
| 349 |
+
results_sorted = []
|
| 350 |
+
|
| 351 |
+
if not len(sorted_data):
|
| 352 |
+
logging.info("decoding, utt: {}, empty speech".format(key))
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
| 356 |
+
batch_size = max(
|
| 357 |
+
batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
batch_size_ms_cum = 0
|
| 361 |
+
beg_idx = 0
|
| 362 |
+
beg_asr_total = time.time()
|
| 363 |
+
time_speech_total_per_sample = speech_lengths / 16000
|
| 364 |
+
time_speech_total_all_samples += time_speech_total_per_sample
|
| 365 |
+
|
| 366 |
+
all_segments = []
|
| 367 |
+
for j, _ in enumerate(range(0, n)):
|
| 368 |
+
# pbar_sample.update(1)
|
| 369 |
+
batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
|
| 370 |
+
if (
|
| 371 |
+
j < n - 1
|
| 372 |
+
and (
|
| 373 |
+
batch_size_ms_cum
|
| 374 |
+
+ sorted_data[j + 1][0][1]
|
| 375 |
+
- sorted_data[j + 1][0][0]
|
| 376 |
+
)
|
| 377 |
+
< batch_size
|
| 378 |
+
and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0])
|
| 379 |
+
< batch_size_threshold_ms
|
| 380 |
+
):
|
| 381 |
+
continue
|
| 382 |
+
batch_size_ms_cum = 0
|
| 383 |
+
end_idx = j + 1
|
| 384 |
+
speech_j, speech_lengths_j = slice_padding_audio_samples(
|
| 385 |
+
speech, speech_lengths, sorted_data[beg_idx:end_idx]
|
| 386 |
+
)
|
| 387 |
+
results = self.inference(
|
| 388 |
+
speech_j,
|
| 389 |
+
input_len=None,
|
| 390 |
+
model=model,
|
| 391 |
+
kwargs=kwargs,
|
| 392 |
+
disable_pbar=True,
|
| 393 |
+
**cfg,
|
| 394 |
+
)
|
| 395 |
+
if self.spk_model is not None:
|
| 396 |
+
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
|
| 397 |
+
for _b in range(len(speech_j)):
|
| 398 |
+
vad_segments = [
|
| 399 |
+
[
|
| 400 |
+
sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
|
| 401 |
+
sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
|
| 402 |
+
np.array(speech_j[_b]),
|
| 403 |
+
]
|
| 404 |
+
]
|
| 405 |
+
segments = sv_chunk(vad_segments)
|
| 406 |
+
all_segments.extend(segments)
|
| 407 |
+
speech_b = [i[2] for i in segments]
|
| 408 |
+
spk_res = self.inference(
|
| 409 |
+
speech_b,
|
| 410 |
+
input_len=None,
|
| 411 |
+
model=self.spk_model,
|
| 412 |
+
kwargs=kwargs,
|
| 413 |
+
disable_pbar=True,
|
| 414 |
+
**cfg,
|
| 415 |
+
)
|
| 416 |
+
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
|
| 417 |
+
beg_idx = end_idx
|
| 418 |
+
if len(results) < 1:
|
| 419 |
+
continue
|
| 420 |
+
results_sorted.extend(results)
|
| 421 |
+
|
| 422 |
+
restored_data = [0] * n
|
| 423 |
+
for j in range(n):
|
| 424 |
+
index = sorted_data[j][1]
|
| 425 |
+
restored_data[index] = results_sorted[j]
|
| 426 |
+
result = {}
|
| 427 |
+
|
| 428 |
+
# results combine for texts, timestamps, speaker embeddings and others
|
| 429 |
+
# TODO: rewrite for clean code
|
| 430 |
+
for j in range(n):
|
| 431 |
+
for k, v in restored_data[j].items():
|
| 432 |
+
if k.startswith("timestamp"):
|
| 433 |
+
if k not in result:
|
| 434 |
+
result[k] = []
|
| 435 |
+
for t in restored_data[j][k]:
|
| 436 |
+
t[0] += vadsegments[j][0]
|
| 437 |
+
t[1] += vadsegments[j][0]
|
| 438 |
+
result[k].extend(restored_data[j][k])
|
| 439 |
+
elif k == "spk_embedding":
|
| 440 |
+
if k not in result:
|
| 441 |
+
result[k] = restored_data[j][k]
|
| 442 |
+
else:
|
| 443 |
+
result[k] = torch.cat(
|
| 444 |
+
[result[k], restored_data[j][k]], dim=0
|
| 445 |
+
)
|
| 446 |
+
elif "text" in k:
|
| 447 |
+
if k not in result:
|
| 448 |
+
result[k] = restored_data[j][k]
|
| 449 |
+
else:
|
| 450 |
+
result[k] += " " + restored_data[j][k]
|
| 451 |
+
else:
|
| 452 |
+
if k not in result:
|
| 453 |
+
result[k] = restored_data[j][k]
|
| 454 |
+
else:
|
| 455 |
+
result[k] += restored_data[j][k]
|
| 456 |
+
|
| 457 |
+
return_raw_text = kwargs.get("return_raw_text", False)
|
| 458 |
+
# step.3 compute punc model
|
| 459 |
+
if self.punc_model is not None:
|
| 460 |
+
self.punc_kwargs.update(cfg)
|
| 461 |
+
punc_res = self.inference(
|
| 462 |
+
result["text"],
|
| 463 |
+
model=self.punc_model,
|
| 464 |
+
kwargs=self.punc_kwargs,
|
| 465 |
+
disable_pbar=True,
|
| 466 |
+
**cfg,
|
| 467 |
+
)
|
| 468 |
+
raw_text = copy.copy(result["text"])
|
| 469 |
+
if return_raw_text:
|
| 470 |
+
result["raw_text"] = raw_text
|
| 471 |
+
result["text"] = punc_res[0]["text"]
|
| 472 |
+
else:
|
| 473 |
+
raw_text = None
|
| 474 |
+
|
| 475 |
+
# speaker embedding cluster after resorted
|
| 476 |
+
if self.spk_model is not None and kwargs.get("return_spk_res", True):
|
| 477 |
+
if raw_text is None:
|
| 478 |
+
logging.error("Missing punc_model, which is required by spk_model.")
|
| 479 |
+
all_segments = sorted(all_segments, key=lambda x: x[0])
|
| 480 |
+
spk_embedding = result["spk_embedding"]
|
| 481 |
+
labels = self.cb_model(
|
| 482 |
+
spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
|
| 483 |
+
)
|
| 484 |
+
# del result['spk_embedding']
|
| 485 |
+
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
|
| 486 |
+
if self.spk_mode == "vad_segment": # recover sentence_list
|
| 487 |
+
sentence_list = []
|
| 488 |
+
for res, vadsegment in zip(restored_data, vadsegments):
|
| 489 |
+
if "timestamp" not in res:
|
| 490 |
+
logging.error(
|
| 491 |
+
"Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
|
| 492 |
+
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
|
| 493 |
+
can predict timestamp, and speaker diarization relies on timestamps."
|
| 494 |
+
)
|
| 495 |
+
sentence_list.append(
|
| 496 |
+
{
|
| 497 |
+
"start": vadsegment[0],
|
| 498 |
+
"end": vadsegment[1],
|
| 499 |
+
"sentence": res["text"],
|
| 500 |
+
"timestamp": res["timestamp"],
|
| 501 |
+
}
|
| 502 |
+
)
|
| 503 |
+
elif self.spk_mode == "punc_segment":
|
| 504 |
+
if "timestamp" not in result:
|
| 505 |
+
logging.error(
|
| 506 |
+
"Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
|
| 507 |
+
and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
|
| 508 |
+
can predict timestamp, and speaker diarization relies on timestamps."
|
| 509 |
+
)
|
| 510 |
+
sentence_list = timestamp_sentence(
|
| 511 |
+
punc_res[0]["punc_array"],
|
| 512 |
+
result["timestamp"],
|
| 513 |
+
raw_text,
|
| 514 |
+
return_raw_text=return_raw_text,
|
| 515 |
+
)
|
| 516 |
+
distribute_spk(sentence_list, sv_output)
|
| 517 |
+
result["sentence_info"] = sentence_list
|
| 518 |
+
elif kwargs.get("sentence_timestamp", False):
|
| 519 |
+
sentence_list = timestamp_sentence(
|
| 520 |
+
punc_res[0]["punc_array"],
|
| 521 |
+
result["timestamp"],
|
| 522 |
+
raw_text,
|
| 523 |
+
return_raw_text=return_raw_text,
|
| 524 |
+
)
|
| 525 |
+
result["sentence_info"] = sentence_list
|
| 526 |
+
if "spk_embedding" in result:
|
| 527 |
+
del result["spk_embedding"]
|
| 528 |
+
|
| 529 |
+
result["key"] = key
|
| 530 |
+
results_ret_list.append(result)
|
| 531 |
+
end_asr_total = time.time()
|
| 532 |
+
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
| 533 |
+
pbar_total.update(1)
|
| 534 |
+
pbar_total.set_description(
|
| 535 |
+
f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
| 536 |
+
f"time_speech: {time_speech_total_per_sample: 0.3f}, "
|
| 537 |
+
f"time_escape: {time_escape_total_per_sample:0.3f}"
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
return results_ret_list
|
| 541 |
+
|
| 542 |
+
def infer_encoder(
|
| 543 |
+
self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
|
| 544 |
+
):
|
| 545 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
| 546 |
+
kwargs.update(cfg)
|
| 547 |
+
model = self.model if model is None else model
|
| 548 |
+
model = model.cuda()
|
| 549 |
+
model.eval()
|
| 550 |
+
|
| 551 |
+
batch_size = kwargs.get("batch_size", 1)
|
| 552 |
+
|
| 553 |
+
key_list, data_list = prepare_data_iterator(
|
| 554 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
asr_result_list = []
|
| 558 |
+
num_samples = len(data_list)
|
| 559 |
+
for beg_idx in range(0, num_samples, batch_size):
|
| 560 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
| 561 |
+
data_batch = data_list[beg_idx:end_idx]
|
| 562 |
+
key_batch = key_list[beg_idx:end_idx]
|
| 563 |
+
batch = {"data_in": data_batch, "key": key_batch}
|
| 564 |
+
if (end_idx - beg_idx) == 1 and kwargs.get(
|
| 565 |
+
"data_type", None
|
| 566 |
+
) == "fbank": # fbank
|
| 567 |
+
batch["data_in"] = data_batch[0]
|
| 568 |
+
batch["data_lengths"] = input_len
|
| 569 |
+
|
| 570 |
+
with torch.no_grad():
|
| 571 |
+
results, meta_data, cache = model.infer_encoder(**batch, **kwargs)
|
| 572 |
+
asr_result_list.extend(results)
|
| 573 |
+
|
| 574 |
+
torch.cuda.empty_cache()
|
| 575 |
+
return asr_result_list, cache
|
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class AutoTokenizer:
|
| 2 |
+
"""
|
| 3 |
+
Undo
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
def __init__(self):
|
| 7 |
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import hydra
|
| 6 |
+
import logging
|
| 7 |
+
from omegaconf import DictConfig, OmegaConf
|
| 8 |
+
|
| 9 |
+
from funasr_detach.register import tables
|
| 10 |
+
from funasr_detach.download.download_from_hub import download_model
|
| 11 |
+
from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@hydra.main(config_name=None, version_base=None)
|
| 15 |
+
def main_hydra(kwargs: DictConfig):
|
| 16 |
+
if kwargs.get("debug", False):
|
| 17 |
+
import pdb
|
| 18 |
+
|
| 19 |
+
pdb.set_trace()
|
| 20 |
+
|
| 21 |
+
assert "model" in kwargs
|
| 22 |
+
if "model_conf" not in kwargs:
|
| 23 |
+
logging.info(
|
| 24 |
+
"download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
|
| 25 |
+
)
|
| 26 |
+
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
|
| 27 |
+
|
| 28 |
+
main(**kwargs)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main(**kwargs):
|
| 32 |
+
print(kwargs)
|
| 33 |
+
# set random seed
|
| 34 |
+
tables.print()
|
| 35 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
| 36 |
+
torch.backends.cudnn.enabled = kwargs.get(
|
| 37 |
+
"cudnn_enabled", torch.backends.cudnn.enabled
|
| 38 |
+
)
|
| 39 |
+
torch.backends.cudnn.benchmark = kwargs.get(
|
| 40 |
+
"cudnn_benchmark", torch.backends.cudnn.benchmark
|
| 41 |
+
)
|
| 42 |
+
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
|
| 43 |
+
|
| 44 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 45 |
+
|
| 46 |
+
# build frontend if frontend is none None
|
| 47 |
+
frontend = kwargs.get("frontend", None)
|
| 48 |
+
if frontend is not None:
|
| 49 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
| 50 |
+
frontend = frontend_class(**kwargs["frontend_conf"])
|
| 51 |
+
kwargs["frontend"] = frontend
|
| 52 |
+
kwargs["input_size"] = frontend.output_size()
|
| 53 |
+
|
| 54 |
+
# dataset
|
| 55 |
+
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
| 56 |
+
dataset_train = dataset_class(
|
| 57 |
+
kwargs.get("train_data_set_list"),
|
| 58 |
+
frontend=frontend,
|
| 59 |
+
tokenizer=None,
|
| 60 |
+
is_training=False,
|
| 61 |
+
**kwargs.get("dataset_conf")
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# dataloader
|
| 65 |
+
batch_sampler = kwargs["dataset_conf"].get(
|
| 66 |
+
"batch_sampler", "DynamicBatchLocalShuffleSampler"
|
| 67 |
+
)
|
| 68 |
+
batch_sampler_train = None
|
| 69 |
+
if batch_sampler is not None:
|
| 70 |
+
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
| 71 |
+
dataset_conf = kwargs.get("dataset_conf")
|
| 72 |
+
dataset_conf["batch_type"] = "example"
|
| 73 |
+
dataset_conf["batch_size"] = 1
|
| 74 |
+
batch_sampler_train = batch_sampler_class(
|
| 75 |
+
dataset_train, is_training=False, **dataset_conf
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
dataloader_train = torch.utils.data.DataLoader(
|
| 79 |
+
dataset_train,
|
| 80 |
+
collate_fn=dataset_train.collator,
|
| 81 |
+
batch_sampler=batch_sampler_train,
|
| 82 |
+
num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
|
| 83 |
+
pin_memory=True,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
iter_stop = int(kwargs.get("scale", 1.0) * len(dataloader_train))
|
| 87 |
+
|
| 88 |
+
total_frames = 0
|
| 89 |
+
for batch_idx, batch in enumerate(dataloader_train):
|
| 90 |
+
if batch_idx >= iter_stop:
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
fbank = batch["speech"].numpy()[0, :, :]
|
| 94 |
+
if total_frames == 0:
|
| 95 |
+
mean_stats = np.sum(fbank, axis=0)
|
| 96 |
+
var_stats = np.sum(np.square(fbank), axis=0)
|
| 97 |
+
else:
|
| 98 |
+
mean_stats += np.sum(fbank, axis=0)
|
| 99 |
+
var_stats += np.sum(np.square(fbank), axis=0)
|
| 100 |
+
total_frames += fbank.shape[0]
|
| 101 |
+
|
| 102 |
+
cmvn_info = {
|
| 103 |
+
"mean_stats": list(mean_stats.tolist()),
|
| 104 |
+
"var_stats": list(var_stats.tolist()),
|
| 105 |
+
"total_frames": total_frames,
|
| 106 |
+
}
|
| 107 |
+
cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
|
| 108 |
+
# import pdb;pdb.set_trace()
|
| 109 |
+
with open(cmvn_file, "w") as fout:
|
| 110 |
+
fout.write(json.dumps(cmvn_info))
|
| 111 |
+
|
| 112 |
+
mean = -1.0 * mean_stats / total_frames
|
| 113 |
+
var = 1.0 / np.sqrt(var_stats / total_frames - mean * mean)
|
| 114 |
+
dims = mean.shape[0]
|
| 115 |
+
am_mvn = os.path.dirname(cmvn_file) + "/am.mvn"
|
| 116 |
+
with open(am_mvn, "w") as fout:
|
| 117 |
+
fout.write(
|
| 118 |
+
"<Nnet>"
|
| 119 |
+
+ "\n"
|
| 120 |
+
+ "<Splice> "
|
| 121 |
+
+ str(dims)
|
| 122 |
+
+ " "
|
| 123 |
+
+ str(dims)
|
| 124 |
+
+ "\n"
|
| 125 |
+
+ "[ 0 ]"
|
| 126 |
+
+ "\n"
|
| 127 |
+
+ "<AddShift> "
|
| 128 |
+
+ str(dims)
|
| 129 |
+
+ " "
|
| 130 |
+
+ str(dims)
|
| 131 |
+
+ "\n"
|
| 132 |
+
)
|
| 133 |
+
mean_str = (
|
| 134 |
+
str(list(mean)).replace(",", "").replace("[", "[ ").replace("]", " ]")
|
| 135 |
+
)
|
| 136 |
+
fout.write("<LearnRateCoef> 0 " + mean_str + "\n")
|
| 137 |
+
fout.write("<Rescale> " + str(dims) + " " + str(dims) + "\n")
|
| 138 |
+
var_str = str(list(var)).replace(",", "").replace("[", "[ ").replace("]", " ]")
|
| 139 |
+
fout.write("<LearnRateCoef> 0 " + var_str + "\n")
|
| 140 |
+
fout.write("</Nnet>" + "\n")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
"""
|
| 144 |
+
python funasr/bin/compute_audio_cmvn.py \
|
| 145 |
+
--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
|
| 146 |
+
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
|
| 147 |
+
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
|
| 148 |
+
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
|
| 149 |
+
++dataset_conf.num_workers=0
|
| 150 |
+
"""
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
main_hydra()
|
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hydra
|
| 2 |
+
import logging
|
| 3 |
+
from omegaconf import DictConfig, OmegaConf, ListConfig
|
| 4 |
+
|
| 5 |
+
from funasr_detach.auto.auto_model import AutoModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@hydra.main(config_name=None, version_base=None)
|
| 9 |
+
def main_hydra(cfg: DictConfig):
|
| 10 |
+
def to_plain_list(cfg_item):
|
| 11 |
+
if isinstance(cfg_item, ListConfig):
|
| 12 |
+
return OmegaConf.to_container(cfg_item, resolve=True)
|
| 13 |
+
elif isinstance(cfg_item, DictConfig):
|
| 14 |
+
return {k: to_plain_list(v) for k, v in cfg_item.items()}
|
| 15 |
+
else:
|
| 16 |
+
return cfg_item
|
| 17 |
+
|
| 18 |
+
kwargs = to_plain_list(cfg)
|
| 19 |
+
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(level=log_level)
|
| 22 |
+
|
| 23 |
+
if kwargs.get("debug", False):
|
| 24 |
+
import pdb
|
| 25 |
+
|
| 26 |
+
pdb.set_trace()
|
| 27 |
+
model = AutoModel(**kwargs)
|
| 28 |
+
res = model.generate(input=kwargs["input"])
|
| 29 |
+
print(res)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main_hydra()
|
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
from collections import Counter
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
from typing import List
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from funasr_detach.utils.cli_utils import get_commandline_args
|
| 12 |
+
from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
|
| 13 |
+
from funasr_detach.tokenizer.cleaner import TextCleaner
|
| 14 |
+
from funasr_detach.tokenizer.phoneme_tokenizer import g2p_classes
|
| 15 |
+
from funasr_detach.utils.types import str2bool
|
| 16 |
+
from funasr_detach.utils.types import str_or_none
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def field2slice(field: Optional[str]) -> slice:
|
| 20 |
+
"""Convert field string to slice
|
| 21 |
+
|
| 22 |
+
Note that field string accepts 1-based integer.
|
| 23 |
+
|
| 24 |
+
Examples:
|
| 25 |
+
>>> field2slice("1-")
|
| 26 |
+
slice(0, None, None)
|
| 27 |
+
>>> field2slice("1-3")
|
| 28 |
+
slice(0, 3, None)
|
| 29 |
+
>>> field2slice("-3")
|
| 30 |
+
slice(None, 3, None)
|
| 31 |
+
"""
|
| 32 |
+
field = field.strip()
|
| 33 |
+
try:
|
| 34 |
+
if "-" in field:
|
| 35 |
+
# e.g. "2-" or "2-5" or "-7"
|
| 36 |
+
s1, s2 = field.split("-", maxsplit=1)
|
| 37 |
+
if s1.strip() == "":
|
| 38 |
+
s1 = None
|
| 39 |
+
else:
|
| 40 |
+
s1 = int(s1)
|
| 41 |
+
if s1 == 0:
|
| 42 |
+
raise ValueError("1-based string")
|
| 43 |
+
if s2.strip() == "":
|
| 44 |
+
s2 = None
|
| 45 |
+
else:
|
| 46 |
+
s2 = int(s2)
|
| 47 |
+
else:
|
| 48 |
+
# e.g. "2"
|
| 49 |
+
s1 = int(field)
|
| 50 |
+
s2 = s1 + 1
|
| 51 |
+
if s1 == 0:
|
| 52 |
+
raise ValueError("must be 1 or more value")
|
| 53 |
+
except ValueError:
|
| 54 |
+
raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
|
| 55 |
+
|
| 56 |
+
if s1 is None:
|
| 57 |
+
slic = slice(None, s2)
|
| 58 |
+
else:
|
| 59 |
+
# -1 because of 1-based integer following "cut" command
|
| 60 |
+
# e.g "1-3" -> slice(0, 3)
|
| 61 |
+
slic = slice(s1 - 1, s2)
|
| 62 |
+
return slic
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def tokenize(
|
| 66 |
+
input: str,
|
| 67 |
+
output: str,
|
| 68 |
+
field: Optional[str],
|
| 69 |
+
delimiter: Optional[str],
|
| 70 |
+
token_type: str,
|
| 71 |
+
space_symbol: str,
|
| 72 |
+
non_linguistic_symbols: Optional[str],
|
| 73 |
+
bpemodel: Optional[str],
|
| 74 |
+
log_level: str,
|
| 75 |
+
write_vocabulary: bool,
|
| 76 |
+
vocabulary_size: int,
|
| 77 |
+
remove_non_linguistic_symbols: bool,
|
| 78 |
+
cutoff: int,
|
| 79 |
+
add_symbol: List[str],
|
| 80 |
+
cleaner: Optional[str],
|
| 81 |
+
g2p: Optional[str],
|
| 82 |
+
):
|
| 83 |
+
|
| 84 |
+
logging.basicConfig(
|
| 85 |
+
level=log_level,
|
| 86 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
| 87 |
+
)
|
| 88 |
+
if input == "-":
|
| 89 |
+
fin = sys.stdin
|
| 90 |
+
else:
|
| 91 |
+
fin = Path(input).open("r", encoding="utf-8")
|
| 92 |
+
if output == "-":
|
| 93 |
+
fout = sys.stdout
|
| 94 |
+
else:
|
| 95 |
+
p = Path(output)
|
| 96 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
fout = p.open("w", encoding="utf-8")
|
| 98 |
+
|
| 99 |
+
cleaner = TextCleaner(cleaner)
|
| 100 |
+
tokenizer = build_tokenizer(
|
| 101 |
+
token_type=token_type,
|
| 102 |
+
bpemodel=bpemodel,
|
| 103 |
+
delimiter=delimiter,
|
| 104 |
+
space_symbol=space_symbol,
|
| 105 |
+
non_linguistic_symbols=non_linguistic_symbols,
|
| 106 |
+
remove_non_linguistic_symbols=remove_non_linguistic_symbols,
|
| 107 |
+
g2p_type=g2p,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
counter = Counter()
|
| 111 |
+
if field is not None:
|
| 112 |
+
field = field2slice(field)
|
| 113 |
+
|
| 114 |
+
for line in fin:
|
| 115 |
+
line = line.rstrip()
|
| 116 |
+
if field is not None:
|
| 117 |
+
# e.g. field="2-"
|
| 118 |
+
# uttidA hello world!! -> hello world!!
|
| 119 |
+
tokens = line.split(delimiter)
|
| 120 |
+
tokens = tokens[field]
|
| 121 |
+
if delimiter is None:
|
| 122 |
+
line = " ".join(tokens)
|
| 123 |
+
else:
|
| 124 |
+
line = delimiter.join(tokens)
|
| 125 |
+
|
| 126 |
+
line = cleaner(line)
|
| 127 |
+
tokens = tokenizer.text2tokens(line)
|
| 128 |
+
if not write_vocabulary:
|
| 129 |
+
fout.write(" ".join(tokens) + "\n")
|
| 130 |
+
else:
|
| 131 |
+
for t in tokens:
|
| 132 |
+
counter[t] += 1
|
| 133 |
+
|
| 134 |
+
if not write_vocabulary:
|
| 135 |
+
return
|
| 136 |
+
|
| 137 |
+
## FIXME
|
| 138 |
+
## del duplicate add_symbols in counter
|
| 139 |
+
for symbol_and_id in add_symbol:
|
| 140 |
+
# e.g symbol="<blank>:0"
|
| 141 |
+
try:
|
| 142 |
+
symbol, idx = symbol_and_id.split(":")
|
| 143 |
+
except ValueError:
|
| 144 |
+
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
|
| 145 |
+
symbol = symbol.strip()
|
| 146 |
+
if symbol in counter:
|
| 147 |
+
del counter[symbol]
|
| 148 |
+
|
| 149 |
+
# ======= write_vocabulary mode from here =======
|
| 150 |
+
# Sort by the number of occurrences in descending order
|
| 151 |
+
# and filter lower frequency words than cutoff value
|
| 152 |
+
words_and_counts = list(
|
| 153 |
+
filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
|
| 154 |
+
)
|
| 155 |
+
# Restrict the vocabulary size
|
| 156 |
+
if vocabulary_size > 0:
|
| 157 |
+
if vocabulary_size < len(add_symbol):
|
| 158 |
+
raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
|
| 159 |
+
words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
|
| 160 |
+
|
| 161 |
+
# Parse the values of --add_symbol
|
| 162 |
+
for symbol_and_id in add_symbol:
|
| 163 |
+
# e.g symbol="<blank>:0"
|
| 164 |
+
try:
|
| 165 |
+
symbol, idx = symbol_and_id.split(":")
|
| 166 |
+
idx = int(idx)
|
| 167 |
+
except ValueError:
|
| 168 |
+
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
|
| 169 |
+
symbol = symbol.strip()
|
| 170 |
+
|
| 171 |
+
# e.g. idx=0 -> append as the first symbol
|
| 172 |
+
# e.g. idx=-1 -> append as the last symbol
|
| 173 |
+
if idx < 0:
|
| 174 |
+
idx = len(words_and_counts) + 1 + idx
|
| 175 |
+
words_and_counts.insert(idx, (symbol, None))
|
| 176 |
+
|
| 177 |
+
# Write words
|
| 178 |
+
for w, c in words_and_counts:
|
| 179 |
+
fout.write(w + "\n")
|
| 180 |
+
|
| 181 |
+
# Logging
|
| 182 |
+
total_count = sum(counter.values())
|
| 183 |
+
invocab_count = sum(c for w, c in words_and_counts if c is not None)
|
| 184 |
+
logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_parser() -> argparse.ArgumentParser:
|
| 188 |
+
parser = argparse.ArgumentParser(
|
| 189 |
+
description="Tokenize texts",
|
| 190 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--log_level",
|
| 194 |
+
type=lambda x: x.upper(),
|
| 195 |
+
default="INFO",
|
| 196 |
+
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
| 197 |
+
help="The verbose level of logging",
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--input", "-i", required=True, help="Input text. - indicates sys.stdin"
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--output", "-o", required=True, help="Output text. - indicates sys.stdout"
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--field",
|
| 208 |
+
"-f",
|
| 209 |
+
help="The target columns of the input text as 1-based integer. e.g 2-",
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--token_type",
|
| 213 |
+
"-t",
|
| 214 |
+
default="char",
|
| 215 |
+
choices=["char", "bpe", "word", "phn"],
|
| 216 |
+
help="Token type",
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
|
| 219 |
+
parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
|
| 220 |
+
parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--non_linguistic_symbols",
|
| 223 |
+
type=str_or_none,
|
| 224 |
+
help="non_linguistic_symbols file path",
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--remove_non_linguistic_symbols",
|
| 228 |
+
type=str2bool,
|
| 229 |
+
default=False,
|
| 230 |
+
help="Remove non-language-symbols from tokens",
|
| 231 |
+
)
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--cleaner",
|
| 234 |
+
type=str_or_none,
|
| 235 |
+
choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
|
| 236 |
+
default=None,
|
| 237 |
+
help="Apply text cleaning",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--g2p",
|
| 241 |
+
type=str_or_none,
|
| 242 |
+
choices=g2p_classes,
|
| 243 |
+
default=None,
|
| 244 |
+
help="Specify g2p method if --token_type=phn",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
group = parser.add_argument_group("write_vocabulary mode related")
|
| 248 |
+
group.add_argument(
|
| 249 |
+
"--write_vocabulary",
|
| 250 |
+
type=str2bool,
|
| 251 |
+
default=False,
|
| 252 |
+
help="Write tokens list instead of tokenized text per line",
|
| 253 |
+
)
|
| 254 |
+
group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
|
| 255 |
+
group.add_argument(
|
| 256 |
+
"--cutoff",
|
| 257 |
+
default=0,
|
| 258 |
+
type=int,
|
| 259 |
+
help="cut-off frequency used for write-vocabulary mode",
|
| 260 |
+
)
|
| 261 |
+
group.add_argument(
|
| 262 |
+
"--add_symbol",
|
| 263 |
+
type=str,
|
| 264 |
+
default=[],
|
| 265 |
+
action="append",
|
| 266 |
+
help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
return parser
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def main(cmd=None):
|
| 273 |
+
print(get_commandline_args(), file=sys.stderr)
|
| 274 |
+
parser = get_parser()
|
| 275 |
+
args = parser.parse_args(cmd)
|
| 276 |
+
kwargs = vars(args)
|
| 277 |
+
tokenize(**kwargs)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import torch
|
| 7 |
+
import hydra
|
| 8 |
+
import logging
|
| 9 |
+
import argparse
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from collections.abc import Sequence
|
| 13 |
+
from omegaconf import DictConfig, OmegaConf
|
| 14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 15 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 16 |
+
|
| 17 |
+
from funasr_detach.register import tables
|
| 18 |
+
from funasr_detach.optimizers import optim_classes
|
| 19 |
+
from funasr_detach.train_utils.trainer import Trainer
|
| 20 |
+
from funasr_detach.schedulers import scheduler_classes
|
| 21 |
+
from funasr_detach.train_utils.initialize import initialize
|
| 22 |
+
from funasr_detach.download.download_from_hub import download_model
|
| 23 |
+
from funasr_detach.models.lora.utils import mark_only_lora_as_trainable
|
| 24 |
+
from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
|
| 25 |
+
from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
|
| 26 |
+
|
| 27 |
+
# from funasr_detach.tokenizer.build_tokenizer import build_tokenizer
|
| 28 |
+
# from funasr_detach.tokenizer.token_id_converter import TokenIDConverter
|
| 29 |
+
# from funasr_detach.tokenizer.funtoken import build_tokenizer
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@hydra.main(config_name=None, version_base=None)
|
| 33 |
+
def main_hydra(kwargs: DictConfig):
|
| 34 |
+
if kwargs.get("debug", False):
|
| 35 |
+
import pdb
|
| 36 |
+
|
| 37 |
+
pdb.set_trace()
|
| 38 |
+
|
| 39 |
+
assert "model" in kwargs
|
| 40 |
+
if "model_conf" not in kwargs:
|
| 41 |
+
logging.info(
|
| 42 |
+
"download models from model hub: {}".format(kwargs.get("model_hub", "ms"))
|
| 43 |
+
)
|
| 44 |
+
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
|
| 45 |
+
|
| 46 |
+
main(**kwargs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main(**kwargs):
|
| 50 |
+
print(kwargs)
|
| 51 |
+
|
| 52 |
+
# set random seed
|
| 53 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
| 54 |
+
torch.backends.cudnn.enabled = kwargs.get(
|
| 55 |
+
"cudnn_enabled", torch.backends.cudnn.enabled
|
| 56 |
+
)
|
| 57 |
+
torch.backends.cudnn.benchmark = kwargs.get(
|
| 58 |
+
"cudnn_benchmark", torch.backends.cudnn.benchmark
|
| 59 |
+
)
|
| 60 |
+
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
|
| 61 |
+
|
| 62 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 63 |
+
if local_rank == 0:
|
| 64 |
+
tables.print()
|
| 65 |
+
# Check if we are using DDP or FSDP
|
| 66 |
+
use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
|
| 67 |
+
use_fsdp = kwargs.get("use_fsdp", None)
|
| 68 |
+
if use_ddp or use_fsdp:
|
| 69 |
+
dist.init_process_group(
|
| 70 |
+
backend=kwargs.get("backend", "nccl"), init_method="env://"
|
| 71 |
+
)
|
| 72 |
+
torch.cuda.set_device(local_rank)
|
| 73 |
+
|
| 74 |
+
# save config.yaml
|
| 75 |
+
if (
|
| 76 |
+
(use_ddp or use_fsdp)
|
| 77 |
+
and dist.get_rank() == 0
|
| 78 |
+
or not (use_ddp or use_fsdp)
|
| 79 |
+
and local_rank == 0
|
| 80 |
+
):
|
| 81 |
+
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
|
| 82 |
+
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
|
| 83 |
+
OmegaConf.save(config=kwargs, f=yaml_file)
|
| 84 |
+
logging.info("config.yaml is saved to: %s", yaml_file)
|
| 85 |
+
|
| 86 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 87 |
+
if tokenizer is not None:
|
| 88 |
+
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
| 89 |
+
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
|
| 90 |
+
kwargs["tokenizer"] = tokenizer
|
| 91 |
+
|
| 92 |
+
# build frontend if frontend is none None
|
| 93 |
+
frontend = kwargs.get("frontend", None)
|
| 94 |
+
if frontend is not None:
|
| 95 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
| 96 |
+
frontend = frontend_class(**kwargs["frontend_conf"])
|
| 97 |
+
kwargs["frontend"] = frontend
|
| 98 |
+
kwargs["input_size"] = frontend.output_size()
|
| 99 |
+
|
| 100 |
+
# build model
|
| 101 |
+
model_class = tables.model_classes.get(kwargs["model"])
|
| 102 |
+
model = model_class(
|
| 103 |
+
**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# init_param
|
| 107 |
+
init_param = kwargs.get("init_param", None)
|
| 108 |
+
if init_param is not None:
|
| 109 |
+
if not isinstance(init_param, (list, tuple)):
|
| 110 |
+
init_param = (init_param,)
|
| 111 |
+
logging.info("init_param is not None: %s", init_param)
|
| 112 |
+
for p in init_param:
|
| 113 |
+
logging.info(f"Loading pretrained params from {p}")
|
| 114 |
+
load_pretrained_model(
|
| 115 |
+
model=model,
|
| 116 |
+
path=p,
|
| 117 |
+
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
| 118 |
+
oss_bucket=kwargs.get("oss_bucket", None),
|
| 119 |
+
scope_map=kwargs.get("scope_map", None),
|
| 120 |
+
excludes=kwargs.get("excludes", None),
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
initialize(model, kwargs.get("init", "kaiming_normal"))
|
| 124 |
+
|
| 125 |
+
# freeze_param
|
| 126 |
+
freeze_param = kwargs.get("freeze_param", None)
|
| 127 |
+
if freeze_param is not None:
|
| 128 |
+
freeze_param = eval(freeze_param)
|
| 129 |
+
if isinstance(freeze_param, Sequence):
|
| 130 |
+
freeze_param = (freeze_param,)
|
| 131 |
+
logging.info("freeze_param is not None: %s", freeze_param)
|
| 132 |
+
for t in freeze_param:
|
| 133 |
+
for k, p in model.named_parameters():
|
| 134 |
+
if k.startswith(t + ".") or k == t:
|
| 135 |
+
logging.info(f"Setting {k}.requires_grad = False")
|
| 136 |
+
p.requires_grad = False
|
| 137 |
+
|
| 138 |
+
if use_ddp:
|
| 139 |
+
model = model.cuda(local_rank)
|
| 140 |
+
model = DDP(
|
| 141 |
+
model,
|
| 142 |
+
device_ids=[local_rank],
|
| 143 |
+
find_unused_parameters=kwargs.get("train_conf", {}).get(
|
| 144 |
+
"find_unused_parameters", False
|
| 145 |
+
),
|
| 146 |
+
)
|
| 147 |
+
elif use_fsdp:
|
| 148 |
+
model = FSDP(model).cuda(local_rank)
|
| 149 |
+
else:
|
| 150 |
+
model = model.to(device=kwargs.get("device", "cuda"))
|
| 151 |
+
|
| 152 |
+
# optim
|
| 153 |
+
optim = kwargs.get("optim", "adam")
|
| 154 |
+
assert optim in optim_classes
|
| 155 |
+
optim_class = optim_classes.get(optim)
|
| 156 |
+
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
|
| 157 |
+
|
| 158 |
+
# scheduler
|
| 159 |
+
scheduler = kwargs.get("scheduler", "warmuplr")
|
| 160 |
+
assert scheduler in scheduler_classes
|
| 161 |
+
scheduler_class = scheduler_classes.get(scheduler)
|
| 162 |
+
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
| 163 |
+
|
| 164 |
+
# dataset
|
| 165 |
+
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
|
| 166 |
+
dataset_tr = dataset_class(
|
| 167 |
+
kwargs.get("train_data_set_list"),
|
| 168 |
+
frontend=frontend,
|
| 169 |
+
tokenizer=tokenizer,
|
| 170 |
+
is_training=True,
|
| 171 |
+
**kwargs.get("dataset_conf"),
|
| 172 |
+
)
|
| 173 |
+
dataset_val = dataset_class(
|
| 174 |
+
kwargs.get("valid_data_set_list"),
|
| 175 |
+
frontend=frontend,
|
| 176 |
+
tokenizer=tokenizer,
|
| 177 |
+
is_training=False,
|
| 178 |
+
**kwargs.get("dataset_conf"),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# dataloader
|
| 182 |
+
batch_sampler = kwargs["dataset_conf"].get(
|
| 183 |
+
"batch_sampler", "DynamicBatchLocalShuffleSampler"
|
| 184 |
+
)
|
| 185 |
+
batch_sampler_val = None
|
| 186 |
+
if batch_sampler is not None:
|
| 187 |
+
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
|
| 188 |
+
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
|
| 189 |
+
batch_sampler_val = batch_sampler_class(
|
| 190 |
+
dataset_val, is_training=False, **kwargs.get("dataset_conf")
|
| 191 |
+
)
|
| 192 |
+
dataloader_tr = torch.utils.data.DataLoader(
|
| 193 |
+
dataset_tr,
|
| 194 |
+
collate_fn=dataset_tr.collator,
|
| 195 |
+
batch_sampler=batch_sampler,
|
| 196 |
+
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
| 197 |
+
pin_memory=True,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
dataloader_val = torch.utils.data.DataLoader(
|
| 201 |
+
dataset_val,
|
| 202 |
+
collate_fn=dataset_val.collator,
|
| 203 |
+
batch_sampler=batch_sampler_val,
|
| 204 |
+
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
|
| 205 |
+
pin_memory=True,
|
| 206 |
+
)
|
| 207 |
+
trainer = Trainer(
|
| 208 |
+
model=model,
|
| 209 |
+
optim=optim,
|
| 210 |
+
scheduler=scheduler,
|
| 211 |
+
dataloader_train=dataloader_tr,
|
| 212 |
+
dataloader_val=dataloader_val,
|
| 213 |
+
local_rank=local_rank,
|
| 214 |
+
use_ddp=use_ddp,
|
| 215 |
+
use_fsdp=use_fsdp,
|
| 216 |
+
output_dir=kwargs.get("output_dir", "./exp"),
|
| 217 |
+
resume=kwargs.get("resume", True),
|
| 218 |
+
**kwargs.get("train_conf"),
|
| 219 |
+
)
|
| 220 |
+
trainer.run()
|
| 221 |
+
|
| 222 |
+
if use_ddp or use_fsdp:
|
| 223 |
+
torch.distributed.destroy_process_group()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main_hydra()
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from funasr_detach.register import tables
|
| 4 |
+
from funasr_detach.utils.load_utils import extract_fbank, load_audio_text_image_video
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@tables.register("dataset_classes", "AudioDataset")
|
| 8 |
+
class AudioDataset(torch.utils.data.Dataset):
|
| 9 |
+
"""
|
| 10 |
+
AudioDataset
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
path,
|
| 16 |
+
index_ds: str = None,
|
| 17 |
+
frontend=None,
|
| 18 |
+
tokenizer=None,
|
| 19 |
+
int_pad_value: int = -1,
|
| 20 |
+
float_pad_value: float = 0.0,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
index_ds_class = tables.index_ds_classes.get(index_ds)
|
| 25 |
+
self.index_ds = index_ds_class(path, **kwargs)
|
| 26 |
+
preprocessor_speech = kwargs.get("preprocessor_speech", None)
|
| 27 |
+
if preprocessor_speech:
|
| 28 |
+
preprocessor_speech_class = tables.preprocessor_classes.get(
|
| 29 |
+
preprocessor_speech
|
| 30 |
+
)
|
| 31 |
+
preprocessor_speech = preprocessor_speech_class(
|
| 32 |
+
**kwargs.get("preprocessor_speech_conf")
|
| 33 |
+
)
|
| 34 |
+
self.preprocessor_speech = preprocessor_speech
|
| 35 |
+
preprocessor_text = kwargs.get("preprocessor_text", None)
|
| 36 |
+
if preprocessor_text:
|
| 37 |
+
preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
|
| 38 |
+
preprocessor_text = preprocessor_text_class(
|
| 39 |
+
**kwargs.get("preprocessor_text_conf")
|
| 40 |
+
)
|
| 41 |
+
self.preprocessor_text = preprocessor_text
|
| 42 |
+
|
| 43 |
+
self.frontend = frontend
|
| 44 |
+
self.fs = 16000 if frontend is None else frontend.fs
|
| 45 |
+
self.data_type = "sound"
|
| 46 |
+
self.tokenizer = tokenizer
|
| 47 |
+
|
| 48 |
+
self.int_pad_value = int_pad_value
|
| 49 |
+
self.float_pad_value = float_pad_value
|
| 50 |
+
|
| 51 |
+
def get_source_len(self, index):
|
| 52 |
+
item = self.index_ds[index]
|
| 53 |
+
return self.index_ds.get_source_len(item)
|
| 54 |
+
|
| 55 |
+
def get_target_len(self, index):
|
| 56 |
+
item = self.index_ds[index]
|
| 57 |
+
return self.index_ds.get_target_len(item)
|
| 58 |
+
|
| 59 |
+
def __len__(self):
|
| 60 |
+
return len(self.index_ds)
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, index):
|
| 63 |
+
item = self.index_ds[index]
|
| 64 |
+
# import pdb;
|
| 65 |
+
# pdb.set_trace()
|
| 66 |
+
source = item["source"]
|
| 67 |
+
data_src = load_audio_text_image_video(source, fs=self.fs)
|
| 68 |
+
if self.preprocessor_speech:
|
| 69 |
+
data_src = self.preprocessor_speech(data_src, fs=self.fs)
|
| 70 |
+
speech, speech_lengths = extract_fbank(
|
| 71 |
+
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
|
| 72 |
+
) # speech: [b, T, d]
|
| 73 |
+
|
| 74 |
+
target = item["target"]
|
| 75 |
+
if self.preprocessor_text:
|
| 76 |
+
target = self.preprocessor_text(target)
|
| 77 |
+
if self.tokenizer:
|
| 78 |
+
ids = self.tokenizer.encode(target)
|
| 79 |
+
text = torch.tensor(ids, dtype=torch.int64)
|
| 80 |
+
else:
|
| 81 |
+
ids = target
|
| 82 |
+
text = ids
|
| 83 |
+
ids_lengths = len(ids)
|
| 84 |
+
text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
"speech": speech[0, :, :],
|
| 88 |
+
"speech_lengths": speech_lengths,
|
| 89 |
+
"text": text,
|
| 90 |
+
"text_lengths": text_lengths,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
def collator(self, samples: list = None):
|
| 94 |
+
outputs = {}
|
| 95 |
+
for sample in samples:
|
| 96 |
+
for key in sample.keys():
|
| 97 |
+
if key not in outputs:
|
| 98 |
+
outputs[key] = []
|
| 99 |
+
outputs[key].append(sample[key])
|
| 100 |
+
|
| 101 |
+
for key, data_list in outputs.items():
|
| 102 |
+
if isinstance(data_list[0], torch.Tensor):
|
| 103 |
+
if data_list[0].dtype == torch.int64:
|
| 104 |
+
|
| 105 |
+
pad_value = self.int_pad_value
|
| 106 |
+
else:
|
| 107 |
+
pad_value = self.float_pad_value
|
| 108 |
+
|
| 109 |
+
outputs[key] = torch.nn.utils.rnn.pad_sequence(
|
| 110 |
+
data_list, batch_first=True, padding_value=pad_value
|
| 111 |
+
)
|
| 112 |
+
return outputs
|
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import concurrent.futures
|
| 6 |
+
import librosa
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
|
| 9 |
+
from funasr_detach.register import tables
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
|
| 13 |
+
class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
|
| 14 |
+
|
| 15 |
+
def __init__(self, path):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
contents = []
|
| 19 |
+
with open(path, encoding="utf-8") as fin:
|
| 20 |
+
for line in fin:
|
| 21 |
+
data = json.loads(line.strip())
|
| 22 |
+
if "text" in data: # for sft
|
| 23 |
+
self.contents.append(data["text"])
|
| 24 |
+
if "source" in data: # for speech lab pretrain
|
| 25 |
+
prompt = data["prompt"]
|
| 26 |
+
source = data["source"]
|
| 27 |
+
target = data["target"]
|
| 28 |
+
source_len = data["source_len"]
|
| 29 |
+
target_len = data["target_len"]
|
| 30 |
+
|
| 31 |
+
contents.append(
|
| 32 |
+
{
|
| 33 |
+
"source": source,
|
| 34 |
+
"prompt": prompt,
|
| 35 |
+
"target": target,
|
| 36 |
+
"source_len": source_len,
|
| 37 |
+
"target_len": target_len,
|
| 38 |
+
}
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.contents = []
|
| 42 |
+
total_num = len(contents)
|
| 43 |
+
try:
|
| 44 |
+
rank = dist.get_rank()
|
| 45 |
+
world_size = dist.get_world_size()
|
| 46 |
+
except:
|
| 47 |
+
rank = 0
|
| 48 |
+
world_size = 1
|
| 49 |
+
logging.warning("distributed is not initialized, only single shard")
|
| 50 |
+
num_per_rank = total_num // world_size
|
| 51 |
+
|
| 52 |
+
# rank = 0
|
| 53 |
+
# import ipdb; ipdb.set_trace()
|
| 54 |
+
self.contents = contents[rank * num_per_rank : (rank + 1) * num_per_rank]
|
| 55 |
+
|
| 56 |
+
logging.info(
|
| 57 |
+
"in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(
|
| 58 |
+
rank, len(self.contents), len(contents)
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def __len__(self):
|
| 63 |
+
return len(self.contents)
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, index):
|
| 66 |
+
try:
|
| 67 |
+
data = self.contents[index]
|
| 68 |
+
except:
|
| 69 |
+
print(index)
|
| 70 |
+
return data
|
| 71 |
+
|
| 72 |
+
def get_source_len(self, data_dict):
|
| 73 |
+
return data_dict["source_len"]
|
| 74 |
+
|
| 75 |
+
def get_target_len(self, data_dict):
|
| 76 |
+
|
| 77 |
+
return data_dict["target_len"] if "target_len" in data_dict else 0
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@tables.register("index_ds_classes", "IndexDSJsonl")
|
| 81 |
+
@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
|
| 82 |
+
class IndexDSJsonlRankFull(torch.utils.data.Dataset):
|
| 83 |
+
|
| 84 |
+
def __init__(self, path: str, **kwargs):
|
| 85 |
+
super().__init__()
|
| 86 |
+
|
| 87 |
+
if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
|
| 88 |
+
from funasr_detach.datasets.audio_datasets.scp2jsonl import (
|
| 89 |
+
gen_jsonl_from_wav_text_list,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
jsonl_outdir = os.path.dirname(path[0])
|
| 93 |
+
jsonl_name = (
|
| 94 |
+
"datalist_train.jsonl"
|
| 95 |
+
if kwargs.get("is_training", True)
|
| 96 |
+
else "datalist_val.jsonl"
|
| 97 |
+
)
|
| 98 |
+
jsonl_file_out = os.path.join(jsonl_outdir, jsonl_name)
|
| 99 |
+
if not os.path.exists(jsonl_file_out):
|
| 100 |
+
print(f"datalist is: {path}, generate jsonl from it")
|
| 101 |
+
gen_jsonl_from_wav_text_list(
|
| 102 |
+
path, jsonl_file_out=jsonl_file_out, **kwargs
|
| 103 |
+
)
|
| 104 |
+
path = jsonl_file_out
|
| 105 |
+
|
| 106 |
+
contents = []
|
| 107 |
+
with open(path, encoding="utf-8") as fin:
|
| 108 |
+
for line in fin:
|
| 109 |
+
data = json.loads(line.strip())
|
| 110 |
+
if "text" in data: # for sft
|
| 111 |
+
self.contents.append(data["text"])
|
| 112 |
+
if "source" in data: # for speech lab pretrain
|
| 113 |
+
prompt = data.get("prompt", "<ASR>")
|
| 114 |
+
source = data["source"]
|
| 115 |
+
target = data["target"]
|
| 116 |
+
source_len = data.get("source_len", 1)
|
| 117 |
+
target_len = data.get("target_len", 0)
|
| 118 |
+
|
| 119 |
+
contents.append(
|
| 120 |
+
{
|
| 121 |
+
"source": source,
|
| 122 |
+
"prompt": prompt,
|
| 123 |
+
"target": target,
|
| 124 |
+
"source_len": source_len,
|
| 125 |
+
"target_len": target_len,
|
| 126 |
+
}
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.contents = contents
|
| 130 |
+
|
| 131 |
+
logging.info(
|
| 132 |
+
"total_num of samplers across ranks: {}".format(len(self.contents))
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def __len__(self):
|
| 136 |
+
return len(self.contents)
|
| 137 |
+
|
| 138 |
+
def __getitem__(self, index):
|
| 139 |
+
try:
|
| 140 |
+
data = self.contents[index]
|
| 141 |
+
except:
|
| 142 |
+
print(index)
|
| 143 |
+
return data
|
| 144 |
+
|
| 145 |
+
def get_source_len(self, data_dict):
|
| 146 |
+
return data_dict.get("source_len", 1)
|
| 147 |
+
|
| 148 |
+
def get_target_len(self, data_dict):
|
| 149 |
+
|
| 150 |
+
return data_dict.get("target_len", 0)
|
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import concurrent.futures
|
| 6 |
+
import librosa
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from typing import Collection
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
from torch import nn
|
| 12 |
+
import random
|
| 13 |
+
import re
|
| 14 |
+
from funasr_detach.tokenizer.cleaner import TextCleaner
|
| 15 |
+
from funasr_detach.register import tables
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
|
| 19 |
+
class SpeechPreprocessSpeedPerturb(nn.Module):
|
| 20 |
+
def __init__(self, speed_perturb: list = None, **kwargs):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.speed_perturb = speed_perturb
|
| 23 |
+
|
| 24 |
+
def forward(self, waveform, fs, **kwargs):
|
| 25 |
+
if self.speed_perturb is None:
|
| 26 |
+
return waveform
|
| 27 |
+
speed = random.choice(self.speed_perturb)
|
| 28 |
+
if speed != 1.0:
|
| 29 |
+
if not isinstance(waveform, torch.Tensor):
|
| 30 |
+
waveform = torch.tensor(waveform)
|
| 31 |
+
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
|
| 32 |
+
waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
|
| 33 |
+
)
|
| 34 |
+
waveform = waveform.view(-1)
|
| 35 |
+
|
| 36 |
+
return waveform
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
|
| 40 |
+
class TextPreprocessSegDict(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
seg_dict: str = None,
|
| 44 |
+
text_cleaner: Collection[str] = None,
|
| 45 |
+
split_with_space: bool = False,
|
| 46 |
+
**kwargs
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.text_cleaner = TextCleaner(text_cleaner)
|
| 51 |
+
|
| 52 |
+
def forward(self, text, **kwargs):
|
| 53 |
+
text = self.text_cleaner(text)
|
| 54 |
+
|
| 55 |
+
return text
|
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
from funasr_detach.register import tables
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
|
| 10 |
+
class BatchSampler(torch.utils.data.BatchSampler):
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
dataset,
|
| 15 |
+
batch_type: str = "example",
|
| 16 |
+
batch_size: int = 100,
|
| 17 |
+
buffer_size: int = 30,
|
| 18 |
+
drop_last: bool = False,
|
| 19 |
+
shuffle: bool = True,
|
| 20 |
+
is_training: bool = True,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
|
| 24 |
+
self.drop_last = drop_last
|
| 25 |
+
self.pre_idx = -1
|
| 26 |
+
self.dataset = dataset
|
| 27 |
+
self.total_samples = len(dataset)
|
| 28 |
+
self.batch_type = batch_type
|
| 29 |
+
self.batch_size = int(batch_size)
|
| 30 |
+
self.buffer_size = buffer_size
|
| 31 |
+
self.max_token_length = kwargs.get("max_token_length", 5000)
|
| 32 |
+
self.shuffle_idx = np.arange(self.total_samples)
|
| 33 |
+
self.shuffle = shuffle and is_training
|
| 34 |
+
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return (self.total_samples - 1) // self.batch_size + 1
|
| 38 |
+
|
| 39 |
+
def set_epoch(self, epoch):
|
| 40 |
+
np.random.seed(epoch)
|
| 41 |
+
|
| 42 |
+
def __iter__(self):
|
| 43 |
+
|
| 44 |
+
if self.shuffle:
|
| 45 |
+
np.random.shuffle(self.shuffle_idx)
|
| 46 |
+
|
| 47 |
+
batch = []
|
| 48 |
+
max_token = 0
|
| 49 |
+
num_sample = 0
|
| 50 |
+
|
| 51 |
+
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
| 52 |
+
# print("iter_num: ", iter_num)
|
| 53 |
+
for iter in range(self.pre_idx + 1, iter_num):
|
| 54 |
+
datalen_with_index = []
|
| 55 |
+
for i in range(self.buffer_size):
|
| 56 |
+
idx = iter * self.buffer_size + i
|
| 57 |
+
if idx >= self.total_samples:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
idx_map = self.shuffle_idx[idx]
|
| 61 |
+
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
| 62 |
+
target_len = (
|
| 63 |
+
self.dataset.get_target_len(idx_map)
|
| 64 |
+
if self.batch_type == "length"
|
| 65 |
+
else 0.0
|
| 66 |
+
)
|
| 67 |
+
source_len = (
|
| 68 |
+
self.dataset.get_source_len(idx_map) / self.length_scale_source
|
| 69 |
+
)
|
| 70 |
+
sample_len_cur = source_len + target_len
|
| 71 |
+
|
| 72 |
+
datalen_with_index.append([idx, sample_len_cur])
|
| 73 |
+
|
| 74 |
+
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
| 75 |
+
for item in datalen_with_index_sort:
|
| 76 |
+
idx, sample_len_cur_raw = item
|
| 77 |
+
if sample_len_cur_raw > self.max_token_length:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
max_token_cur = max(max_token, sample_len_cur_raw)
|
| 81 |
+
max_token_padding = 1 + num_sample
|
| 82 |
+
if self.batch_type != "example":
|
| 83 |
+
max_token_padding *= max_token_cur
|
| 84 |
+
if max_token_padding <= self.batch_size:
|
| 85 |
+
batch.append(idx)
|
| 86 |
+
max_token = max_token_cur
|
| 87 |
+
num_sample += 1
|
| 88 |
+
else:
|
| 89 |
+
yield batch
|
| 90 |
+
batch = [idx]
|
| 91 |
+
max_token = sample_len_cur_raw
|
| 92 |
+
num_sample = 1
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@tables.register("batch_sampler_classes", "BatchSampler")
|
| 96 |
+
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
|
| 97 |
+
class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
dataset,
|
| 102 |
+
batch_type: str = "example",
|
| 103 |
+
batch_size: int = 100,
|
| 104 |
+
buffer_size: int = 30,
|
| 105 |
+
drop_last: bool = True,
|
| 106 |
+
shuffle: bool = True,
|
| 107 |
+
is_training: bool = True,
|
| 108 |
+
**kwargs
|
| 109 |
+
):
|
| 110 |
+
|
| 111 |
+
self.drop_last = drop_last
|
| 112 |
+
self.pre_idx = -1
|
| 113 |
+
self.dataset = dataset
|
| 114 |
+
self.total_samples = len(dataset)
|
| 115 |
+
self.batch_type = batch_type
|
| 116 |
+
self.batch_size = int(batch_size)
|
| 117 |
+
self.buffer_size = buffer_size
|
| 118 |
+
self.max_token_length = kwargs.get("max_token_length", 1500)
|
| 119 |
+
self.shuffle_idx = np.arange(self.total_samples)
|
| 120 |
+
self.shuffle = shuffle and is_training
|
| 121 |
+
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
rank = dist.get_rank()
|
| 125 |
+
world_size = dist.get_world_size()
|
| 126 |
+
except:
|
| 127 |
+
rank = 0
|
| 128 |
+
world_size = 1
|
| 129 |
+
self.rank = rank
|
| 130 |
+
self.world_size = world_size
|
| 131 |
+
|
| 132 |
+
def __len__(self):
|
| 133 |
+
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
|
| 134 |
+
|
| 135 |
+
def set_epoch(self, epoch):
|
| 136 |
+
np.random.seed(epoch)
|
| 137 |
+
|
| 138 |
+
def __iter__(self):
|
| 139 |
+
|
| 140 |
+
batch_size_total = self.batch_size * self.world_size
|
| 141 |
+
|
| 142 |
+
if self.shuffle:
|
| 143 |
+
np.random.shuffle(self.shuffle_idx)
|
| 144 |
+
|
| 145 |
+
batch = []
|
| 146 |
+
max_token = 0
|
| 147 |
+
num_sample = 0
|
| 148 |
+
|
| 149 |
+
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
| 150 |
+
# print("iter_num: ", iter_num)
|
| 151 |
+
for iter in range(self.pre_idx + 1, iter_num):
|
| 152 |
+
# if iter == iter_num -1 and self.drop_last:
|
| 153 |
+
# continue
|
| 154 |
+
datalen_with_index = []
|
| 155 |
+
for i in range(self.buffer_size):
|
| 156 |
+
idx = iter * self.buffer_size + i
|
| 157 |
+
if idx >= self.total_samples:
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
idx_map = self.shuffle_idx[idx]
|
| 161 |
+
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
| 162 |
+
|
| 163 |
+
source_len = (
|
| 164 |
+
self.dataset.get_source_len(idx_map) / self.length_scale_source
|
| 165 |
+
)
|
| 166 |
+
target_len = (
|
| 167 |
+
self.dataset.get_target_len(idx_map)
|
| 168 |
+
if self.batch_type == "length"
|
| 169 |
+
else 0.0
|
| 170 |
+
)
|
| 171 |
+
sample_len_cur = source_len + target_len
|
| 172 |
+
|
| 173 |
+
datalen_with_index.append([idx, sample_len_cur])
|
| 174 |
+
|
| 175 |
+
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
| 176 |
+
for item in datalen_with_index_sort:
|
| 177 |
+
idx, sample_len_cur_raw = item
|
| 178 |
+
if sample_len_cur_raw > self.max_token_length:
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
max_token_cur = max(max_token, sample_len_cur_raw)
|
| 182 |
+
max_token_padding = 1 + num_sample
|
| 183 |
+
# if self.batch_type != 'example':
|
| 184 |
+
# max_token_padding *= max_token_cur
|
| 185 |
+
if max_token_padding <= batch_size_total:
|
| 186 |
+
batch.append(idx)
|
| 187 |
+
max_token = max_token_cur
|
| 188 |
+
num_sample += 1
|
| 189 |
+
else:
|
| 190 |
+
batch_rank = batch[
|
| 191 |
+
self.rank * self.batch_size : (self.rank + 1) * self.batch_size
|
| 192 |
+
]
|
| 193 |
+
yield batch_rank
|
| 194 |
+
batch = [idx]
|
| 195 |
+
max_token = sample_len_cur_raw
|
| 196 |
+
num_sample = 1
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
|
| 200 |
+
class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
|
| 201 |
+
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
dataset,
|
| 205 |
+
batch_type: str = "example",
|
| 206 |
+
batch_size: int = 100,
|
| 207 |
+
buffer_size: int = 30,
|
| 208 |
+
drop_last: bool = True,
|
| 209 |
+
shuffle: bool = True,
|
| 210 |
+
is_training: bool = True,
|
| 211 |
+
**kwargs
|
| 212 |
+
):
|
| 213 |
+
|
| 214 |
+
self.drop_last = drop_last
|
| 215 |
+
self.pre_idx = -1
|
| 216 |
+
self.dataset = dataset
|
| 217 |
+
self.total_samples = len(dataset)
|
| 218 |
+
self.batch_type = batch_type
|
| 219 |
+
self.batch_size = int(batch_size)
|
| 220 |
+
self.buffer_size = buffer_size
|
| 221 |
+
self.max_token_length = kwargs.get("max_token_length", 1500)
|
| 222 |
+
self.shuffle_idx = np.arange(self.total_samples)
|
| 223 |
+
self.shuffle = shuffle and is_training
|
| 224 |
+
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
rank = dist.get_rank()
|
| 228 |
+
world_size = dist.get_world_size()
|
| 229 |
+
except:
|
| 230 |
+
rank = 0
|
| 231 |
+
world_size = 1
|
| 232 |
+
self.rank = rank
|
| 233 |
+
self.world_size = world_size
|
| 234 |
+
|
| 235 |
+
def __len__(self):
|
| 236 |
+
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
|
| 237 |
+
|
| 238 |
+
def set_epoch(self, epoch):
|
| 239 |
+
np.random.seed(epoch)
|
| 240 |
+
|
| 241 |
+
def __iter__(self):
|
| 242 |
+
|
| 243 |
+
batch_size_total = self.batch_size * self.world_size
|
| 244 |
+
if self.shuffle:
|
| 245 |
+
np.random.shuffle(self.shuffle_idx)
|
| 246 |
+
|
| 247 |
+
batch_list_all_rank = []
|
| 248 |
+
batch_list_cur = []
|
| 249 |
+
max_token = 0
|
| 250 |
+
num_sample = 0
|
| 251 |
+
|
| 252 |
+
iter_num = (self.total_samples - 1) // self.buffer_size + 1
|
| 253 |
+
# print("iter_num: ", iter_num)
|
| 254 |
+
for iter in range(self.pre_idx + 1, iter_num):
|
| 255 |
+
# if iter == iter_num - 1 and self.drop_last:
|
| 256 |
+
# continue
|
| 257 |
+
datalen_with_index = []
|
| 258 |
+
for i in range(self.buffer_size):
|
| 259 |
+
idx = iter * self.buffer_size + i
|
| 260 |
+
if idx >= self.total_samples:
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
idx_map = self.shuffle_idx[idx]
|
| 264 |
+
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
|
| 265 |
+
|
| 266 |
+
source_len = (
|
| 267 |
+
self.dataset.get_source_len(idx_map) / self.length_scale_source
|
| 268 |
+
)
|
| 269 |
+
target_len = (
|
| 270 |
+
self.dataset.get_target_len(idx_map)
|
| 271 |
+
if self.batch_type == "length"
|
| 272 |
+
else 0.0
|
| 273 |
+
)
|
| 274 |
+
sample_len_cur = source_len + target_len
|
| 275 |
+
|
| 276 |
+
datalen_with_index.append([idx, sample_len_cur])
|
| 277 |
+
|
| 278 |
+
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
|
| 279 |
+
for ii, item in enumerate(datalen_with_index_sort):
|
| 280 |
+
is_last_batch = iter == iter_num - 1 and ii == len(
|
| 281 |
+
datalen_with_index_sort
|
| 282 |
+
)
|
| 283 |
+
idx, sample_len_cur_raw = item
|
| 284 |
+
if sample_len_cur_raw > self.max_token_length:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
max_token_cur = max(max_token, sample_len_cur_raw)
|
| 288 |
+
max_token_padding = 1 + num_sample
|
| 289 |
+
|
| 290 |
+
if self.batch_type != "example":
|
| 291 |
+
max_token_padding *= max_token_cur
|
| 292 |
+
if len(batch_list_all_rank) < self.world_size:
|
| 293 |
+
|
| 294 |
+
if max_token_padding <= self.batch_size:
|
| 295 |
+
batch_list_cur.append(idx)
|
| 296 |
+
max_token = max_token_cur
|
| 297 |
+
num_sample += 1
|
| 298 |
+
else:
|
| 299 |
+
batch_list_all_rank.append(batch_list_cur)
|
| 300 |
+
batch_list_cur = []
|
| 301 |
+
else:
|
| 302 |
+
batch_rank = batch_list_all_rank[self.rank]
|
| 303 |
+
yield batch_rank
|
| 304 |
+
batch_list_all_rank = [idx]
|
| 305 |
+
max_token = sample_len_cur_raw
|
| 306 |
+
num_sample = 1
|
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
import hydra
|
| 6 |
+
from omegaconf import DictConfig, OmegaConf
|
| 7 |
+
import concurrent.futures
|
| 8 |
+
import librosa
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def gen_jsonl_from_wav_text_list(
|
| 13 |
+
path, data_type_list=("source", "target"), jsonl_file_out: str = None, **kwargs
|
| 14 |
+
):
|
| 15 |
+
try:
|
| 16 |
+
rank = dist.get_rank()
|
| 17 |
+
world_size = dist.get_world_size()
|
| 18 |
+
except:
|
| 19 |
+
rank = 0
|
| 20 |
+
world_size = 1
|
| 21 |
+
|
| 22 |
+
cpu_cores = os.cpu_count() or 1
|
| 23 |
+
print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
|
| 24 |
+
if rank == 0:
|
| 25 |
+
json_dict = {}
|
| 26 |
+
for data_type, data_file in zip(data_type_list, path):
|
| 27 |
+
json_dict[data_type] = {}
|
| 28 |
+
with open(data_file, "r") as f:
|
| 29 |
+
|
| 30 |
+
data_file_lists = f.readlines()
|
| 31 |
+
lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
|
| 32 |
+
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
|
| 33 |
+
with concurrent.futures.ThreadPoolExecutor(
|
| 34 |
+
max_workers=cpu_cores
|
| 35 |
+
) as executor:
|
| 36 |
+
|
| 37 |
+
futures = [
|
| 38 |
+
executor.submit(
|
| 39 |
+
parse_context_length,
|
| 40 |
+
data_file_lists[
|
| 41 |
+
i * lines_for_each_th : (i + 1) * lines_for_each_th
|
| 42 |
+
],
|
| 43 |
+
data_type,
|
| 44 |
+
)
|
| 45 |
+
for i in range(task_num)
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
for future in concurrent.futures.as_completed(futures):
|
| 49 |
+
|
| 50 |
+
json_dict[data_type].update(future.result())
|
| 51 |
+
# print(json_dict)
|
| 52 |
+
|
| 53 |
+
with open(jsonl_file_out, "w") as f:
|
| 54 |
+
for key in json_dict[data_type_list[0]].keys():
|
| 55 |
+
jsonl_line = {"key": key}
|
| 56 |
+
for data_file in data_type_list:
|
| 57 |
+
jsonl_line.update(json_dict[data_file][key])
|
| 58 |
+
jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
|
| 59 |
+
f.write(jsonl_line + "\n")
|
| 60 |
+
f.flush()
|
| 61 |
+
|
| 62 |
+
else:
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
if world_size > 1:
|
| 66 |
+
dist.barrier()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def parse_context_length(data_list: list, data_type: str):
|
| 70 |
+
|
| 71 |
+
res = {}
|
| 72 |
+
for i, line in enumerate(data_list):
|
| 73 |
+
key, line = line.strip().split(maxsplit=1)
|
| 74 |
+
line = line.strip()
|
| 75 |
+
if os.path.exists(line):
|
| 76 |
+
waveform, _ = librosa.load(line, sr=16000)
|
| 77 |
+
sample_num = len(waveform)
|
| 78 |
+
context_len = int(sample_num // 16000 * 1000 / 10)
|
| 79 |
+
else:
|
| 80 |
+
context_len = len(line.split()) if " " in line else len(line)
|
| 81 |
+
res[key] = {data_type: line, f"{data_type}_len": context_len}
|
| 82 |
+
return res
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@hydra.main(config_name=None, version_base=None)
|
| 86 |
+
def main_hydra(cfg: DictConfig):
|
| 87 |
+
|
| 88 |
+
kwargs = OmegaConf.to_container(cfg, resolve=True)
|
| 89 |
+
|
| 90 |
+
scp_file_list = kwargs.get(
|
| 91 |
+
"scp_file_list",
|
| 92 |
+
(
|
| 93 |
+
"/Users/zhifu/funasr1.0/test_local/wav.scp",
|
| 94 |
+
"/Users/zhifu/funasr1.0/test_local/text.txt",
|
| 95 |
+
),
|
| 96 |
+
)
|
| 97 |
+
if isinstance(scp_file_list, str):
|
| 98 |
+
scp_file_list = eval(scp_file_list)
|
| 99 |
+
data_type_list = kwargs.get("data_type_list", ("source", "target"))
|
| 100 |
+
jsonl_file_out = kwargs.get(
|
| 101 |
+
"jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
|
| 102 |
+
)
|
| 103 |
+
gen_jsonl_from_wav_text_list(
|
| 104 |
+
scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
"""
|
| 109 |
+
python -m funasr_detach.datasets.audio_datasets.scp2jsonl \
|
| 110 |
+
++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
|
| 111 |
+
++data_type_list='["source", "target"]' \
|
| 112 |
+
++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
main_hydra()
|
|
File without changes
|
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def download_dataset():
|
| 2 |
+
pass
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def download_dataset_from_ms(**kwargs):
|
| 6 |
+
from modelscope.msdatasets import MsDataset
|
| 7 |
+
|
| 8 |
+
dataset_name = kwargs.get(
|
| 9 |
+
"dataset_name", "speech_asr/speech_asr_aishell1_trainsets"
|
| 10 |
+
)
|
| 11 |
+
subset_name = kwargs.get("subset_name", "default")
|
| 12 |
+
split = kwargs.get("split", "train")
|
| 13 |
+
data_dump_dir = kwargs.get("data_dump_dir", None)
|
| 14 |
+
ds = MsDataset.load(
|
| 15 |
+
dataset_name=dataset_name,
|
| 16 |
+
subset_name=subset_name,
|
| 17 |
+
split=split,
|
| 18 |
+
cache_dir=data_dump_dir,
|
| 19 |
+
)
|
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import threading
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
|
| 6 |
+
from funasr_detach.download.name_maps_from_hub import name_maps_ms, name_maps_hf
|
| 7 |
+
|
| 8 |
+
# Global cache for downloaded models to avoid repeated downloads
|
| 9 |
+
# Key: (repo_id, model_revision, model_hub)
|
| 10 |
+
# Value: repo_cache_dir
|
| 11 |
+
_model_cache = {}
|
| 12 |
+
_cache_lock = threading.Lock()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def download_model(**kwargs):
|
| 16 |
+
model_hub = kwargs.get("model_hub", "ms")
|
| 17 |
+
model_or_path = kwargs.get("model")
|
| 18 |
+
repo_path = kwargs.get("repo_path", "")
|
| 19 |
+
|
| 20 |
+
# Handle name mapping based on model_hub
|
| 21 |
+
if model_hub == "ms" and model_or_path in name_maps_ms:
|
| 22 |
+
model_or_path = name_maps_ms[model_or_path]
|
| 23 |
+
elif model_hub == "hf" and model_or_path in name_maps_hf:
|
| 24 |
+
model_or_path = name_maps_hf[model_or_path]
|
| 25 |
+
|
| 26 |
+
model_revision = kwargs.get("model_revision")
|
| 27 |
+
|
| 28 |
+
# Download model if it doesn't exist locally
|
| 29 |
+
if not os.path.exists(model_or_path):
|
| 30 |
+
if model_hub == "local":
|
| 31 |
+
# For local models, the path should already exist
|
| 32 |
+
raise FileNotFoundError(f"Local model path does not exist: {model_or_path}")
|
| 33 |
+
elif model_hub in ["ms", "hf"]:
|
| 34 |
+
repo_path, model_or_path = get_or_download_model_dir(
|
| 35 |
+
model_or_path,
|
| 36 |
+
model_revision,
|
| 37 |
+
is_training=kwargs.get("is_training"),
|
| 38 |
+
check_latest=kwargs.get("kwargs", True),
|
| 39 |
+
model_hub=model_hub,
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
raise ValueError(f"Unsupported model_hub: {model_hub}")
|
| 43 |
+
|
| 44 |
+
print(f"Using model path: {model_or_path}")
|
| 45 |
+
kwargs["model_path"] = model_or_path
|
| 46 |
+
kwargs["repo_path"] = repo_path
|
| 47 |
+
|
| 48 |
+
# Common logic for processing configuration files (same for all model hubs)
|
| 49 |
+
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
|
| 50 |
+
with open(
|
| 51 |
+
os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8"
|
| 52 |
+
) as f:
|
| 53 |
+
conf_json = json.load(f)
|
| 54 |
+
cfg = {}
|
| 55 |
+
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
|
| 56 |
+
cfg.update(kwargs)
|
| 57 |
+
config = OmegaConf.load(cfg["config"])
|
| 58 |
+
kwargs = OmegaConf.merge(config, cfg)
|
| 59 |
+
kwargs["model"] = config["model"]
|
| 60 |
+
elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
|
| 61 |
+
os.path.join(model_or_path, "model.pt")
|
| 62 |
+
):
|
| 63 |
+
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
|
| 64 |
+
kwargs = OmegaConf.merge(config, kwargs)
|
| 65 |
+
init_param = os.path.join(model_or_path, "model.pb")
|
| 66 |
+
kwargs["init_param"] = init_param
|
| 67 |
+
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
| 68 |
+
kwargs["tokenizer_conf"]["token_list"] = os.path.join(
|
| 69 |
+
model_or_path, "tokens.txt"
|
| 70 |
+
)
|
| 71 |
+
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
| 72 |
+
kwargs["tokenizer_conf"]["token_list"] = os.path.join(
|
| 73 |
+
model_or_path, "tokens.json"
|
| 74 |
+
)
|
| 75 |
+
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
| 76 |
+
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(
|
| 77 |
+
model_or_path, "seg_dict"
|
| 78 |
+
)
|
| 79 |
+
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
| 80 |
+
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(
|
| 81 |
+
model_or_path, "bpe.model"
|
| 82 |
+
)
|
| 83 |
+
kwargs["model"] = config["model"]
|
| 84 |
+
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
| 85 |
+
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
| 86 |
+
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
| 87 |
+
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
| 88 |
+
|
| 89 |
+
return OmegaConf.to_container(kwargs, resolve=True)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
|
| 93 |
+
|
| 94 |
+
if isinstance(file_path_metas, dict):
|
| 95 |
+
for k, v in file_path_metas.items():
|
| 96 |
+
if isinstance(v, str):
|
| 97 |
+
p = os.path.join(model_or_path, v)
|
| 98 |
+
if os.path.exists(p):
|
| 99 |
+
cfg[k] = p
|
| 100 |
+
elif isinstance(v, dict):
|
| 101 |
+
if k not in cfg:
|
| 102 |
+
cfg[k] = {}
|
| 103 |
+
add_file_root_path(model_or_path, v, cfg[k])
|
| 104 |
+
|
| 105 |
+
return cfg
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_or_download_model_dir(
|
| 109 |
+
model,
|
| 110 |
+
model_revision=None,
|
| 111 |
+
is_training=False,
|
| 112 |
+
check_latest=True,
|
| 113 |
+
model_hub="ms",
|
| 114 |
+
):
|
| 115 |
+
"""Get local model directory or download model if necessary.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
model (str): model id or path to local model directory.
|
| 119 |
+
For HF subfolders, use format: "repo_id/subfolder_path"
|
| 120 |
+
model_revision (str, optional): model version number.
|
| 121 |
+
is_training (bool): Whether this is for training
|
| 122 |
+
check_latest (bool): Whether to check for latest version
|
| 123 |
+
model_hub (str): Model hub type ("ms" for ModelScope, "hf" for HuggingFace)
|
| 124 |
+
"""
|
| 125 |
+
# Extract repo_id for caching (handle subfolder case)
|
| 126 |
+
if "/" in model and len(model.split("/")) > 2:
|
| 127 |
+
parts = model.split("/")
|
| 128 |
+
repo_id = "/".join(parts[:2]) # e.g., "organization/repo" or "stepfun-ai/Step-Audio-EditX"
|
| 129 |
+
subfolder = "/".join(parts[2:]) # e.g., "subfolder/model"
|
| 130 |
+
else:
|
| 131 |
+
repo_id = model
|
| 132 |
+
subfolder = None
|
| 133 |
+
|
| 134 |
+
# Create cache key
|
| 135 |
+
cache_key = (repo_id, model_revision, model_hub)
|
| 136 |
+
|
| 137 |
+
# Check cache first
|
| 138 |
+
with _cache_lock:
|
| 139 |
+
if cache_key in _model_cache:
|
| 140 |
+
cached_repo_dir = _model_cache[cache_key]
|
| 141 |
+
print(f"Using cached model for {repo_id}: {cached_repo_dir}")
|
| 142 |
+
|
| 143 |
+
# For subfolder case, construct the model_cache_dir from cached repo
|
| 144 |
+
if subfolder:
|
| 145 |
+
model_cache_dir = os.path.join(cached_repo_dir, subfolder)
|
| 146 |
+
if not os.path.exists(model_cache_dir):
|
| 147 |
+
raise FileNotFoundError(f"Subfolder {subfolder} not found in cached repo {repo_id}")
|
| 148 |
+
else:
|
| 149 |
+
model_cache_dir = cached_repo_dir
|
| 150 |
+
|
| 151 |
+
return cached_repo_dir, model_cache_dir
|
| 152 |
+
|
| 153 |
+
# Cache miss, need to download
|
| 154 |
+
if model_hub == "ms":
|
| 155 |
+
# ModelScope download
|
| 156 |
+
from modelscope.hub.snapshot_download import snapshot_download
|
| 157 |
+
from modelscope.utils.constant import Invoke, ThirdParty
|
| 158 |
+
|
| 159 |
+
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
|
| 160 |
+
|
| 161 |
+
# Download the repo (use repo_id, not the full model path with subfolder)
|
| 162 |
+
repo_cache_dir = snapshot_download(
|
| 163 |
+
repo_id,
|
| 164 |
+
revision=model_revision,
|
| 165 |
+
user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"},
|
| 166 |
+
)
|
| 167 |
+
repo_cache_dir = normalize_cache_path(repo_cache_dir)
|
| 168 |
+
|
| 169 |
+
# Construct model_cache_dir
|
| 170 |
+
if subfolder:
|
| 171 |
+
model_cache_dir = os.path.join(repo_cache_dir, subfolder)
|
| 172 |
+
if not os.path.exists(model_cache_dir):
|
| 173 |
+
raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
|
| 174 |
+
else:
|
| 175 |
+
model_cache_dir = normalize_cache_path(repo_cache_dir)
|
| 176 |
+
|
| 177 |
+
elif model_hub == "hf":
|
| 178 |
+
# HuggingFace download
|
| 179 |
+
try:
|
| 180 |
+
from huggingface_hub import snapshot_download
|
| 181 |
+
except ImportError:
|
| 182 |
+
raise ImportError(
|
| 183 |
+
"huggingface_hub is required for downloading from HuggingFace. "
|
| 184 |
+
"Please install it with: pip install huggingface_hub"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Download the repo (use repo_id, not the full model path with subfolder)
|
| 188 |
+
repo_cache_dir = snapshot_download(
|
| 189 |
+
repo_id=repo_id,
|
| 190 |
+
revision=model_revision,
|
| 191 |
+
allow_patterns=None, # Download all files to ensure resource files are available
|
| 192 |
+
)
|
| 193 |
+
repo_cache_dir = normalize_cache_path(repo_cache_dir)
|
| 194 |
+
|
| 195 |
+
# Construct model_cache_dir
|
| 196 |
+
if subfolder:
|
| 197 |
+
model_cache_dir = os.path.join(repo_cache_dir, subfolder)
|
| 198 |
+
if not os.path.exists(model_cache_dir):
|
| 199 |
+
raise FileNotFoundError(f"Subfolder {subfolder} not found in downloaded repo {repo_id}")
|
| 200 |
+
else:
|
| 201 |
+
model_cache_dir = normalize_cache_path(repo_cache_dir)
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError(f"Unsupported model_hub: {model_hub}")
|
| 204 |
+
|
| 205 |
+
# Cache the result before returning
|
| 206 |
+
with _cache_lock:
|
| 207 |
+
_model_cache[cache_key] = repo_cache_dir
|
| 208 |
+
|
| 209 |
+
print(f"Model downloaded to: {model_cache_dir}")
|
| 210 |
+
return repo_cache_dir, model_cache_dir
|
| 211 |
+
|
| 212 |
+
def normalize_cache_path(cache_path):
|
| 213 |
+
"""Normalize cache path to ensure consistent format with snapshots/{commit_id}."""
|
| 214 |
+
# Check if the cache_path directory contains a snapshots folder
|
| 215 |
+
snapshots_dir = os.path.join(cache_path, "snapshots")
|
| 216 |
+
if os.path.exists(snapshots_dir) and os.path.isdir(snapshots_dir):
|
| 217 |
+
# Find the commit_id subdirectory in snapshots
|
| 218 |
+
try:
|
| 219 |
+
snapshot_items = os.listdir(snapshots_dir)
|
| 220 |
+
# Look for the first directory (should be the commit_id)
|
| 221 |
+
for item in snapshot_items:
|
| 222 |
+
item_path = os.path.join(snapshots_dir, item)
|
| 223 |
+
if os.path.isdir(item_path):
|
| 224 |
+
# Found commit_id directory, return the full path
|
| 225 |
+
return os.path.join(cache_path, "snapshots", item)
|
| 226 |
+
except OSError:
|
| 227 |
+
pass
|
| 228 |
+
|
| 229 |
+
# If no snapshots directory found or error occurred, return original path
|
| 230 |
+
return cache_path
|
| 231 |
+
|
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
from abc import ABCMeta, abstractmethod
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Generator, Union
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
from urllib.parse import urlparse
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def download_from_url(url):
|
| 15 |
+
result = urlparse(url)
|
| 16 |
+
file_path = None
|
| 17 |
+
if result.scheme is not None and len(result.scheme) > 0:
|
| 18 |
+
storage = HTTPStorage()
|
| 19 |
+
# bytes
|
| 20 |
+
data = storage.read(url)
|
| 21 |
+
work_dir = tempfile.TemporaryDirectory().name
|
| 22 |
+
if not os.path.exists(work_dir):
|
| 23 |
+
os.makedirs(work_dir)
|
| 24 |
+
file_path = os.path.join(work_dir, os.path.basename(url))
|
| 25 |
+
with open(file_path, "wb") as fb:
|
| 26 |
+
fb.write(data)
|
| 27 |
+
assert file_path is not None, f"failed to download: {url}"
|
| 28 |
+
return file_path
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Storage(metaclass=ABCMeta):
|
| 32 |
+
"""Abstract class of storage.
|
| 33 |
+
|
| 34 |
+
All backends need to implement two apis: ``read()`` and ``read_text()``.
|
| 35 |
+
``read()`` reads the file as a byte stream and ``read_text()`` reads
|
| 36 |
+
the file as texts.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def read(self, filepath: str):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def read_text(self, filepath: str):
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def write_text(
|
| 53 |
+
self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
|
| 54 |
+
) -> None:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class LocalStorage(Storage):
|
| 59 |
+
"""Local hard disk storage"""
|
| 60 |
+
|
| 61 |
+
def read(self, filepath: Union[str, Path]) -> bytes:
|
| 62 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
filepath (str or Path): Path to read data.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
bytes: Expected bytes object.
|
| 69 |
+
"""
|
| 70 |
+
with open(filepath, "rb") as f:
|
| 71 |
+
content = f.read()
|
| 72 |
+
return content
|
| 73 |
+
|
| 74 |
+
def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
|
| 75 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
filepath (str or Path): Path to read data.
|
| 79 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 80 |
+
Default: 'utf-8'.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
str: Expected text reading from ``filepath``.
|
| 84 |
+
"""
|
| 85 |
+
with open(filepath, "r", encoding=encoding) as f:
|
| 86 |
+
value_buf = f.read()
|
| 87 |
+
return value_buf
|
| 88 |
+
|
| 89 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 90 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
| 91 |
+
|
| 92 |
+
Note:
|
| 93 |
+
``write`` will create a directory if the directory of ``filepath``
|
| 94 |
+
does not exist.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
obj (bytes): Data to be written.
|
| 98 |
+
filepath (str or Path): Path to write data.
|
| 99 |
+
"""
|
| 100 |
+
dirname = os.path.dirname(filepath)
|
| 101 |
+
if dirname and not os.path.exists(dirname):
|
| 102 |
+
os.makedirs(dirname, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
with open(filepath, "wb") as f:
|
| 105 |
+
f.write(obj)
|
| 106 |
+
|
| 107 |
+
def write_text(
|
| 108 |
+
self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
|
| 109 |
+
) -> None:
|
| 110 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
| 111 |
+
|
| 112 |
+
Note:
|
| 113 |
+
``write_text`` will create a directory if the directory of
|
| 114 |
+
``filepath`` does not exist.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
obj (str): Data to be written.
|
| 118 |
+
filepath (str or Path): Path to write data.
|
| 119 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 120 |
+
Default: 'utf-8'.
|
| 121 |
+
"""
|
| 122 |
+
dirname = os.path.dirname(filepath)
|
| 123 |
+
if dirname and not os.path.exists(dirname):
|
| 124 |
+
os.makedirs(dirname, exist_ok=True)
|
| 125 |
+
|
| 126 |
+
with open(filepath, "w", encoding=encoding) as f:
|
| 127 |
+
f.write(obj)
|
| 128 |
+
|
| 129 |
+
@contextlib.contextmanager
|
| 130 |
+
def as_local_path(
|
| 131 |
+
self, filepath: Union[str, Path]
|
| 132 |
+
) -> Generator[Union[str, Path], None, None]:
|
| 133 |
+
"""Only for unified API and do nothing."""
|
| 134 |
+
yield filepath
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class HTTPStorage(Storage):
|
| 138 |
+
"""HTTP and HTTPS storage."""
|
| 139 |
+
|
| 140 |
+
def read(self, url):
|
| 141 |
+
# TODO @wenmeng.zwm add progress bar if file is too large
|
| 142 |
+
r = requests.get(url)
|
| 143 |
+
r.raise_for_status()
|
| 144 |
+
return r.content
|
| 145 |
+
|
| 146 |
+
def read_text(self, url):
|
| 147 |
+
r = requests.get(url)
|
| 148 |
+
r.raise_for_status()
|
| 149 |
+
return r.text
|
| 150 |
+
|
| 151 |
+
@contextlib.contextmanager
|
| 152 |
+
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
| 153 |
+
"""Download a file from ``filepath``.
|
| 154 |
+
|
| 155 |
+
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
| 156 |
+
can be called with ``with`` statement, and when exists from the
|
| 157 |
+
``with`` statement, the temporary path will be released.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
filepath (str): Download a file from ``filepath``.
|
| 161 |
+
|
| 162 |
+
Examples:
|
| 163 |
+
>>> storage = HTTPStorage()
|
| 164 |
+
>>> # After existing from the ``with`` clause,
|
| 165 |
+
>>> # the path will be removed
|
| 166 |
+
>>> with storage.get_local_path('http://path/to/file') as path:
|
| 167 |
+
... # do something here
|
| 168 |
+
"""
|
| 169 |
+
try:
|
| 170 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
| 171 |
+
f.write(self.read(filepath))
|
| 172 |
+
f.close()
|
| 173 |
+
yield f.name
|
| 174 |
+
finally:
|
| 175 |
+
os.remove(f.name)
|
| 176 |
+
|
| 177 |
+
def write(self, obj: bytes, url: Union[str, Path]) -> None:
|
| 178 |
+
raise NotImplementedError("write is not supported by HTTP Storage")
|
| 179 |
+
|
| 180 |
+
def write_text(
|
| 181 |
+
self, obj: str, url: Union[str, Path], encoding: str = "utf-8"
|
| 182 |
+
) -> None:
|
| 183 |
+
raise NotImplementedError("write_text is not supported by HTTP Storage")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class OSSStorage(Storage):
|
| 187 |
+
"""OSS storage."""
|
| 188 |
+
|
| 189 |
+
def __init__(self, oss_config_file=None):
|
| 190 |
+
# read from config file or env var
|
| 191 |
+
raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")
|
| 192 |
+
|
| 193 |
+
def read(self, filepath):
|
| 194 |
+
raise NotImplementedError("OSSStorage.read to be implemented in the future")
|
| 195 |
+
|
| 196 |
+
def read_text(self, filepath, encoding="utf-8"):
|
| 197 |
+
raise NotImplementedError(
|
| 198 |
+
"OSSStorage.read_text to be implemented in the future"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
@contextlib.contextmanager
|
| 202 |
+
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
| 203 |
+
"""Download a file from ``filepath``.
|
| 204 |
+
|
| 205 |
+
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
| 206 |
+
can be called with ``with`` statement, and when exists from the
|
| 207 |
+
``with`` statement, the temporary path will be released.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
filepath (str): Download a file from ``filepath``.
|
| 211 |
+
|
| 212 |
+
Examples:
|
| 213 |
+
>>> storage = OSSStorage()
|
| 214 |
+
>>> # After existing from the ``with`` clause,
|
| 215 |
+
>>> # the path will be removed
|
| 216 |
+
>>> with storage.get_local_path('http://path/to/file') as path:
|
| 217 |
+
... # do something here
|
| 218 |
+
"""
|
| 219 |
+
try:
|
| 220 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
| 221 |
+
f.write(self.read(filepath))
|
| 222 |
+
f.close()
|
| 223 |
+
yield f.name
|
| 224 |
+
finally:
|
| 225 |
+
os.remove(f.name)
|
| 226 |
+
|
| 227 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 228 |
+
raise NotImplementedError("OSSStorage.write to be implemented in the future")
|
| 229 |
+
|
| 230 |
+
def write_text(
|
| 231 |
+
self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8"
|
| 232 |
+
) -> None:
|
| 233 |
+
raise NotImplementedError(
|
| 234 |
+
"OSSStorage.write_text to be implemented in the future"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
G_STORAGES = {}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class File(object):
|
| 242 |
+
_prefix_to_storage: dict = {
|
| 243 |
+
"oss": OSSStorage,
|
| 244 |
+
"http": HTTPStorage,
|
| 245 |
+
"https": HTTPStorage,
|
| 246 |
+
"local": LocalStorage,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
@staticmethod
|
| 250 |
+
def _get_storage(uri):
|
| 251 |
+
assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"
|
| 252 |
+
|
| 253 |
+
if "://" not in uri:
|
| 254 |
+
# local path
|
| 255 |
+
storage_type = "local"
|
| 256 |
+
else:
|
| 257 |
+
prefix, _ = uri.split("://")
|
| 258 |
+
storage_type = prefix
|
| 259 |
+
|
| 260 |
+
assert storage_type in File._prefix_to_storage, (
|
| 261 |
+
f"Unsupported uri {uri}, valid prefixs: "
|
| 262 |
+
f"{list(File._prefix_to_storage.keys())}"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if storage_type not in G_STORAGES:
|
| 266 |
+
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
|
| 267 |
+
|
| 268 |
+
return G_STORAGES[storage_type]
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def read(uri: str) -> bytes:
|
| 272 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
filepath (str or Path): Path to read data.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
bytes: Expected bytes object.
|
| 279 |
+
"""
|
| 280 |
+
storage = File._get_storage(uri)
|
| 281 |
+
return storage.read(uri)
|
| 282 |
+
|
| 283 |
+
@staticmethod
|
| 284 |
+
def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
|
| 285 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
filepath (str or Path): Path to read data.
|
| 289 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 290 |
+
Default: 'utf-8'.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
str: Expected text reading from ``filepath``.
|
| 294 |
+
"""
|
| 295 |
+
storage = File._get_storage(uri)
|
| 296 |
+
return storage.read_text(uri)
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def write(obj: bytes, uri: Union[str, Path]) -> None:
|
| 300 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
| 301 |
+
|
| 302 |
+
Note:
|
| 303 |
+
``write`` will create a directory if the directory of ``filepath``
|
| 304 |
+
does not exist.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
obj (bytes): Data to be written.
|
| 308 |
+
filepath (str or Path): Path to write data.
|
| 309 |
+
"""
|
| 310 |
+
storage = File._get_storage(uri)
|
| 311 |
+
return storage.write(obj, uri)
|
| 312 |
+
|
| 313 |
+
@staticmethod
|
| 314 |
+
def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
|
| 315 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
| 316 |
+
|
| 317 |
+
Note:
|
| 318 |
+
``write_text`` will create a directory if the directory of
|
| 319 |
+
``filepath`` does not exist.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
obj (str): Data to be written.
|
| 323 |
+
filepath (str or Path): Path to write data.
|
| 324 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 325 |
+
Default: 'utf-8'.
|
| 326 |
+
"""
|
| 327 |
+
storage = File._get_storage(uri)
|
| 328 |
+
return storage.write_text(obj, uri)
|
| 329 |
+
|
| 330 |
+
@contextlib.contextmanager
|
| 331 |
+
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
|
| 332 |
+
"""Only for unified API and do nothing."""
|
| 333 |
+
storage = File._get_storage(uri)
|
| 334 |
+
with storage.as_local_path(uri) as local_path:
|
| 335 |
+
yield local_path
|
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name_maps_ms = {
|
| 2 |
+
"paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
| 3 |
+
"paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
|
| 4 |
+
"paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
|
| 5 |
+
"paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
| 6 |
+
"fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
| 7 |
+
"ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
|
| 8 |
+
"ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
| 9 |
+
"fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
|
| 10 |
+
"cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
name_maps_hf = {}
|
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from funasr_detach.utils.types import str2bool
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument("--model-name", type=str, required=True)
|
| 11 |
+
parser.add_argument("--export-dir", type=str, required=True)
|
| 12 |
+
parser.add_argument(
|
| 13 |
+
"--export", type=str2bool, default=True, help="whether to export model"
|
| 14 |
+
)
|
| 15 |
+
parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torch"]')
|
| 16 |
+
parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]')
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--quantize", type=str2bool, default=False, help="export quantized model"
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--fallback-num", type=int, default=0, help="amp fallback number"
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument("--audio_in", type=str, default=None, help='["wav", "wav.scp"]')
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--model_revision", type=str, default=None, help="model_revision"
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--calib_num", type=int, default=200, help="calib max num")
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
model_dir = args.model_name
|
| 31 |
+
if not Path(args.model_name).exists():
|
| 32 |
+
from modelscope.hub.snapshot_download import snapshot_download
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
model_dir = snapshot_download(
|
| 36 |
+
args.model_name, cache_dir=args.export_dir, revision=args.model_revision
|
| 37 |
+
)
|
| 38 |
+
except:
|
| 39 |
+
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
|
| 40 |
+
model_dir
|
| 41 |
+
)
|
| 42 |
+
if args.export:
|
| 43 |
+
model_file = os.path.join(model_dir, "model.onnx")
|
| 44 |
+
if args.quantize:
|
| 45 |
+
model_file = os.path.join(model_dir, "model_quant.onnx")
|
| 46 |
+
if not os.path.exists(model_file):
|
| 47 |
+
print(".onnx is not exist, begin to export onnx")
|
| 48 |
+
from funasr_detach.bin.export_model import ModelExport
|
| 49 |
+
|
| 50 |
+
export_model = ModelExport(
|
| 51 |
+
cache_dir=args.export_dir,
|
| 52 |
+
onnx=True,
|
| 53 |
+
device="cpu",
|
| 54 |
+
quant=args.quantize,
|
| 55 |
+
)
|
| 56 |
+
export_model.export(model_dir)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from typing import Union
|
| 5 |
+
import logging
|
| 6 |
+
import humanfriendly
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from torch_complex.tensor import ComplexTensor
|
| 13 |
+
except:
|
| 14 |
+
print("Please install torch_complex firstly")
|
| 15 |
+
|
| 16 |
+
from funasr_detach.frontends.utils.log_mel import LogMel
|
| 17 |
+
from funasr_detach.frontends.utils.stft import Stft
|
| 18 |
+
from funasr_detach.frontends.utils.frontend import Frontend
|
| 19 |
+
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DefaultFrontend(nn.Module):
|
| 23 |
+
"""Conventional frontend structure for ASR.
|
| 24 |
+
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
fs: Union[int, str] = 16000,
|
| 30 |
+
n_fft: int = 512,
|
| 31 |
+
win_length: int = None,
|
| 32 |
+
hop_length: int = 128,
|
| 33 |
+
window: Optional[str] = "hann",
|
| 34 |
+
center: bool = True,
|
| 35 |
+
normalized: bool = False,
|
| 36 |
+
onesided: bool = True,
|
| 37 |
+
n_mels: int = 80,
|
| 38 |
+
fmin: int = None,
|
| 39 |
+
fmax: int = None,
|
| 40 |
+
htk: bool = False,
|
| 41 |
+
frontend_conf: Optional[dict] = None,
|
| 42 |
+
apply_stft: bool = True,
|
| 43 |
+
use_channel: int = None,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
if isinstance(fs, str):
|
| 47 |
+
fs = humanfriendly.parse_size(fs)
|
| 48 |
+
|
| 49 |
+
# Deepcopy (In general, dict shouldn't be used as default arg)
|
| 50 |
+
frontend_conf = copy.deepcopy(frontend_conf)
|
| 51 |
+
self.hop_length = hop_length
|
| 52 |
+
|
| 53 |
+
if apply_stft:
|
| 54 |
+
self.stft = Stft(
|
| 55 |
+
n_fft=n_fft,
|
| 56 |
+
win_length=win_length,
|
| 57 |
+
hop_length=hop_length,
|
| 58 |
+
center=center,
|
| 59 |
+
window=window,
|
| 60 |
+
normalized=normalized,
|
| 61 |
+
onesided=onesided,
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
self.stft = None
|
| 65 |
+
self.apply_stft = apply_stft
|
| 66 |
+
|
| 67 |
+
if frontend_conf is not None:
|
| 68 |
+
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
| 69 |
+
else:
|
| 70 |
+
self.frontend = None
|
| 71 |
+
|
| 72 |
+
self.logmel = LogMel(
|
| 73 |
+
fs=fs,
|
| 74 |
+
n_fft=n_fft,
|
| 75 |
+
n_mels=n_mels,
|
| 76 |
+
fmin=fmin,
|
| 77 |
+
fmax=fmax,
|
| 78 |
+
htk=htk,
|
| 79 |
+
)
|
| 80 |
+
self.n_mels = n_mels
|
| 81 |
+
self.use_channel = use_channel
|
| 82 |
+
self.frontend_type = "default"
|
| 83 |
+
|
| 84 |
+
def output_size(self) -> int:
|
| 85 |
+
return self.n_mels
|
| 86 |
+
|
| 87 |
+
def forward(
|
| 88 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 89 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 90 |
+
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
| 91 |
+
if self.stft is not None:
|
| 92 |
+
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
| 93 |
+
else:
|
| 94 |
+
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
| 95 |
+
feats_lens = input_lengths
|
| 96 |
+
# 2. [Option] Speech enhancement
|
| 97 |
+
if self.frontend is not None:
|
| 98 |
+
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
| 99 |
+
# input_stft: (Batch, Length, [Channel], Freq)
|
| 100 |
+
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
| 101 |
+
|
| 102 |
+
# 3. [Multi channel case]: Select a channel
|
| 103 |
+
if input_stft.dim() == 4:
|
| 104 |
+
# h: (B, T, C, F) -> h: (B, T, F)
|
| 105 |
+
if self.training:
|
| 106 |
+
if self.use_channel is not None:
|
| 107 |
+
input_stft = input_stft[:, :, self.use_channel, :]
|
| 108 |
+
else:
|
| 109 |
+
# Select 1ch randomly
|
| 110 |
+
ch = np.random.randint(input_stft.size(2))
|
| 111 |
+
input_stft = input_stft[:, :, ch, :]
|
| 112 |
+
else:
|
| 113 |
+
# Use the first channel
|
| 114 |
+
input_stft = input_stft[:, :, 0, :]
|
| 115 |
+
|
| 116 |
+
# 4. STFT -> Power spectrum
|
| 117 |
+
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
| 118 |
+
input_power = input_stft.real**2 + input_stft.imag**2
|
| 119 |
+
|
| 120 |
+
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
| 121 |
+
# input_power: (Batch, [Channel,] Length, Freq)
|
| 122 |
+
# -> input_feats: (Batch, Length, Dim)
|
| 123 |
+
input_feats, _ = self.logmel(input_power, feats_lens)
|
| 124 |
+
|
| 125 |
+
return input_feats, feats_lens
|
| 126 |
+
|
| 127 |
+
def _compute_stft(
|
| 128 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 129 |
+
) -> torch.Tensor:
|
| 130 |
+
input_stft, feats_lens = self.stft(input, input_lengths)
|
| 131 |
+
|
| 132 |
+
assert input_stft.dim() >= 4, input_stft.shape
|
| 133 |
+
# "2" refers to the real/imag parts of Complex
|
| 134 |
+
assert input_stft.shape[-1] == 2, input_stft.shape
|
| 135 |
+
|
| 136 |
+
# Change torch.Tensor to ComplexTensor
|
| 137 |
+
# input_stft: (..., F, 2) -> (..., F)
|
| 138 |
+
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
| 139 |
+
return input_stft, feats_lens
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class MultiChannelFrontend(nn.Module):
|
| 143 |
+
"""Conventional frontend structure for ASR.
|
| 144 |
+
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
fs: Union[int, str] = 16000,
|
| 150 |
+
n_fft: int = 512,
|
| 151 |
+
win_length: int = None,
|
| 152 |
+
hop_length: int = None,
|
| 153 |
+
frame_length: int = None,
|
| 154 |
+
frame_shift: int = None,
|
| 155 |
+
window: Optional[str] = "hann",
|
| 156 |
+
center: bool = True,
|
| 157 |
+
normalized: bool = False,
|
| 158 |
+
onesided: bool = True,
|
| 159 |
+
n_mels: int = 80,
|
| 160 |
+
fmin: int = None,
|
| 161 |
+
fmax: int = None,
|
| 162 |
+
htk: bool = False,
|
| 163 |
+
frontend_conf: Optional[dict] = None,
|
| 164 |
+
apply_stft: bool = True,
|
| 165 |
+
use_channel: int = None,
|
| 166 |
+
lfr_m: int = 1,
|
| 167 |
+
lfr_n: int = 1,
|
| 168 |
+
cmvn_file: str = None,
|
| 169 |
+
mc: bool = True,
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
if isinstance(fs, str):
|
| 173 |
+
fs = humanfriendly.parse_size(fs)
|
| 174 |
+
|
| 175 |
+
# Deepcopy (In general, dict shouldn't be used as default arg)
|
| 176 |
+
frontend_conf = copy.deepcopy(frontend_conf)
|
| 177 |
+
if win_length is None and hop_length is None:
|
| 178 |
+
self.win_length = frame_length * 16
|
| 179 |
+
self.hop_length = frame_shift * 16
|
| 180 |
+
elif frame_length is None and frame_shift is None:
|
| 181 |
+
self.win_length = self.win_length
|
| 182 |
+
self.hop_length = self.hop_length
|
| 183 |
+
else:
|
| 184 |
+
logging.error(
|
| 185 |
+
"Only one of (win_length, hop_length) and (frame_length, frame_shift)"
|
| 186 |
+
"can be set."
|
| 187 |
+
)
|
| 188 |
+
exit(1)
|
| 189 |
+
|
| 190 |
+
if apply_stft:
|
| 191 |
+
self.stft = Stft(
|
| 192 |
+
n_fft=n_fft,
|
| 193 |
+
win_length=self.win_length,
|
| 194 |
+
hop_length=self.hop_length,
|
| 195 |
+
center=center,
|
| 196 |
+
window=window,
|
| 197 |
+
normalized=normalized,
|
| 198 |
+
onesided=onesided,
|
| 199 |
+
)
|
| 200 |
+
else:
|
| 201 |
+
self.stft = None
|
| 202 |
+
self.apply_stft = apply_stft
|
| 203 |
+
|
| 204 |
+
if frontend_conf is not None:
|
| 205 |
+
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
| 206 |
+
else:
|
| 207 |
+
self.frontend = None
|
| 208 |
+
|
| 209 |
+
self.logmel = LogMel(
|
| 210 |
+
fs=fs,
|
| 211 |
+
n_fft=n_fft,
|
| 212 |
+
n_mels=n_mels,
|
| 213 |
+
fmin=fmin,
|
| 214 |
+
fmax=fmax,
|
| 215 |
+
htk=htk,
|
| 216 |
+
)
|
| 217 |
+
self.n_mels = n_mels
|
| 218 |
+
self.use_channel = use_channel
|
| 219 |
+
self.mc = mc
|
| 220 |
+
if not self.mc:
|
| 221 |
+
if self.use_channel is not None:
|
| 222 |
+
logging.info("use the channel %d" % (self.use_channel))
|
| 223 |
+
else:
|
| 224 |
+
logging.info("random select channel")
|
| 225 |
+
self.cmvn_file = cmvn_file
|
| 226 |
+
if self.cmvn_file is not None:
|
| 227 |
+
mean, std = self._load_cmvn(self.cmvn_file)
|
| 228 |
+
self.register_buffer("mean", torch.from_numpy(mean))
|
| 229 |
+
self.register_buffer("std", torch.from_numpy(std))
|
| 230 |
+
self.frontend_type = "multichannelfrontend"
|
| 231 |
+
|
| 232 |
+
def output_size(self) -> int:
|
| 233 |
+
return self.n_mels
|
| 234 |
+
|
| 235 |
+
def forward(
|
| 236 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 237 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 238 |
+
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
| 239 |
+
# import pdb;pdb.set_trace()
|
| 240 |
+
if self.stft is not None:
|
| 241 |
+
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
| 242 |
+
else:
|
| 243 |
+
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
| 244 |
+
feats_lens = input_lengths
|
| 245 |
+
# 2. [Option] Speech enhancement
|
| 246 |
+
if self.frontend is not None:
|
| 247 |
+
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
| 248 |
+
# input_stft: (Batch, Length, [Channel], Freq)
|
| 249 |
+
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
| 250 |
+
|
| 251 |
+
# 3. [Multi channel case]: Select a channel(sa_asr)
|
| 252 |
+
if input_stft.dim() == 4 and not self.mc:
|
| 253 |
+
# h: (B, T, C, F) -> h: (B, T, F)
|
| 254 |
+
if self.training:
|
| 255 |
+
if self.use_channel is not None:
|
| 256 |
+
input_stft = input_stft[:, :, self.use_channel, :]
|
| 257 |
+
|
| 258 |
+
else:
|
| 259 |
+
# Select 1ch randomly
|
| 260 |
+
ch = np.random.randint(input_stft.size(2))
|
| 261 |
+
input_stft = input_stft[:, :, ch, :]
|
| 262 |
+
else:
|
| 263 |
+
# Use the first channel
|
| 264 |
+
input_stft = input_stft[:, :, 0, :]
|
| 265 |
+
|
| 266 |
+
# 4. STFT -> Power spectrum
|
| 267 |
+
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
| 268 |
+
input_power = input_stft.real**2 + input_stft.imag**2
|
| 269 |
+
|
| 270 |
+
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
| 271 |
+
# input_power: (Batch, [Channel,] Length, Freq)
|
| 272 |
+
# -> input_feats: (Batch, Length, Dim)
|
| 273 |
+
input_feats, _ = self.logmel(input_power, feats_lens)
|
| 274 |
+
if self.mc:
|
| 275 |
+
# MFCCA
|
| 276 |
+
if input_feats.dim() == 4:
|
| 277 |
+
bt = input_feats.size(0)
|
| 278 |
+
channel_size = input_feats.size(2)
|
| 279 |
+
input_feats = (
|
| 280 |
+
input_feats.transpose(1, 2)
|
| 281 |
+
.reshape(bt * channel_size, -1, 80)
|
| 282 |
+
.contiguous()
|
| 283 |
+
)
|
| 284 |
+
feats_lens = feats_lens.repeat(1, channel_size).squeeze()
|
| 285 |
+
else:
|
| 286 |
+
channel_size = 1
|
| 287 |
+
return input_feats, feats_lens, channel_size
|
| 288 |
+
else:
|
| 289 |
+
# 6. Apply CMVN
|
| 290 |
+
if self.cmvn_file is not None:
|
| 291 |
+
if feats_lens is None:
|
| 292 |
+
feats_lens = input_feats.new_full(
|
| 293 |
+
[input_feats.size(0)], input_feats.size(1)
|
| 294 |
+
)
|
| 295 |
+
self.mean = self.mean.to(input_feats.device, input_feats.dtype)
|
| 296 |
+
self.std = self.std.to(input_feats.device, input_feats.dtype)
|
| 297 |
+
mask = make_pad_mask(feats_lens, input_feats, 1)
|
| 298 |
+
|
| 299 |
+
if input_feats.requires_grad:
|
| 300 |
+
input_feats = input_feats + self.mean
|
| 301 |
+
else:
|
| 302 |
+
input_feats += self.mean
|
| 303 |
+
if input_feats.requires_grad:
|
| 304 |
+
input_feats = input_feats.masked_fill(mask, 0.0)
|
| 305 |
+
else:
|
| 306 |
+
input_feats.masked_fill_(mask, 0.0)
|
| 307 |
+
|
| 308 |
+
input_feats *= self.std
|
| 309 |
+
|
| 310 |
+
return input_feats, feats_lens
|
| 311 |
+
|
| 312 |
+
def _compute_stft(
|
| 313 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 314 |
+
) -> torch.Tensor:
|
| 315 |
+
input_stft, feats_lens = self.stft(input, input_lengths)
|
| 316 |
+
|
| 317 |
+
assert input_stft.dim() >= 4, input_stft.shape
|
| 318 |
+
# "2" refers to the real/imag parts of Complex
|
| 319 |
+
assert input_stft.shape[-1] == 2, input_stft.shape
|
| 320 |
+
|
| 321 |
+
# Change torch.Tensor to ComplexTensor
|
| 322 |
+
# input_stft: (..., F, 2) -> (..., F)
|
| 323 |
+
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
| 324 |
+
return input_stft, feats_lens
|
| 325 |
+
|
| 326 |
+
def _load_cmvn(self, cmvn_file):
|
| 327 |
+
with open(cmvn_file, "r", encoding="utf-8") as f:
|
| 328 |
+
lines = f.readlines()
|
| 329 |
+
means_list = []
|
| 330 |
+
vars_list = []
|
| 331 |
+
for i in range(len(lines)):
|
| 332 |
+
line_item = lines[i].split()
|
| 333 |
+
if line_item[0] == "<AddShift>":
|
| 334 |
+
line_item = lines[i + 1].split()
|
| 335 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 336 |
+
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
| 337 |
+
means_list = list(add_shift_line)
|
| 338 |
+
continue
|
| 339 |
+
elif line_item[0] == "<Rescale>":
|
| 340 |
+
line_item = lines[i + 1].split()
|
| 341 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 342 |
+
rescale_line = line_item[3 : (len(line_item) - 1)]
|
| 343 |
+
vars_list = list(rescale_line)
|
| 344 |
+
continue
|
| 345 |
+
means = np.array(means_list).astype(np.float)
|
| 346 |
+
vars = np.array(vars_list).astype(np.float)
|
| 347 |
+
return means, vars
|
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
#
|
| 4 |
+
# This module is for computing audio features
|
| 5 |
+
|
| 6 |
+
import librosa
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def transform(Y, dtype=np.float32):
|
| 11 |
+
Y = np.abs(Y)
|
| 12 |
+
n_fft = 2 * (Y.shape[1] - 1)
|
| 13 |
+
sr = 8000
|
| 14 |
+
n_mels = 23
|
| 15 |
+
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
| 16 |
+
Y = np.dot(Y**2, mel_basis.T)
|
| 17 |
+
Y = np.log10(np.maximum(Y, 1e-10))
|
| 18 |
+
mean = np.mean(Y, axis=0)
|
| 19 |
+
Y = Y - mean
|
| 20 |
+
return Y.astype(dtype)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def subsample(Y, T, subsampling=1):
|
| 24 |
+
Y_ss = Y[::subsampling]
|
| 25 |
+
T_ss = T[::subsampling]
|
| 26 |
+
return Y_ss, T_ss
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def splice(Y, context_size=0):
|
| 30 |
+
Y_pad = np.pad(Y, [(context_size, context_size), (0, 0)], "constant")
|
| 31 |
+
Y_spliced = np.lib.stride_tricks.as_strided(
|
| 32 |
+
np.ascontiguousarray(Y_pad),
|
| 33 |
+
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
|
| 34 |
+
(Y.itemsize * Y.shape[1], Y.itemsize),
|
| 35 |
+
writeable=False,
|
| 36 |
+
)
|
| 37 |
+
return Y_spliced
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def stft(data, frame_size=1024, frame_shift=256):
|
| 41 |
+
fft_size = 1 << (frame_size - 1).bit_length()
|
| 42 |
+
if len(data) % frame_shift == 0:
|
| 43 |
+
return librosa.stft(
|
| 44 |
+
data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift
|
| 45 |
+
).T[:-1]
|
| 46 |
+
else:
|
| 47 |
+
return librosa.stft(
|
| 48 |
+
data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift
|
| 49 |
+
).T
|
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from funasr_detach.frontends.default import DefaultFrontend
|
| 2 |
+
from funasr_detach.frontends.s3prl import S3prlFrontend
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FusedFrontends(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
|
| 12 |
+
):
|
| 13 |
+
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.align_method = (
|
| 16 |
+
align_method # fusing method : linear_projection only for now
|
| 17 |
+
)
|
| 18 |
+
self.proj_dim = proj_dim # dim of the projection done on each frontend
|
| 19 |
+
self.frontends = [] # list of the frontends to combine
|
| 20 |
+
|
| 21 |
+
for i, frontend in enumerate(frontends):
|
| 22 |
+
frontend_type = frontend["frontend_type"]
|
| 23 |
+
if frontend_type == "default":
|
| 24 |
+
n_mels, fs, n_fft, win_length, hop_length = (
|
| 25 |
+
frontend.get("n_mels", 80),
|
| 26 |
+
fs,
|
| 27 |
+
frontend.get("n_fft", 512),
|
| 28 |
+
frontend.get("win_length"),
|
| 29 |
+
frontend.get("hop_length", 128),
|
| 30 |
+
)
|
| 31 |
+
window, center, normalized, onesided = (
|
| 32 |
+
frontend.get("window", "hann"),
|
| 33 |
+
frontend.get("center", True),
|
| 34 |
+
frontend.get("normalized", False),
|
| 35 |
+
frontend.get("onesided", True),
|
| 36 |
+
)
|
| 37 |
+
fmin, fmax, htk, apply_stft = (
|
| 38 |
+
frontend.get("fmin", None),
|
| 39 |
+
frontend.get("fmax", None),
|
| 40 |
+
frontend.get("htk", False),
|
| 41 |
+
frontend.get("apply_stft", True),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.frontends.append(
|
| 45 |
+
DefaultFrontend(
|
| 46 |
+
n_mels=n_mels,
|
| 47 |
+
n_fft=n_fft,
|
| 48 |
+
fs=fs,
|
| 49 |
+
win_length=win_length,
|
| 50 |
+
hop_length=hop_length,
|
| 51 |
+
window=window,
|
| 52 |
+
center=center,
|
| 53 |
+
normalized=normalized,
|
| 54 |
+
onesided=onesided,
|
| 55 |
+
fmin=fmin,
|
| 56 |
+
fmax=fmax,
|
| 57 |
+
htk=htk,
|
| 58 |
+
apply_stft=apply_stft,
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
elif frontend_type == "s3prl":
|
| 62 |
+
frontend_conf, download_dir, multilayer_feature = (
|
| 63 |
+
frontend.get("frontend_conf"),
|
| 64 |
+
frontend.get("download_dir"),
|
| 65 |
+
frontend.get("multilayer_feature"),
|
| 66 |
+
)
|
| 67 |
+
self.frontends.append(
|
| 68 |
+
S3prlFrontend(
|
| 69 |
+
fs=fs,
|
| 70 |
+
frontend_conf=frontend_conf,
|
| 71 |
+
download_dir=download_dir,
|
| 72 |
+
multilayer_feature=multilayer_feature,
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
+
raise NotImplementedError # frontends are only default or s3prl
|
| 78 |
+
|
| 79 |
+
self.frontends = torch.nn.ModuleList(self.frontends)
|
| 80 |
+
|
| 81 |
+
self.gcd = np.gcd.reduce([frontend.hop_length for frontend in self.frontends])
|
| 82 |
+
self.factors = [frontend.hop_length // self.gcd for frontend in self.frontends]
|
| 83 |
+
if torch.cuda.is_available():
|
| 84 |
+
dev = "cuda"
|
| 85 |
+
else:
|
| 86 |
+
dev = "cpu"
|
| 87 |
+
if self.align_method == "linear_projection":
|
| 88 |
+
self.projection_layers = [
|
| 89 |
+
torch.nn.Linear(
|
| 90 |
+
in_features=frontend.output_size(),
|
| 91 |
+
out_features=self.factors[i] * self.proj_dim,
|
| 92 |
+
)
|
| 93 |
+
for i, frontend in enumerate(self.frontends)
|
| 94 |
+
]
|
| 95 |
+
self.projection_layers = torch.nn.ModuleList(self.projection_layers)
|
| 96 |
+
self.projection_layers = self.projection_layers.to(torch.device(dev))
|
| 97 |
+
|
| 98 |
+
def output_size(self) -> int:
|
| 99 |
+
return len(self.frontends) * self.proj_dim
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 103 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 104 |
+
|
| 105 |
+
# step 0 : get all frontends features
|
| 106 |
+
self.feats = []
|
| 107 |
+
for frontend in self.frontends:
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
input_feats, feats_lens = frontend.forward(input, input_lengths)
|
| 110 |
+
self.feats.append([input_feats, feats_lens])
|
| 111 |
+
|
| 112 |
+
if (
|
| 113 |
+
self.align_method == "linear_projection"
|
| 114 |
+
): # TODO(Dan): to add other align methods
|
| 115 |
+
|
| 116 |
+
# first step : projections
|
| 117 |
+
self.feats_proj = []
|
| 118 |
+
for i, frontend in enumerate(self.frontends):
|
| 119 |
+
input_feats = self.feats[i][0]
|
| 120 |
+
self.feats_proj.append(self.projection_layers[i](input_feats))
|
| 121 |
+
|
| 122 |
+
# 2nd step : reshape
|
| 123 |
+
self.feats_reshaped = []
|
| 124 |
+
for i, frontend in enumerate(self.frontends):
|
| 125 |
+
input_feats_proj = self.feats_proj[i]
|
| 126 |
+
bs, nf, dim = input_feats_proj.shape
|
| 127 |
+
input_feats_reshaped = torch.reshape(
|
| 128 |
+
input_feats_proj, (bs, nf * self.factors[i], dim // self.factors[i])
|
| 129 |
+
)
|
| 130 |
+
self.feats_reshaped.append(input_feats_reshaped)
|
| 131 |
+
|
| 132 |
+
# 3rd step : drop the few last frames
|
| 133 |
+
m = min([x.shape[1] for x in self.feats_reshaped])
|
| 134 |
+
self.feats_final = [x[:, :m, :] for x in self.feats_reshaped]
|
| 135 |
+
|
| 136 |
+
input_feats = torch.cat(
|
| 137 |
+
self.feats_final, dim=-1
|
| 138 |
+
) # change the input size of the preencoder : proj_dim * n_frontends
|
| 139 |
+
feats_lens = torch.ones_like(self.feats[0][1]) * (m)
|
| 140 |
+
|
| 141 |
+
else:
|
| 142 |
+
raise NotImplementedError
|
| 143 |
+
|
| 144 |
+
return input_feats, feats_lens
|
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from argparse import Namespace
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import humanfriendly
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from funasr_detach.frontends.utils.frontend import Frontend
|
| 14 |
+
from funasr_detach.models.transformer.utils.nets_utils import pad_list
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def base_s3prl_setup(args):
|
| 18 |
+
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
|
| 19 |
+
args.upstream_model_config = getattr(args, "upstream_model_config", None)
|
| 20 |
+
args.upstream_refresh = getattr(args, "upstream_refresh", False)
|
| 21 |
+
args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
|
| 22 |
+
args.init_ckpt = getattr(args, "init_ckpt", None)
|
| 23 |
+
args.verbose = getattr(args, "verbose", False)
|
| 24 |
+
args.tile_factor = getattr(args, "tile_factor", 1)
|
| 25 |
+
return args
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class S3prlFrontend(nn.Module):
|
| 29 |
+
"""Speech Pretrained Representation frontend structure for ASR."""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
fs: Union[int, str] = 16000,
|
| 34 |
+
frontend_conf: Optional[dict] = None,
|
| 35 |
+
download_dir: str = None,
|
| 36 |
+
multilayer_feature: bool = False,
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
if isinstance(fs, str):
|
| 40 |
+
fs = humanfriendly.parse_size(fs)
|
| 41 |
+
|
| 42 |
+
if download_dir is not None:
|
| 43 |
+
torch.hub.set_dir(download_dir)
|
| 44 |
+
|
| 45 |
+
self.multilayer_feature = multilayer_feature
|
| 46 |
+
self.upstream, self.featurizer = self._get_upstream(frontend_conf)
|
| 47 |
+
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
|
| 48 |
+
self.output_dim = self.featurizer.output_dim
|
| 49 |
+
self.frontend_type = "s3prl"
|
| 50 |
+
self.hop_length = self.upstream.get_downsample_rates("key")
|
| 51 |
+
|
| 52 |
+
def _get_upstream(self, frontend_conf):
|
| 53 |
+
"""Get S3PRL upstream model."""
|
| 54 |
+
s3prl_args = base_s3prl_setup(
|
| 55 |
+
Namespace(**frontend_conf, device="cpu"),
|
| 56 |
+
)
|
| 57 |
+
self.args = s3prl_args
|
| 58 |
+
|
| 59 |
+
s3prl_path = None
|
| 60 |
+
python_path_list = os.environ.get("PYTHONPATH", "(None)").split(":")
|
| 61 |
+
for p in python_path_list:
|
| 62 |
+
if p.endswith("s3prl"):
|
| 63 |
+
s3prl_path = p
|
| 64 |
+
break
|
| 65 |
+
assert s3prl_path is not None
|
| 66 |
+
|
| 67 |
+
s3prl_upstream = torch.hub.load(
|
| 68 |
+
s3prl_path,
|
| 69 |
+
s3prl_args.upstream,
|
| 70 |
+
ckpt=s3prl_args.upstream_ckpt,
|
| 71 |
+
model_config=s3prl_args.upstream_model_config,
|
| 72 |
+
refresh=s3prl_args.upstream_refresh,
|
| 73 |
+
source="local",
|
| 74 |
+
).to("cpu")
|
| 75 |
+
|
| 76 |
+
if getattr(
|
| 77 |
+
s3prl_upstream, "model", None
|
| 78 |
+
) is not None and s3prl_upstream.model.__class__.__name__ in [
|
| 79 |
+
"Wav2Vec2Model",
|
| 80 |
+
"HubertModel",
|
| 81 |
+
]:
|
| 82 |
+
s3prl_upstream.model.encoder.layerdrop = 0.0
|
| 83 |
+
|
| 84 |
+
from s3prl.upstream.interfaces import Featurizer
|
| 85 |
+
|
| 86 |
+
if self.multilayer_feature is None:
|
| 87 |
+
feature_selection = "last_hidden_state"
|
| 88 |
+
else:
|
| 89 |
+
feature_selection = "hidden_states"
|
| 90 |
+
s3prl_featurizer = Featurizer(
|
| 91 |
+
upstream=s3prl_upstream,
|
| 92 |
+
feature_selection=feature_selection,
|
| 93 |
+
upstream_device="cpu",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return s3prl_upstream, s3prl_featurizer
|
| 97 |
+
|
| 98 |
+
def _tile_representations(self, feature):
|
| 99 |
+
"""Tile up the representations by `tile_factor`.
|
| 100 |
+
Input - sequence of representations
|
| 101 |
+
shape: (batch_size, seq_len, feature_dim)
|
| 102 |
+
Output - sequence of tiled representations
|
| 103 |
+
shape: (batch_size, seq_len * factor, feature_dim)
|
| 104 |
+
"""
|
| 105 |
+
assert (
|
| 106 |
+
len(feature.shape) == 3
|
| 107 |
+
), "Input argument `feature` has invalid shape: {}".format(feature.shape)
|
| 108 |
+
tiled_feature = feature.repeat(1, 1, self.args.tile_factor)
|
| 109 |
+
tiled_feature = tiled_feature.reshape(
|
| 110 |
+
feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2)
|
| 111 |
+
)
|
| 112 |
+
return tiled_feature
|
| 113 |
+
|
| 114 |
+
def output_size(self) -> int:
|
| 115 |
+
return self.output_dim
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 119 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 120 |
+
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
|
| 121 |
+
self.upstream.eval()
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
feats = self.upstream(wavs)
|
| 124 |
+
feats = self.featurizer(wavs, feats)
|
| 125 |
+
|
| 126 |
+
if self.args.tile_factor != 1:
|
| 127 |
+
feats = self._tile_representations(feats)
|
| 128 |
+
|
| 129 |
+
input_feats = pad_list(feats, 0.0)
|
| 130 |
+
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
|
| 131 |
+
|
| 132 |
+
# Saving CUDA Memory
|
| 133 |
+
del feats
|
| 134 |
+
|
| 135 |
+
return input_feats, feats_lens
|
| 136 |
+
|
| 137 |
+
def reload_pretrained_parameters(self):
|
| 138 |
+
self.upstream.load_state_dict(self.pretrained_params)
|
| 139 |
+
logging.info("Pretrained S3PRL frontend model parameters reloaded!")
|
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Initialize sub package."""
|
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch_complex import functional as FC
|
| 3 |
+
from torch_complex.tensor import ComplexTensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_power_spectral_density_matrix(
|
| 7 |
+
xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
|
| 8 |
+
) -> ComplexTensor:
|
| 9 |
+
"""Return cross-channel power spectral density (PSD) matrix
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
xs (ComplexTensor): (..., F, C, T)
|
| 13 |
+
mask (torch.Tensor): (..., F, C, T)
|
| 14 |
+
normalization (bool):
|
| 15 |
+
eps (float):
|
| 16 |
+
Returns
|
| 17 |
+
psd (ComplexTensor): (..., F, C, C)
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
# outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
|
| 21 |
+
psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
|
| 22 |
+
|
| 23 |
+
# Averaging mask along C: (..., C, T) -> (..., T)
|
| 24 |
+
mask = mask.mean(dim=-2)
|
| 25 |
+
|
| 26 |
+
# Normalized mask along T: (..., T)
|
| 27 |
+
if normalization:
|
| 28 |
+
# If assuming the tensor is padded with zero, the summation along
|
| 29 |
+
# the time axis is same regardless of the padding length.
|
| 30 |
+
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
|
| 31 |
+
|
| 32 |
+
# psd: (..., T, C, C)
|
| 33 |
+
psd = psd_Y * mask[..., None, None]
|
| 34 |
+
# (..., T, C, C) -> (..., C, C)
|
| 35 |
+
psd = psd.sum(dim=-3)
|
| 36 |
+
|
| 37 |
+
return psd
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_mvdr_vector(
|
| 41 |
+
psd_s: ComplexTensor,
|
| 42 |
+
psd_n: ComplexTensor,
|
| 43 |
+
reference_vector: torch.Tensor,
|
| 44 |
+
eps: float = 1e-15,
|
| 45 |
+
) -> ComplexTensor:
|
| 46 |
+
"""Return the MVDR(Minimum Variance Distortionless Response) vector:
|
| 47 |
+
|
| 48 |
+
h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
|
| 49 |
+
|
| 50 |
+
Reference:
|
| 51 |
+
On optimal frequency-domain multichannel linear filtering
|
| 52 |
+
for noise reduction; M. Souden et al., 2010;
|
| 53 |
+
https://ieeexplore.ieee.org/document/5089420
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
psd_s (ComplexTensor): (..., F, C, C)
|
| 57 |
+
psd_n (ComplexTensor): (..., F, C, C)
|
| 58 |
+
reference_vector (torch.Tensor): (..., C)
|
| 59 |
+
eps (float):
|
| 60 |
+
Returns:
|
| 61 |
+
beamform_vector (ComplexTensor)r: (..., F, C)
|
| 62 |
+
"""
|
| 63 |
+
# Add eps
|
| 64 |
+
C = psd_n.size(-1)
|
| 65 |
+
eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
|
| 66 |
+
shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
|
| 67 |
+
eye = eye.view(*shape)
|
| 68 |
+
psd_n += eps * eye
|
| 69 |
+
|
| 70 |
+
# numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
|
| 71 |
+
numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
|
| 72 |
+
# ws: (..., C, C) / (...,) -> (..., C, C)
|
| 73 |
+
ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
|
| 74 |
+
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
|
| 75 |
+
beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
|
| 76 |
+
return beamform_vector
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def apply_beamforming_vector(
|
| 80 |
+
beamform_vector: ComplexTensor, mix: ComplexTensor
|
| 81 |
+
) -> ComplexTensor:
|
| 82 |
+
# (..., C) x (..., C, T) -> (..., T)
|
| 83 |
+
es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
|
| 84 |
+
return es
|
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Beamformer module."""
|
| 2 |
+
|
| 3 |
+
from distutils.version import LooseVersion
|
| 4 |
+
from typing import Sequence
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
from typing import Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from torch_complex import functional as FC
|
| 12 |
+
from torch_complex.tensor import ComplexTensor
|
| 13 |
+
except:
|
| 14 |
+
print("Please install torch_complex firstly")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
EPS = torch.finfo(torch.double).eps
|
| 18 |
+
is_torch_1_8_plus = LooseVersion(torch.__version__) >= LooseVersion("1.8.0")
|
| 19 |
+
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def new_complex_like(
|
| 23 |
+
ref: Union[torch.Tensor, ComplexTensor],
|
| 24 |
+
real_imag: Tuple[torch.Tensor, torch.Tensor],
|
| 25 |
+
):
|
| 26 |
+
if isinstance(ref, ComplexTensor):
|
| 27 |
+
return ComplexTensor(*real_imag)
|
| 28 |
+
elif is_torch_complex_tensor(ref):
|
| 29 |
+
return torch.complex(*real_imag)
|
| 30 |
+
else:
|
| 31 |
+
raise ValueError(
|
| 32 |
+
"Please update your PyTorch version to 1.9+ for complex support."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def is_torch_complex_tensor(c):
|
| 37 |
+
return (
|
| 38 |
+
not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def is_complex(c):
|
| 43 |
+
return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def to_double(c):
|
| 47 |
+
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
|
| 48 |
+
return c.to(dtype=torch.complex128)
|
| 49 |
+
else:
|
| 50 |
+
return c.double()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def to_float(c):
|
| 54 |
+
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
|
| 55 |
+
return c.to(dtype=torch.complex64)
|
| 56 |
+
else:
|
| 57 |
+
return c.float()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
|
| 61 |
+
if not isinstance(seq, (list, tuple)):
|
| 62 |
+
raise TypeError(
|
| 63 |
+
"cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
|
| 64 |
+
"not Tensor"
|
| 65 |
+
)
|
| 66 |
+
if isinstance(seq[0], ComplexTensor):
|
| 67 |
+
return FC.cat(seq, *args, **kwargs)
|
| 68 |
+
else:
|
| 69 |
+
return torch.cat(seq, *args, **kwargs)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def complex_norm(
|
| 73 |
+
c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
if not is_complex(c):
|
| 76 |
+
raise TypeError("Input is not a complex tensor.")
|
| 77 |
+
if is_torch_complex_tensor(c):
|
| 78 |
+
return torch.norm(c, dim=dim, keepdim=keepdim)
|
| 79 |
+
else:
|
| 80 |
+
return torch.sqrt((c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def einsum(equation, *operands):
|
| 84 |
+
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
| 85 |
+
# NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
|
| 86 |
+
# mixed input with complex and real tensors.
|
| 87 |
+
if len(operands) == 1:
|
| 88 |
+
if isinstance(operands[0], (tuple, list)):
|
| 89 |
+
operands = operands[0]
|
| 90 |
+
complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
|
| 91 |
+
return complex_module.einsum(equation, *operands)
|
| 92 |
+
elif len(operands) != 2:
|
| 93 |
+
op0 = operands[0]
|
| 94 |
+
same_type = all(op.dtype == op0.dtype for op in operands[1:])
|
| 95 |
+
if same_type:
|
| 96 |
+
_einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
|
| 97 |
+
return _einsum(equation, *operands)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError("0 or More than 2 operands are not supported.")
|
| 100 |
+
a, b = operands
|
| 101 |
+
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
| 102 |
+
return FC.einsum(equation, a, b)
|
| 103 |
+
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
| 104 |
+
if not torch.is_complex(a):
|
| 105 |
+
o_real = torch.einsum(equation, a, b.real)
|
| 106 |
+
o_imag = torch.einsum(equation, a, b.imag)
|
| 107 |
+
return torch.complex(o_real, o_imag)
|
| 108 |
+
elif not torch.is_complex(b):
|
| 109 |
+
o_real = torch.einsum(equation, a.real, b)
|
| 110 |
+
o_imag = torch.einsum(equation, a.imag, b)
|
| 111 |
+
return torch.complex(o_real, o_imag)
|
| 112 |
+
else:
|
| 113 |
+
return torch.einsum(equation, a, b)
|
| 114 |
+
else:
|
| 115 |
+
return torch.einsum(equation, a, b)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def inverse(
|
| 119 |
+
c: Union[torch.Tensor, ComplexTensor],
|
| 120 |
+
) -> Union[torch.Tensor, ComplexTensor]:
|
| 121 |
+
if isinstance(c, ComplexTensor):
|
| 122 |
+
return c.inverse2()
|
| 123 |
+
else:
|
| 124 |
+
return c.inverse()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def matmul(
|
| 128 |
+
a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
|
| 129 |
+
) -> Union[torch.Tensor, ComplexTensor]:
|
| 130 |
+
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
| 131 |
+
# NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
|
| 132 |
+
# multiplication between complex and real tensors.
|
| 133 |
+
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
| 134 |
+
return FC.matmul(a, b)
|
| 135 |
+
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
| 136 |
+
if not torch.is_complex(a):
|
| 137 |
+
o_real = torch.matmul(a, b.real)
|
| 138 |
+
o_imag = torch.matmul(a, b.imag)
|
| 139 |
+
return torch.complex(o_real, o_imag)
|
| 140 |
+
elif not torch.is_complex(b):
|
| 141 |
+
o_real = torch.matmul(a.real, b)
|
| 142 |
+
o_imag = torch.matmul(a.imag, b)
|
| 143 |
+
return torch.complex(o_real, o_imag)
|
| 144 |
+
else:
|
| 145 |
+
return torch.matmul(a, b)
|
| 146 |
+
else:
|
| 147 |
+
return torch.matmul(a, b)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def trace(a: Union[torch.Tensor, ComplexTensor]):
|
| 151 |
+
# NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
|
| 152 |
+
# support bacth processing. Use FC.trace() as fallback.
|
| 153 |
+
return FC.trace(a)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
|
| 157 |
+
if isinstance(a, ComplexTensor):
|
| 158 |
+
return FC.reverse(a, dim=dim)
|
| 159 |
+
else:
|
| 160 |
+
return torch.flip(a, dims=(dim,))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
|
| 164 |
+
"""Solve the linear equation ax = b."""
|
| 165 |
+
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
| 166 |
+
# NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
|
| 167 |
+
# mixed input with complex and real tensors.
|
| 168 |
+
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
| 169 |
+
if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
|
| 170 |
+
return FC.solve(b, a, return_LU=False)
|
| 171 |
+
else:
|
| 172 |
+
return matmul(inverse(a), b)
|
| 173 |
+
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
| 174 |
+
if torch.is_complex(a) and torch.is_complex(b):
|
| 175 |
+
return torch.linalg.solve(a, b)
|
| 176 |
+
else:
|
| 177 |
+
return matmul(inverse(a), b)
|
| 178 |
+
else:
|
| 179 |
+
if is_torch_1_8_plus:
|
| 180 |
+
return torch.linalg.solve(a, b)
|
| 181 |
+
else:
|
| 182 |
+
return torch.solve(b, a)[0]
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
|
| 186 |
+
if not isinstance(seq, (list, tuple)):
|
| 187 |
+
raise TypeError(
|
| 188 |
+
"stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
|
| 189 |
+
"not Tensor"
|
| 190 |
+
)
|
| 191 |
+
if isinstance(seq[0], ComplexTensor):
|
| 192 |
+
return FC.stack(seq, *args, **kwargs)
|
| 193 |
+
else:
|
| 194 |
+
return torch.stack(seq, *args, **kwargs)
|
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DNN beamformer module."""
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from funasr_detach.frontends.utils.beamformer import apply_beamforming_vector
|
| 9 |
+
from funasr_detach.frontends.utils.beamformer import get_mvdr_vector
|
| 10 |
+
from funasr_detach.frontends.utils.beamformer import (
|
| 11 |
+
get_power_spectral_density_matrix, # noqa: H301
|
| 12 |
+
)
|
| 13 |
+
from funasr_detach.frontends.utils.mask_estimator import MaskEstimator
|
| 14 |
+
from torch_complex.tensor import ComplexTensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DNN_Beamformer(torch.nn.Module):
|
| 18 |
+
"""DNN mask based Beamformer
|
| 19 |
+
|
| 20 |
+
Citation:
|
| 21 |
+
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
|
| 22 |
+
https://arxiv.org/abs/1703.04783
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
bidim,
|
| 29 |
+
btype="blstmp",
|
| 30 |
+
blayers=3,
|
| 31 |
+
bunits=300,
|
| 32 |
+
bprojs=320,
|
| 33 |
+
bnmask=2,
|
| 34 |
+
dropout_rate=0.0,
|
| 35 |
+
badim=320,
|
| 36 |
+
ref_channel: int = -1,
|
| 37 |
+
beamformer_type="mvdr",
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.mask = MaskEstimator(
|
| 41 |
+
btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
|
| 42 |
+
)
|
| 43 |
+
self.ref = AttentionReference(bidim, badim)
|
| 44 |
+
self.ref_channel = ref_channel
|
| 45 |
+
|
| 46 |
+
self.nmask = bnmask
|
| 47 |
+
|
| 48 |
+
if beamformer_type != "mvdr":
|
| 49 |
+
raise ValueError(
|
| 50 |
+
"Not supporting beamformer_type={}".format(beamformer_type)
|
| 51 |
+
)
|
| 52 |
+
self.beamformer_type = beamformer_type
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self, data: ComplexTensor, ilens: torch.LongTensor
|
| 56 |
+
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
| 57 |
+
"""The forward function
|
| 58 |
+
|
| 59 |
+
Notation:
|
| 60 |
+
B: Batch
|
| 61 |
+
C: Channel
|
| 62 |
+
T: Time or Sequence length
|
| 63 |
+
F: Freq
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
data (ComplexTensor): (B, T, C, F)
|
| 67 |
+
ilens (torch.Tensor): (B,)
|
| 68 |
+
Returns:
|
| 69 |
+
enhanced (ComplexTensor): (B, T, F)
|
| 70 |
+
ilens (torch.Tensor): (B,)
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def apply_beamforming(data, ilens, psd_speech, psd_noise):
|
| 75 |
+
# u: (B, C)
|
| 76 |
+
if self.ref_channel < 0:
|
| 77 |
+
u, _ = self.ref(psd_speech, ilens)
|
| 78 |
+
else:
|
| 79 |
+
# (optional) Create onehot vector for fixed reference microphone
|
| 80 |
+
u = torch.zeros(
|
| 81 |
+
*(data.size()[:-3] + (data.size(-2),)), device=data.device
|
| 82 |
+
)
|
| 83 |
+
u[..., self.ref_channel].fill_(1)
|
| 84 |
+
|
| 85 |
+
ws = get_mvdr_vector(psd_speech, psd_noise, u)
|
| 86 |
+
enhanced = apply_beamforming_vector(ws, data)
|
| 87 |
+
|
| 88 |
+
return enhanced, ws
|
| 89 |
+
|
| 90 |
+
# data (B, T, C, F) -> (B, F, C, T)
|
| 91 |
+
data = data.permute(0, 3, 2, 1)
|
| 92 |
+
|
| 93 |
+
# mask: (B, F, C, T)
|
| 94 |
+
masks, _ = self.mask(data, ilens)
|
| 95 |
+
assert self.nmask == len(masks)
|
| 96 |
+
|
| 97 |
+
if self.nmask == 2: # (mask_speech, mask_noise)
|
| 98 |
+
mask_speech, mask_noise = masks
|
| 99 |
+
|
| 100 |
+
psd_speech = get_power_spectral_density_matrix(data, mask_speech)
|
| 101 |
+
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
|
| 102 |
+
|
| 103 |
+
enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
|
| 104 |
+
|
| 105 |
+
# (..., F, T) -> (..., T, F)
|
| 106 |
+
enhanced = enhanced.transpose(-1, -2)
|
| 107 |
+
mask_speech = mask_speech.transpose(-1, -3)
|
| 108 |
+
else: # multi-speaker case: (mask_speech1, ..., mask_noise)
|
| 109 |
+
mask_speech = list(masks[:-1])
|
| 110 |
+
mask_noise = masks[-1]
|
| 111 |
+
|
| 112 |
+
psd_speeches = [
|
| 113 |
+
get_power_spectral_density_matrix(data, mask) for mask in mask_speech
|
| 114 |
+
]
|
| 115 |
+
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
|
| 116 |
+
|
| 117 |
+
enhanced = []
|
| 118 |
+
ws = []
|
| 119 |
+
for i in range(self.nmask - 1):
|
| 120 |
+
psd_speech = psd_speeches.pop(i)
|
| 121 |
+
# treat all other speakers' psd_speech as noises
|
| 122 |
+
enh, w = apply_beamforming(
|
| 123 |
+
data, ilens, psd_speech, sum(psd_speeches) + psd_noise
|
| 124 |
+
)
|
| 125 |
+
psd_speeches.insert(i, psd_speech)
|
| 126 |
+
|
| 127 |
+
# (..., F, T) -> (..., T, F)
|
| 128 |
+
enh = enh.transpose(-1, -2)
|
| 129 |
+
mask_speech[i] = mask_speech[i].transpose(-1, -3)
|
| 130 |
+
|
| 131 |
+
enhanced.append(enh)
|
| 132 |
+
ws.append(w)
|
| 133 |
+
|
| 134 |
+
return enhanced, ilens, mask_speech
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class AttentionReference(torch.nn.Module):
|
| 138 |
+
def __init__(self, bidim, att_dim):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
|
| 141 |
+
self.gvec = torch.nn.Linear(att_dim, 1)
|
| 142 |
+
|
| 143 |
+
def forward(
|
| 144 |
+
self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
|
| 145 |
+
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
| 146 |
+
"""The forward function
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
psd_in (ComplexTensor): (B, F, C, C)
|
| 150 |
+
ilens (torch.Tensor): (B,)
|
| 151 |
+
scaling (float):
|
| 152 |
+
Returns:
|
| 153 |
+
u (torch.Tensor): (B, C)
|
| 154 |
+
ilens (torch.Tensor): (B,)
|
| 155 |
+
"""
|
| 156 |
+
B, _, C = psd_in.size()[:3]
|
| 157 |
+
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
|
| 158 |
+
# psd_in: (B, F, C, C)
|
| 159 |
+
psd = psd_in.masked_fill(
|
| 160 |
+
torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
|
| 161 |
+
)
|
| 162 |
+
# psd: (B, F, C, C) -> (B, C, F)
|
| 163 |
+
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
|
| 164 |
+
|
| 165 |
+
# Calculate amplitude
|
| 166 |
+
psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
|
| 167 |
+
|
| 168 |
+
# (B, C, F) -> (B, C, F2)
|
| 169 |
+
mlp_psd = self.mlp_psd(psd_feat)
|
| 170 |
+
# (B, C, F2) -> (B, C, 1) -> (B, C)
|
| 171 |
+
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
|
| 172 |
+
u = F.softmax(scaling * e, dim=-1)
|
| 173 |
+
return u, ilens
|
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
from pytorch_wpe import wpe_one_iteration
|
| 4 |
+
import torch
|
| 5 |
+
from torch_complex.tensor import ComplexTensor
|
| 6 |
+
|
| 7 |
+
from funasr_detach.frontends.utils.mask_estimator import MaskEstimator
|
| 8 |
+
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DNN_WPE(torch.nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
wtype: str = "blstmp",
|
| 15 |
+
widim: int = 257,
|
| 16 |
+
wlayers: int = 3,
|
| 17 |
+
wunits: int = 300,
|
| 18 |
+
wprojs: int = 320,
|
| 19 |
+
dropout_rate: float = 0.0,
|
| 20 |
+
taps: int = 5,
|
| 21 |
+
delay: int = 3,
|
| 22 |
+
use_dnn_mask: bool = True,
|
| 23 |
+
iterations: int = 1,
|
| 24 |
+
normalization: bool = False,
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.iterations = iterations
|
| 28 |
+
self.taps = taps
|
| 29 |
+
self.delay = delay
|
| 30 |
+
|
| 31 |
+
self.normalization = normalization
|
| 32 |
+
self.use_dnn_mask = use_dnn_mask
|
| 33 |
+
|
| 34 |
+
self.inverse_power = True
|
| 35 |
+
|
| 36 |
+
if self.use_dnn_mask:
|
| 37 |
+
self.mask_est = MaskEstimator(
|
| 38 |
+
wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(
|
| 42 |
+
self, data: ComplexTensor, ilens: torch.LongTensor
|
| 43 |
+
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
| 44 |
+
"""The forward function
|
| 45 |
+
|
| 46 |
+
Notation:
|
| 47 |
+
B: Batch
|
| 48 |
+
C: Channel
|
| 49 |
+
T: Time or Sequence length
|
| 50 |
+
F: Freq or Some dimension of the feature vector
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
data: (B, C, T, F)
|
| 54 |
+
ilens: (B,)
|
| 55 |
+
Returns:
|
| 56 |
+
data: (B, C, T, F)
|
| 57 |
+
ilens: (B,)
|
| 58 |
+
"""
|
| 59 |
+
# (B, T, C, F) -> (B, F, C, T)
|
| 60 |
+
enhanced = data = data.permute(0, 3, 2, 1)
|
| 61 |
+
mask = None
|
| 62 |
+
|
| 63 |
+
for i in range(self.iterations):
|
| 64 |
+
# Calculate power: (..., C, T)
|
| 65 |
+
power = enhanced.real**2 + enhanced.imag**2
|
| 66 |
+
if i == 0 and self.use_dnn_mask:
|
| 67 |
+
# mask: (B, F, C, T)
|
| 68 |
+
(mask,), _ = self.mask_est(enhanced, ilens)
|
| 69 |
+
if self.normalization:
|
| 70 |
+
# Normalize along T
|
| 71 |
+
mask = mask / mask.sum(dim=-1)[..., None]
|
| 72 |
+
# (..., C, T) * (..., C, T) -> (..., C, T)
|
| 73 |
+
power = power * mask
|
| 74 |
+
|
| 75 |
+
# Averaging along the channel axis: (..., C, T) -> (..., T)
|
| 76 |
+
power = power.mean(dim=-2)
|
| 77 |
+
|
| 78 |
+
# enhanced: (..., C, T) -> (..., C, T)
|
| 79 |
+
enhanced = wpe_one_iteration(
|
| 80 |
+
data.contiguous(),
|
| 81 |
+
power,
|
| 82 |
+
taps=self.taps,
|
| 83 |
+
delay=self.delay,
|
| 84 |
+
inverse_power=self.inverse_power,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
|
| 88 |
+
|
| 89 |
+
# (B, F, C, T) -> (B, T, C, F)
|
| 90 |
+
enhanced = enhanced.permute(0, 3, 2, 1)
|
| 91 |
+
if mask is not None:
|
| 92 |
+
mask = mask.transpose(-1, -3)
|
| 93 |
+
return enhanced, ilens, mask
|
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
import librosa
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch_complex.tensor import ComplexTensor
|
| 9 |
+
|
| 10 |
+
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FeatureTransform(torch.nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
# Mel options,
|
| 17 |
+
fs: int = 16000,
|
| 18 |
+
n_fft: int = 512,
|
| 19 |
+
n_mels: int = 80,
|
| 20 |
+
fmin: float = 0.0,
|
| 21 |
+
fmax: float = None,
|
| 22 |
+
# Normalization
|
| 23 |
+
stats_file: str = None,
|
| 24 |
+
apply_uttmvn: bool = True,
|
| 25 |
+
uttmvn_norm_means: bool = True,
|
| 26 |
+
uttmvn_norm_vars: bool = False,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.apply_uttmvn = apply_uttmvn
|
| 30 |
+
|
| 31 |
+
self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
| 32 |
+
self.stats_file = stats_file
|
| 33 |
+
if stats_file is not None:
|
| 34 |
+
self.global_mvn = GlobalMVN(stats_file)
|
| 35 |
+
else:
|
| 36 |
+
self.global_mvn = None
|
| 37 |
+
|
| 38 |
+
if self.apply_uttmvn is not None:
|
| 39 |
+
self.uttmvn = UtteranceMVN(
|
| 40 |
+
norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
self.uttmvn = None
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
|
| 47 |
+
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
| 48 |
+
# (B, T, F) or (B, T, C, F)
|
| 49 |
+
if x.dim() not in (3, 4):
|
| 50 |
+
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
|
| 51 |
+
if not torch.is_tensor(ilens):
|
| 52 |
+
ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
|
| 53 |
+
|
| 54 |
+
if x.dim() == 4:
|
| 55 |
+
# h: (B, T, C, F) -> h: (B, T, F)
|
| 56 |
+
if self.training:
|
| 57 |
+
# Select 1ch randomly
|
| 58 |
+
ch = np.random.randint(x.size(2))
|
| 59 |
+
h = x[:, :, ch, :]
|
| 60 |
+
else:
|
| 61 |
+
# Use the first channel
|
| 62 |
+
h = x[:, :, 0, :]
|
| 63 |
+
else:
|
| 64 |
+
h = x
|
| 65 |
+
|
| 66 |
+
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
| 67 |
+
h = h.real**2 + h.imag**2
|
| 68 |
+
|
| 69 |
+
h, _ = self.logmel(h, ilens)
|
| 70 |
+
if self.stats_file is not None:
|
| 71 |
+
h, _ = self.global_mvn(h, ilens)
|
| 72 |
+
if self.apply_uttmvn:
|
| 73 |
+
h, _ = self.uttmvn(h, ilens)
|
| 74 |
+
|
| 75 |
+
return h, ilens
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class LogMel(torch.nn.Module):
|
| 79 |
+
"""Convert STFT to fbank feats
|
| 80 |
+
|
| 81 |
+
The arguments is same as librosa.filters.mel
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
fs: number > 0 [scalar] sampling rate of the incoming signal
|
| 85 |
+
n_fft: int > 0 [scalar] number of FFT components
|
| 86 |
+
n_mels: int > 0 [scalar] number of Mel bands to generate
|
| 87 |
+
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
| 88 |
+
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
| 89 |
+
If `None`, use `fmax = fs / 2.0`
|
| 90 |
+
htk: use HTK formula instead of Slaney
|
| 91 |
+
norm: {None, 1, np.inf} [scalar]
|
| 92 |
+
if 1, divide the triangular mel weights by the width of the mel band
|
| 93 |
+
(area normalization). Otherwise, leave all the triangles aiming for
|
| 94 |
+
a peak value of 1.0
|
| 95 |
+
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
fs: int = 16000,
|
| 101 |
+
n_fft: int = 512,
|
| 102 |
+
n_mels: int = 80,
|
| 103 |
+
fmin: float = 0.0,
|
| 104 |
+
fmax: float = None,
|
| 105 |
+
htk: bool = False,
|
| 106 |
+
norm=1,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
_mel_options = dict(
|
| 111 |
+
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
|
| 112 |
+
)
|
| 113 |
+
self.mel_options = _mel_options
|
| 114 |
+
|
| 115 |
+
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
| 116 |
+
melmat = librosa.filters.mel(**_mel_options)
|
| 117 |
+
# melmat: (D2, D1) -> (D1, D2)
|
| 118 |
+
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
| 119 |
+
|
| 120 |
+
def extra_repr(self):
|
| 121 |
+
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self, feat: torch.Tensor, ilens: torch.LongTensor
|
| 125 |
+
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
| 126 |
+
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
| 127 |
+
mel_feat = torch.matmul(feat, self.melmat)
|
| 128 |
+
|
| 129 |
+
logmel_feat = (mel_feat + 1e-20).log()
|
| 130 |
+
# Zero padding
|
| 131 |
+
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
|
| 132 |
+
return logmel_feat, ilens
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class GlobalMVN(torch.nn.Module):
|
| 136 |
+
"""Apply global mean and variance normalization
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
stats_file(str): npy file of 1-dim array or text file.
|
| 140 |
+
From the _first element to
|
| 141 |
+
the {(len(array) - 1) / 2}th element are treated as
|
| 142 |
+
the sum of features,
|
| 143 |
+
and the rest excluding the last elements are
|
| 144 |
+
treated as the sum of the square value of features,
|
| 145 |
+
and the last elements eqauls to the number of samples.
|
| 146 |
+
std_floor(float):
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
stats_file: str,
|
| 152 |
+
norm_means: bool = True,
|
| 153 |
+
norm_vars: bool = True,
|
| 154 |
+
eps: float = 1.0e-20,
|
| 155 |
+
):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.norm_means = norm_means
|
| 158 |
+
self.norm_vars = norm_vars
|
| 159 |
+
|
| 160 |
+
self.stats_file = stats_file
|
| 161 |
+
stats = np.load(stats_file)
|
| 162 |
+
|
| 163 |
+
stats = stats.astype(float)
|
| 164 |
+
assert (len(stats) - 1) % 2 == 0, stats.shape
|
| 165 |
+
|
| 166 |
+
count = stats.flatten()[-1]
|
| 167 |
+
mean = stats[: (len(stats) - 1) // 2] / count
|
| 168 |
+
var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
|
| 169 |
+
std = np.maximum(np.sqrt(var), eps)
|
| 170 |
+
|
| 171 |
+
self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
|
| 172 |
+
self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
|
| 173 |
+
|
| 174 |
+
def extra_repr(self):
|
| 175 |
+
return (
|
| 176 |
+
f"stats_file={self.stats_file}, "
|
| 177 |
+
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def forward(
|
| 181 |
+
self, x: torch.Tensor, ilens: torch.LongTensor
|
| 182 |
+
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
| 183 |
+
# feat: (B, T, D)
|
| 184 |
+
if self.norm_means:
|
| 185 |
+
x += self.bias.type_as(x)
|
| 186 |
+
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
|
| 187 |
+
|
| 188 |
+
if self.norm_vars:
|
| 189 |
+
x *= self.scale.type_as(x)
|
| 190 |
+
return x, ilens
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class UtteranceMVN(torch.nn.Module):
|
| 194 |
+
def __init__(
|
| 195 |
+
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.norm_means = norm_means
|
| 199 |
+
self.norm_vars = norm_vars
|
| 200 |
+
self.eps = eps
|
| 201 |
+
|
| 202 |
+
def extra_repr(self):
|
| 203 |
+
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
| 204 |
+
|
| 205 |
+
def forward(
|
| 206 |
+
self, x: torch.Tensor, ilens: torch.LongTensor
|
| 207 |
+
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
| 208 |
+
return utterance_mvn(
|
| 209 |
+
x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def utterance_mvn(
|
| 214 |
+
x: torch.Tensor,
|
| 215 |
+
ilens: torch.LongTensor,
|
| 216 |
+
norm_means: bool = True,
|
| 217 |
+
norm_vars: bool = False,
|
| 218 |
+
eps: float = 1.0e-20,
|
| 219 |
+
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
| 220 |
+
"""Apply utterance mean and variance normalization
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
x: (B, T, D), assumed zero padded
|
| 224 |
+
ilens: (B, T, D)
|
| 225 |
+
norm_means:
|
| 226 |
+
norm_vars:
|
| 227 |
+
eps:
|
| 228 |
+
|
| 229 |
+
"""
|
| 230 |
+
ilens_ = ilens.type_as(x)
|
| 231 |
+
# mean: (B, D)
|
| 232 |
+
mean = x.sum(dim=1) / ilens_[:, None]
|
| 233 |
+
|
| 234 |
+
if norm_means:
|
| 235 |
+
x -= mean[:, None, :]
|
| 236 |
+
x_ = x
|
| 237 |
+
else:
|
| 238 |
+
x_ = x - mean[:, None, :]
|
| 239 |
+
|
| 240 |
+
# Zero padding
|
| 241 |
+
x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
|
| 242 |
+
if norm_vars:
|
| 243 |
+
var = x_.pow(2).sum(dim=1) / ilens_[:, None]
|
| 244 |
+
var = torch.clamp(var, min=eps)
|
| 245 |
+
x /= var.sqrt()[:, None, :]
|
| 246 |
+
x_ = x
|
| 247 |
+
return x_, ilens
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def feature_transform_for(args, n_fft):
|
| 251 |
+
return FeatureTransform(
|
| 252 |
+
# Mel options,
|
| 253 |
+
fs=args.fbank_fs,
|
| 254 |
+
n_fft=n_fft,
|
| 255 |
+
n_mels=args.n_mels,
|
| 256 |
+
fmin=args.fbank_fmin,
|
| 257 |
+
fmax=args.fbank_fmax,
|
| 258 |
+
# Normalization
|
| 259 |
+
stats_file=args.stats_file,
|
| 260 |
+
apply_uttmvn=args.apply_uttmvn,
|
| 261 |
+
uttmvn_norm_means=args.uttmvn_norm_means,
|
| 262 |
+
uttmvn_norm_vars=args.uttmvn_norm_vars,
|
| 263 |
+
)
|
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import numpy
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch_complex.tensor import ComplexTensor
|
| 10 |
+
|
| 11 |
+
from funasr_detach.frontends.utils.dnn_beamformer import DNN_Beamformer
|
| 12 |
+
from funasr_detach.frontends.utils.dnn_wpe import DNN_WPE
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Frontend(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
idim: int,
|
| 19 |
+
# WPE options
|
| 20 |
+
use_wpe: bool = False,
|
| 21 |
+
wtype: str = "blstmp",
|
| 22 |
+
wlayers: int = 3,
|
| 23 |
+
wunits: int = 300,
|
| 24 |
+
wprojs: int = 320,
|
| 25 |
+
wdropout_rate: float = 0.0,
|
| 26 |
+
taps: int = 5,
|
| 27 |
+
delay: int = 3,
|
| 28 |
+
use_dnn_mask_for_wpe: bool = True,
|
| 29 |
+
# Beamformer options
|
| 30 |
+
use_beamformer: bool = False,
|
| 31 |
+
btype: str = "blstmp",
|
| 32 |
+
blayers: int = 3,
|
| 33 |
+
bunits: int = 300,
|
| 34 |
+
bprojs: int = 320,
|
| 35 |
+
bnmask: int = 2,
|
| 36 |
+
badim: int = 320,
|
| 37 |
+
ref_channel: int = -1,
|
| 38 |
+
bdropout_rate=0.0,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.use_beamformer = use_beamformer
|
| 43 |
+
self.use_wpe = use_wpe
|
| 44 |
+
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
|
| 45 |
+
# use frontend for all the data,
|
| 46 |
+
# e.g. in the case of multi-speaker speech separation
|
| 47 |
+
self.use_frontend_for_all = bnmask > 2
|
| 48 |
+
|
| 49 |
+
if self.use_wpe:
|
| 50 |
+
if self.use_dnn_mask_for_wpe:
|
| 51 |
+
# Use DNN for power estimation
|
| 52 |
+
# (Not observed significant gains)
|
| 53 |
+
iterations = 1
|
| 54 |
+
else:
|
| 55 |
+
# Performing as conventional WPE, without DNN Estimator
|
| 56 |
+
iterations = 2
|
| 57 |
+
|
| 58 |
+
self.wpe = DNN_WPE(
|
| 59 |
+
wtype=wtype,
|
| 60 |
+
widim=idim,
|
| 61 |
+
wunits=wunits,
|
| 62 |
+
wprojs=wprojs,
|
| 63 |
+
wlayers=wlayers,
|
| 64 |
+
taps=taps,
|
| 65 |
+
delay=delay,
|
| 66 |
+
dropout_rate=wdropout_rate,
|
| 67 |
+
iterations=iterations,
|
| 68 |
+
use_dnn_mask=use_dnn_mask_for_wpe,
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
self.wpe = None
|
| 72 |
+
|
| 73 |
+
if self.use_beamformer:
|
| 74 |
+
self.beamformer = DNN_Beamformer(
|
| 75 |
+
btype=btype,
|
| 76 |
+
bidim=idim,
|
| 77 |
+
bunits=bunits,
|
| 78 |
+
bprojs=bprojs,
|
| 79 |
+
blayers=blayers,
|
| 80 |
+
bnmask=bnmask,
|
| 81 |
+
dropout_rate=bdropout_rate,
|
| 82 |
+
badim=badim,
|
| 83 |
+
ref_channel=ref_channel,
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
self.beamformer = None
|
| 87 |
+
|
| 88 |
+
def forward(
|
| 89 |
+
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
|
| 90 |
+
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
|
| 91 |
+
assert len(x) == len(ilens), (len(x), len(ilens))
|
| 92 |
+
# (B, T, F) or (B, T, C, F)
|
| 93 |
+
if x.dim() not in (3, 4):
|
| 94 |
+
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
|
| 95 |
+
if not torch.is_tensor(ilens):
|
| 96 |
+
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
|
| 97 |
+
|
| 98 |
+
mask = None
|
| 99 |
+
h = x
|
| 100 |
+
if h.dim() == 4:
|
| 101 |
+
if self.training:
|
| 102 |
+
choices = [(False, False)] if not self.use_frontend_for_all else []
|
| 103 |
+
if self.use_wpe:
|
| 104 |
+
choices.append((True, False))
|
| 105 |
+
|
| 106 |
+
if self.use_beamformer:
|
| 107 |
+
choices.append((False, True))
|
| 108 |
+
|
| 109 |
+
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
use_wpe = self.use_wpe
|
| 113 |
+
use_beamformer = self.use_beamformer
|
| 114 |
+
|
| 115 |
+
# 1. WPE
|
| 116 |
+
if use_wpe:
|
| 117 |
+
# h: (B, T, C, F) -> h: (B, T, C, F)
|
| 118 |
+
h, ilens, mask = self.wpe(h, ilens)
|
| 119 |
+
|
| 120 |
+
# 2. Beamformer
|
| 121 |
+
if use_beamformer:
|
| 122 |
+
# h: (B, T, C, F) -> h: (B, T, F)
|
| 123 |
+
h, ilens, mask = self.beamformer(h, ilens)
|
| 124 |
+
|
| 125 |
+
return h, ilens, mask
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def frontend_for(args, idim):
|
| 129 |
+
return Frontend(
|
| 130 |
+
idim=idim,
|
| 131 |
+
# WPE options
|
| 132 |
+
use_wpe=args.use_wpe,
|
| 133 |
+
wtype=args.wtype,
|
| 134 |
+
wlayers=args.wlayers,
|
| 135 |
+
wunits=args.wunits,
|
| 136 |
+
wprojs=args.wprojs,
|
| 137 |
+
wdropout_rate=args.wdropout_rate,
|
| 138 |
+
taps=args.wpe_taps,
|
| 139 |
+
delay=args.wpe_delay,
|
| 140 |
+
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
|
| 141 |
+
# Beamformer options
|
| 142 |
+
use_beamformer=args.use_beamformer,
|
| 143 |
+
btype=args.btype,
|
| 144 |
+
blayers=args.blayers,
|
| 145 |
+
bunits=args.bunits,
|
| 146 |
+
bprojs=args.bprojs,
|
| 147 |
+
bnmask=args.bnmask,
|
| 148 |
+
badim=args.badim,
|
| 149 |
+
ref_channel=args.ref_channel,
|
| 150 |
+
bdropout_rate=args.bdropout_rate,
|
| 151 |
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LogMel(torch.nn.Module):
|
| 9 |
+
"""Convert STFT to fbank feats
|
| 10 |
+
|
| 11 |
+
The arguments is same as librosa.filters.mel
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
fs: number > 0 [scalar] sampling rate of the incoming signal
|
| 15 |
+
n_fft: int > 0 [scalar] number of FFT components
|
| 16 |
+
n_mels: int > 0 [scalar] number of Mel bands to generate
|
| 17 |
+
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
| 18 |
+
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
| 19 |
+
If `None`, use `fmax = fs / 2.0`
|
| 20 |
+
htk: use HTK formula instead of Slaney
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
fs: int = 16000,
|
| 26 |
+
n_fft: int = 512,
|
| 27 |
+
n_mels: int = 80,
|
| 28 |
+
fmin: float = None,
|
| 29 |
+
fmax: float = None,
|
| 30 |
+
htk: bool = False,
|
| 31 |
+
log_base: float = None,
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
fmin = 0 if fmin is None else fmin
|
| 36 |
+
fmax = fs / 2 if fmax is None else fmax
|
| 37 |
+
_mel_options = dict(
|
| 38 |
+
sr=fs,
|
| 39 |
+
n_fft=n_fft,
|
| 40 |
+
n_mels=n_mels,
|
| 41 |
+
fmin=fmin,
|
| 42 |
+
fmax=fmax,
|
| 43 |
+
htk=htk,
|
| 44 |
+
)
|
| 45 |
+
self.mel_options = _mel_options
|
| 46 |
+
self.log_base = log_base
|
| 47 |
+
|
| 48 |
+
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
| 49 |
+
melmat = librosa.filters.mel(**_mel_options)
|
| 50 |
+
# melmat: (D2, D1) -> (D1, D2)
|
| 51 |
+
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
| 52 |
+
|
| 53 |
+
def extra_repr(self):
|
| 54 |
+
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
feat: torch.Tensor,
|
| 59 |
+
ilens: torch.Tensor = None,
|
| 60 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 61 |
+
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
| 62 |
+
mel_feat = torch.matmul(feat, self.melmat)
|
| 63 |
+
mel_feat = torch.clamp(mel_feat, min=1e-10)
|
| 64 |
+
|
| 65 |
+
if self.log_base is None:
|
| 66 |
+
logmel_feat = mel_feat.log()
|
| 67 |
+
elif self.log_base == 2.0:
|
| 68 |
+
logmel_feat = mel_feat.log2()
|
| 69 |
+
elif self.log_base == 10.0:
|
| 70 |
+
logmel_feat = mel_feat.log10()
|
| 71 |
+
else:
|
| 72 |
+
logmel_feat = mel_feat.log() / torch.log(self.log_base)
|
| 73 |
+
|
| 74 |
+
# Zero padding
|
| 75 |
+
if ilens is not None:
|
| 76 |
+
logmel_feat = logmel_feat.masked_fill(
|
| 77 |
+
make_pad_mask(ilens, logmel_feat, 1), 0.0
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
ilens = feat.new_full(
|
| 81 |
+
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
| 82 |
+
)
|
| 83 |
+
return logmel_feat, ilens
|
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch_complex.tensor import ComplexTensor
|
| 7 |
+
|
| 8 |
+
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
|
| 9 |
+
from funasr_detach.models.language_model.rnn.encoders import RNN
|
| 10 |
+
from funasr_detach.models.language_model.rnn.encoders import RNNP
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MaskEstimator(torch.nn.Module):
|
| 14 |
+
def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
|
| 15 |
+
super().__init__()
|
| 16 |
+
subsample = np.ones(layers + 1, dtype=np.int32)
|
| 17 |
+
|
| 18 |
+
typ = type.lstrip("vgg").rstrip("p")
|
| 19 |
+
if type[-1] == "p":
|
| 20 |
+
self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
|
| 21 |
+
else:
|
| 22 |
+
self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
|
| 23 |
+
|
| 24 |
+
self.type = type
|
| 25 |
+
self.nmask = nmask
|
| 26 |
+
self.linears = torch.nn.ModuleList(
|
| 27 |
+
[torch.nn.Linear(projs, idim) for _ in range(nmask)]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(
|
| 31 |
+
self, xs: ComplexTensor, ilens: torch.LongTensor
|
| 32 |
+
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
|
| 33 |
+
"""The forward function
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
xs: (B, F, C, T)
|
| 37 |
+
ilens: (B,)
|
| 38 |
+
Returns:
|
| 39 |
+
hs (torch.Tensor): The hidden vector (B, F, C, T)
|
| 40 |
+
masks: A tuple of the masks. (B, F, C, T)
|
| 41 |
+
ilens: (B,)
|
| 42 |
+
"""
|
| 43 |
+
assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
|
| 44 |
+
_, _, C, input_length = xs.size()
|
| 45 |
+
# (B, F, C, T) -> (B, C, T, F)
|
| 46 |
+
xs = xs.permute(0, 2, 3, 1)
|
| 47 |
+
|
| 48 |
+
# Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
|
| 49 |
+
xs = (xs.real**2 + xs.imag**2) ** 0.5
|
| 50 |
+
# xs: (B, C, T, F) -> xs: (B * C, T, F)
|
| 51 |
+
xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
|
| 52 |
+
# ilens: (B,) -> ilens_: (B * C)
|
| 53 |
+
ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
|
| 54 |
+
|
| 55 |
+
# xs: (B * C, T, F) -> xs: (B * C, T, D)
|
| 56 |
+
xs, _, _ = self.brnn(xs, ilens_)
|
| 57 |
+
# xs: (B * C, T, D) -> xs: (B, C, T, D)
|
| 58 |
+
xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
|
| 59 |
+
|
| 60 |
+
masks = []
|
| 61 |
+
for linear in self.linears:
|
| 62 |
+
# xs: (B, C, T, D) -> mask:(B, C, T, F)
|
| 63 |
+
mask = linear(xs)
|
| 64 |
+
|
| 65 |
+
mask = torch.sigmoid(mask)
|
| 66 |
+
# Zero padding
|
| 67 |
+
mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
|
| 68 |
+
|
| 69 |
+
# (B, C, T, F) -> (B, F, C, T)
|
| 70 |
+
mask = mask.permute(0, 3, 1, 2)
|
| 71 |
+
|
| 72 |
+
# Take cares of multi gpu cases: If input_length > max(ilens)
|
| 73 |
+
if mask.size(-1) < input_length:
|
| 74 |
+
mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
|
| 75 |
+
masks.append(mask)
|
| 76 |
+
|
| 77 |
+
return tuple(masks), ilens
|
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from distutils.version import LooseVersion
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from torch_complex.tensor import ComplexTensor
|
| 10 |
+
except:
|
| 11 |
+
print("Please install torch_complex firstly")
|
| 12 |
+
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
|
| 13 |
+
from funasr_detach.frontends.utils.complex_utils import is_complex
|
| 14 |
+
|
| 15 |
+
import librosa
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Stft(torch.nn.Module):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
n_fft: int = 512,
|
| 28 |
+
win_length: int = None,
|
| 29 |
+
hop_length: int = 128,
|
| 30 |
+
window: Optional[str] = "hann",
|
| 31 |
+
center: bool = True,
|
| 32 |
+
normalized: bool = False,
|
| 33 |
+
onesided: bool = True,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.n_fft = n_fft
|
| 37 |
+
if win_length is None:
|
| 38 |
+
self.win_length = n_fft
|
| 39 |
+
else:
|
| 40 |
+
self.win_length = win_length
|
| 41 |
+
self.hop_length = hop_length
|
| 42 |
+
self.center = center
|
| 43 |
+
self.normalized = normalized
|
| 44 |
+
self.onesided = onesided
|
| 45 |
+
if window is not None and not hasattr(torch, f"{window}_window"):
|
| 46 |
+
if window.lower() != "povey":
|
| 47 |
+
raise ValueError(f"{window} window is not implemented")
|
| 48 |
+
self.window = window
|
| 49 |
+
|
| 50 |
+
def extra_repr(self):
|
| 51 |
+
return (
|
| 52 |
+
f"n_fft={self.n_fft}, "
|
| 53 |
+
f"win_length={self.win_length}, "
|
| 54 |
+
f"hop_length={self.hop_length}, "
|
| 55 |
+
f"center={self.center}, "
|
| 56 |
+
f"normalized={self.normalized}, "
|
| 57 |
+
f"onesided={self.onesided}"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self, input: torch.Tensor, ilens: torch.Tensor = None
|
| 62 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 63 |
+
"""STFT forward function.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
| 67 |
+
ilens: (Batch)
|
| 68 |
+
Returns:
|
| 69 |
+
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
bs = input.size(0)
|
| 73 |
+
if input.dim() == 3:
|
| 74 |
+
multi_channel = True
|
| 75 |
+
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
| 76 |
+
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
| 77 |
+
else:
|
| 78 |
+
multi_channel = False
|
| 79 |
+
|
| 80 |
+
# NOTE(kamo):
|
| 81 |
+
# The default behaviour of torch.stft is compatible with librosa.stft
|
| 82 |
+
# about padding and scaling.
|
| 83 |
+
# Note that it's different from scipy.signal.stft
|
| 84 |
+
|
| 85 |
+
# output: (Batch, Freq, Frames, 2=real_imag)
|
| 86 |
+
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
| 87 |
+
if self.window is not None:
|
| 88 |
+
if self.window.lower() == "povey":
|
| 89 |
+
window = torch.hann_window(
|
| 90 |
+
self.win_length,
|
| 91 |
+
periodic=False,
|
| 92 |
+
device=input.device,
|
| 93 |
+
dtype=input.dtype,
|
| 94 |
+
).pow(0.85)
|
| 95 |
+
else:
|
| 96 |
+
window_func = getattr(torch, f"{self.window}_window")
|
| 97 |
+
window = window_func(
|
| 98 |
+
self.win_length, dtype=input.dtype, device=input.device
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
window = None
|
| 102 |
+
|
| 103 |
+
# For the compatibility of ARM devices, which do not support
|
| 104 |
+
# torch.stft() due to the lake of MKL.
|
| 105 |
+
if input.is_cuda or torch.backends.mkl.is_available():
|
| 106 |
+
stft_kwargs = dict(
|
| 107 |
+
n_fft=self.n_fft,
|
| 108 |
+
win_length=self.win_length,
|
| 109 |
+
hop_length=self.hop_length,
|
| 110 |
+
center=self.center,
|
| 111 |
+
window=window,
|
| 112 |
+
normalized=self.normalized,
|
| 113 |
+
onesided=self.onesided,
|
| 114 |
+
)
|
| 115 |
+
if is_torch_1_7_plus:
|
| 116 |
+
stft_kwargs["return_complex"] = False
|
| 117 |
+
output = torch.stft(input, **stft_kwargs)
|
| 118 |
+
else:
|
| 119 |
+
if self.training:
|
| 120 |
+
raise NotImplementedError(
|
| 121 |
+
"stft is implemented with librosa on this device, which does not "
|
| 122 |
+
"support the training mode."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
|
| 126 |
+
stft_kwargs = dict(
|
| 127 |
+
n_fft=self.n_fft,
|
| 128 |
+
win_length=self.win_length,
|
| 129 |
+
hop_length=self.hop_length,
|
| 130 |
+
center=self.center,
|
| 131 |
+
window=window,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if window is not None:
|
| 135 |
+
# pad the given window to n_fft
|
| 136 |
+
n_pad_left = (self.n_fft - window.shape[0]) // 2
|
| 137 |
+
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
|
| 138 |
+
stft_kwargs["window"] = torch.cat(
|
| 139 |
+
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
|
| 140 |
+
).numpy()
|
| 141 |
+
else:
|
| 142 |
+
win_length = (
|
| 143 |
+
self.win_length if self.win_length is not None else self.n_fft
|
| 144 |
+
)
|
| 145 |
+
stft_kwargs["window"] = torch.ones(win_length)
|
| 146 |
+
|
| 147 |
+
output = []
|
| 148 |
+
# iterate over istances in a batch
|
| 149 |
+
for i, instance in enumerate(input):
|
| 150 |
+
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
|
| 151 |
+
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
|
| 152 |
+
output = torch.stack(output, 0)
|
| 153 |
+
if not self.onesided:
|
| 154 |
+
len_conj = self.n_fft - output.shape[1]
|
| 155 |
+
conj = output[:, 1 : 1 + len_conj].flip(1)
|
| 156 |
+
conj[:, :, :, -1].data *= -1
|
| 157 |
+
output = torch.cat([output, conj], 1)
|
| 158 |
+
if self.normalized:
|
| 159 |
+
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
|
| 160 |
+
|
| 161 |
+
# output: (Batch, Freq, Frames, 2=real_imag)
|
| 162 |
+
# -> (Batch, Frames, Freq, 2=real_imag)
|
| 163 |
+
output = output.transpose(1, 2)
|
| 164 |
+
if multi_channel:
|
| 165 |
+
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
| 166 |
+
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
| 167 |
+
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
|
| 168 |
+
1, 2
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if ilens is not None:
|
| 172 |
+
if self.center:
|
| 173 |
+
pad = self.n_fft // 2
|
| 174 |
+
ilens = ilens + 2 * pad
|
| 175 |
+
|
| 176 |
+
olens = (ilens - self.n_fft) // self.hop_length + 1
|
| 177 |
+
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
| 178 |
+
else:
|
| 179 |
+
olens = None
|
| 180 |
+
|
| 181 |
+
return output, olens
|
| 182 |
+
|
| 183 |
+
def inverse(
|
| 184 |
+
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
|
| 185 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 186 |
+
"""Inverse STFT.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
|
| 190 |
+
ilens: (batch,)
|
| 191 |
+
Returns:
|
| 192 |
+
wavs: (batch, samples)
|
| 193 |
+
ilens: (batch,)
|
| 194 |
+
"""
|
| 195 |
+
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
| 196 |
+
istft = torch.functional.istft
|
| 197 |
+
else:
|
| 198 |
+
try:
|
| 199 |
+
import torchaudio
|
| 200 |
+
except ImportError:
|
| 201 |
+
raise ImportError(
|
| 202 |
+
"Please install torchaudio>=0.3.0 or use torch>=1.6.0"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if not hasattr(torchaudio.functional, "istft"):
|
| 206 |
+
raise ImportError(
|
| 207 |
+
"Please install torchaudio>=0.3.0 or use torch>=1.6.0"
|
| 208 |
+
)
|
| 209 |
+
istft = torchaudio.functional.istft
|
| 210 |
+
|
| 211 |
+
if self.window is not None:
|
| 212 |
+
window_func = getattr(torch, f"{self.window}_window")
|
| 213 |
+
if is_complex(input):
|
| 214 |
+
datatype = input.real.dtype
|
| 215 |
+
else:
|
| 216 |
+
datatype = input.dtype
|
| 217 |
+
window = window_func(self.win_length, dtype=datatype, device=input.device)
|
| 218 |
+
else:
|
| 219 |
+
window = None
|
| 220 |
+
|
| 221 |
+
if is_complex(input):
|
| 222 |
+
input = torch.stack([input.real, input.imag], dim=-1)
|
| 223 |
+
elif input.shape[-1] != 2:
|
| 224 |
+
raise TypeError("Invalid input type")
|
| 225 |
+
input = input.transpose(1, 2)
|
| 226 |
+
|
| 227 |
+
wavs = istft(
|
| 228 |
+
input,
|
| 229 |
+
n_fft=self.n_fft,
|
| 230 |
+
hop_length=self.hop_length,
|
| 231 |
+
win_length=self.win_length,
|
| 232 |
+
window=window,
|
| 233 |
+
center=self.center,
|
| 234 |
+
normalized=self.normalized,
|
| 235 |
+
onesided=self.onesided,
|
| 236 |
+
length=ilens.max() if ilens is not None else ilens,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return wavs, ilens
|
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
# Part of the implementation is borrowed from espnet/espnet.
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
import copy
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
|
| 11 |
+
import funasr_detach.frontends.eend_ola_feature as eend_ola_feature
|
| 12 |
+
from funasr_detach.register import tables
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_cmvn(cmvn_file):
|
| 16 |
+
with open(cmvn_file, "r", encoding="utf-8") as f:
|
| 17 |
+
lines = f.readlines()
|
| 18 |
+
means_list = []
|
| 19 |
+
vars_list = []
|
| 20 |
+
for i in range(len(lines)):
|
| 21 |
+
line_item = lines[i].split()
|
| 22 |
+
if line_item[0] == "<AddShift>":
|
| 23 |
+
line_item = lines[i + 1].split()
|
| 24 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 25 |
+
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
| 26 |
+
means_list = list(add_shift_line)
|
| 27 |
+
continue
|
| 28 |
+
elif line_item[0] == "<Rescale>":
|
| 29 |
+
line_item = lines[i + 1].split()
|
| 30 |
+
if line_item[0] == "<LearnRateCoef>":
|
| 31 |
+
rescale_line = line_item[3 : (len(line_item) - 1)]
|
| 32 |
+
vars_list = list(rescale_line)
|
| 33 |
+
continue
|
| 34 |
+
means = np.array(means_list).astype(np.float32)
|
| 35 |
+
vars = np.array(vars_list).astype(np.float32)
|
| 36 |
+
cmvn = np.array([means, vars])
|
| 37 |
+
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
|
| 38 |
+
return cmvn
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def apply_cmvn(inputs, cmvn): # noqa
|
| 42 |
+
"""
|
| 43 |
+
Apply CMVN with mvn data
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
device = inputs.device
|
| 47 |
+
dtype = inputs.dtype
|
| 48 |
+
frame, dim = inputs.shape
|
| 49 |
+
|
| 50 |
+
means = cmvn[0:1, :dim]
|
| 51 |
+
vars = cmvn[1:2, :dim]
|
| 52 |
+
inputs += means.to(device)
|
| 53 |
+
inputs *= vars.to(device)
|
| 54 |
+
|
| 55 |
+
return inputs.type(torch.float32)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def apply_lfr(inputs, lfr_m, lfr_n):
|
| 59 |
+
LFR_inputs = []
|
| 60 |
+
T = inputs.shape[0]
|
| 61 |
+
T_lfr = int(np.ceil(T / lfr_n))
|
| 62 |
+
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
|
| 63 |
+
inputs = torch.vstack((left_padding, inputs))
|
| 64 |
+
T = T + (lfr_m - 1) // 2
|
| 65 |
+
for i in range(T_lfr):
|
| 66 |
+
if lfr_m <= T - i * lfr_n:
|
| 67 |
+
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
|
| 68 |
+
else: # process last LFR frame
|
| 69 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 70 |
+
frame = (inputs[i * lfr_n :]).view(-1)
|
| 71 |
+
for _ in range(num_padding):
|
| 72 |
+
frame = torch.hstack((frame, inputs[-1]))
|
| 73 |
+
LFR_inputs.append(frame)
|
| 74 |
+
LFR_outputs = torch.vstack(LFR_inputs)
|
| 75 |
+
return LFR_outputs.type(torch.float32)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@tables.register("frontend_classes", "WavFrontend")
|
| 79 |
+
class WavFrontend(nn.Module):
|
| 80 |
+
"""Conventional frontend structure for ASR."""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
cmvn_file: str = None,
|
| 85 |
+
fs: int = 16000,
|
| 86 |
+
window: str = "hamming",
|
| 87 |
+
n_mels: int = 80,
|
| 88 |
+
frame_length: int = 25,
|
| 89 |
+
frame_shift: int = 10,
|
| 90 |
+
filter_length_min: int = -1,
|
| 91 |
+
filter_length_max: int = -1,
|
| 92 |
+
lfr_m: int = 1,
|
| 93 |
+
lfr_n: int = 1,
|
| 94 |
+
dither: float = 1.0,
|
| 95 |
+
snip_edges: bool = True,
|
| 96 |
+
upsacle_samples: bool = True,
|
| 97 |
+
**kwargs,
|
| 98 |
+
):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.fs = fs
|
| 101 |
+
self.window = window
|
| 102 |
+
self.n_mels = n_mels
|
| 103 |
+
self.frame_length = frame_length
|
| 104 |
+
self.frame_shift = frame_shift
|
| 105 |
+
self.filter_length_min = filter_length_min
|
| 106 |
+
self.filter_length_max = filter_length_max
|
| 107 |
+
self.lfr_m = lfr_m
|
| 108 |
+
self.lfr_n = lfr_n
|
| 109 |
+
self.cmvn_file = cmvn_file
|
| 110 |
+
self.dither = dither
|
| 111 |
+
self.snip_edges = snip_edges
|
| 112 |
+
self.upsacle_samples = upsacle_samples
|
| 113 |
+
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
|
| 114 |
+
|
| 115 |
+
def output_size(self) -> int:
|
| 116 |
+
return self.n_mels * self.lfr_m
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
input: torch.Tensor,
|
| 121 |
+
input_lengths,
|
| 122 |
+
**kwargs,
|
| 123 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 124 |
+
batch_size = input.size(0)
|
| 125 |
+
feats = []
|
| 126 |
+
feats_lens = []
|
| 127 |
+
for i in range(batch_size):
|
| 128 |
+
waveform_length = input_lengths[i]
|
| 129 |
+
waveform = input[i][:waveform_length]
|
| 130 |
+
if self.upsacle_samples:
|
| 131 |
+
waveform = waveform * (1 << 15)
|
| 132 |
+
waveform = waveform.unsqueeze(0)
|
| 133 |
+
mat = kaldi.fbank(
|
| 134 |
+
waveform,
|
| 135 |
+
num_mel_bins=self.n_mels,
|
| 136 |
+
frame_length=self.frame_length,
|
| 137 |
+
frame_shift=self.frame_shift,
|
| 138 |
+
dither=self.dither,
|
| 139 |
+
energy_floor=0.0,
|
| 140 |
+
window_type=self.window,
|
| 141 |
+
sample_frequency=self.fs,
|
| 142 |
+
snip_edges=self.snip_edges,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 146 |
+
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
| 147 |
+
if self.cmvn is not None:
|
| 148 |
+
mat = apply_cmvn(mat, self.cmvn)
|
| 149 |
+
feat_length = mat.size(0)
|
| 150 |
+
feats.append(mat)
|
| 151 |
+
feats_lens.append(feat_length)
|
| 152 |
+
|
| 153 |
+
feats_lens = torch.as_tensor(feats_lens)
|
| 154 |
+
if batch_size == 1:
|
| 155 |
+
feats_pad = feats[0][None, :, :]
|
| 156 |
+
else:
|
| 157 |
+
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
| 158 |
+
return feats_pad, feats_lens
|
| 159 |
+
|
| 160 |
+
def forward_fbank(
|
| 161 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 162 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 163 |
+
batch_size = input.size(0)
|
| 164 |
+
feats = []
|
| 165 |
+
feats_lens = []
|
| 166 |
+
for i in range(batch_size):
|
| 167 |
+
waveform_length = input_lengths[i]
|
| 168 |
+
waveform = input[i][:waveform_length]
|
| 169 |
+
waveform = waveform * (1 << 15)
|
| 170 |
+
waveform = waveform.unsqueeze(0)
|
| 171 |
+
mat = kaldi.fbank(
|
| 172 |
+
waveform,
|
| 173 |
+
num_mel_bins=self.n_mels,
|
| 174 |
+
frame_length=self.frame_length,
|
| 175 |
+
frame_shift=self.frame_shift,
|
| 176 |
+
dither=self.dither,
|
| 177 |
+
energy_floor=0.0,
|
| 178 |
+
window_type=self.window,
|
| 179 |
+
sample_frequency=self.fs,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
feat_length = mat.size(0)
|
| 183 |
+
feats.append(mat)
|
| 184 |
+
feats_lens.append(feat_length)
|
| 185 |
+
|
| 186 |
+
feats_lens = torch.as_tensor(feats_lens)
|
| 187 |
+
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
| 188 |
+
return feats_pad, feats_lens
|
| 189 |
+
|
| 190 |
+
def forward_lfr_cmvn(
|
| 191 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 192 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 193 |
+
batch_size = input.size(0)
|
| 194 |
+
feats = []
|
| 195 |
+
feats_lens = []
|
| 196 |
+
for i in range(batch_size):
|
| 197 |
+
mat = input[i, : input_lengths[i], :]
|
| 198 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 199 |
+
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
| 200 |
+
if self.cmvn is not None:
|
| 201 |
+
mat = apply_cmvn(mat, self.cmvn)
|
| 202 |
+
feat_length = mat.size(0)
|
| 203 |
+
feats.append(mat)
|
| 204 |
+
feats_lens.append(feat_length)
|
| 205 |
+
|
| 206 |
+
feats_lens = torch.as_tensor(feats_lens)
|
| 207 |
+
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
| 208 |
+
return feats_pad, feats_lens
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@tables.register("frontend_classes", "WavFrontendOnline")
|
| 212 |
+
class WavFrontendOnline(nn.Module):
|
| 213 |
+
"""Conventional frontend structure for streaming ASR/VAD."""
|
| 214 |
+
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
cmvn_file: str = None,
|
| 218 |
+
fs: int = 16000,
|
| 219 |
+
window: str = "hamming",
|
| 220 |
+
n_mels: int = 80,
|
| 221 |
+
frame_length: int = 25,
|
| 222 |
+
frame_shift: int = 10,
|
| 223 |
+
filter_length_min: int = -1,
|
| 224 |
+
filter_length_max: int = -1,
|
| 225 |
+
lfr_m: int = 1,
|
| 226 |
+
lfr_n: int = 1,
|
| 227 |
+
dither: float = 1.0,
|
| 228 |
+
snip_edges: bool = True,
|
| 229 |
+
upsacle_samples: bool = True,
|
| 230 |
+
**kwargs,
|
| 231 |
+
):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.fs = fs
|
| 234 |
+
self.window = window
|
| 235 |
+
self.n_mels = n_mels
|
| 236 |
+
self.frame_length = frame_length
|
| 237 |
+
self.frame_shift = frame_shift
|
| 238 |
+
self.frame_sample_length = int(self.frame_length * self.fs / 1000)
|
| 239 |
+
self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
|
| 240 |
+
self.filter_length_min = filter_length_min
|
| 241 |
+
self.filter_length_max = filter_length_max
|
| 242 |
+
self.lfr_m = lfr_m
|
| 243 |
+
self.lfr_n = lfr_n
|
| 244 |
+
self.cmvn_file = cmvn_file
|
| 245 |
+
self.dither = dither
|
| 246 |
+
self.snip_edges = snip_edges
|
| 247 |
+
self.upsacle_samples = upsacle_samples
|
| 248 |
+
# self.waveforms = None
|
| 249 |
+
# self.reserve_waveforms = None
|
| 250 |
+
# self.fbanks = None
|
| 251 |
+
# self.fbanks_lens = None
|
| 252 |
+
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
|
| 253 |
+
# self.input_cache = None
|
| 254 |
+
# self.lfr_splice_cache = []
|
| 255 |
+
|
| 256 |
+
def output_size(self) -> int:
|
| 257 |
+
return self.n_mels * self.lfr_m
|
| 258 |
+
|
| 259 |
+
@staticmethod
|
| 260 |
+
def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
|
| 261 |
+
"""
|
| 262 |
+
Apply CMVN with mvn data
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
device = inputs.device
|
| 266 |
+
dtype = inputs.dtype
|
| 267 |
+
frame, dim = inputs.shape
|
| 268 |
+
|
| 269 |
+
means = np.tile(cmvn[0:1, :dim], (frame, 1))
|
| 270 |
+
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
|
| 271 |
+
inputs += torch.from_numpy(means).type(dtype).to(device)
|
| 272 |
+
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
| 273 |
+
|
| 274 |
+
return inputs.type(torch.float32)
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def apply_lfr(
|
| 278 |
+
inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False
|
| 279 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 280 |
+
"""
|
| 281 |
+
Apply lfr with data
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
LFR_inputs = []
|
| 285 |
+
# inputs = torch.vstack((inputs_lfr_cache, inputs))
|
| 286 |
+
T = inputs.shape[0] # include the right context
|
| 287 |
+
T_lfr = int(
|
| 288 |
+
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
|
| 289 |
+
) # minus the right context: (lfr_m - 1) // 2
|
| 290 |
+
splice_idx = T_lfr
|
| 291 |
+
for i in range(T_lfr):
|
| 292 |
+
if lfr_m <= T - i * lfr_n:
|
| 293 |
+
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
|
| 294 |
+
else: # process last LFR frame
|
| 295 |
+
if is_final:
|
| 296 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
| 297 |
+
frame = (inputs[i * lfr_n :]).view(-1)
|
| 298 |
+
for _ in range(num_padding):
|
| 299 |
+
frame = torch.hstack((frame, inputs[-1]))
|
| 300 |
+
LFR_inputs.append(frame)
|
| 301 |
+
else:
|
| 302 |
+
# update splice_idx and break the circle
|
| 303 |
+
splice_idx = i
|
| 304 |
+
break
|
| 305 |
+
splice_idx = min(T - 1, splice_idx * lfr_n)
|
| 306 |
+
lfr_splice_cache = inputs[splice_idx:, :]
|
| 307 |
+
LFR_outputs = torch.vstack(LFR_inputs)
|
| 308 |
+
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
|
| 309 |
+
|
| 310 |
+
@staticmethod
|
| 311 |
+
def compute_frame_num(
|
| 312 |
+
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
|
| 313 |
+
) -> int:
|
| 314 |
+
frame_num = int(
|
| 315 |
+
(sample_length - frame_sample_length) / frame_shift_sample_length + 1
|
| 316 |
+
)
|
| 317 |
+
return (
|
| 318 |
+
frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def forward_fbank(
|
| 322 |
+
self,
|
| 323 |
+
input: torch.Tensor,
|
| 324 |
+
input_lengths: torch.Tensor,
|
| 325 |
+
cache: dict = {},
|
| 326 |
+
**kwargs,
|
| 327 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 328 |
+
batch_size = input.size(0)
|
| 329 |
+
assert batch_size == 1
|
| 330 |
+
input = torch.cat((cache["input_cache"], input), dim=1)
|
| 331 |
+
frame_num = self.compute_frame_num(
|
| 332 |
+
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
|
| 333 |
+
)
|
| 334 |
+
# update self.in_cache
|
| 335 |
+
cache["input_cache"] = input[
|
| 336 |
+
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
|
| 337 |
+
]
|
| 338 |
+
waveforms = torch.empty(0)
|
| 339 |
+
feats_pad = torch.empty(0)
|
| 340 |
+
feats_lens = torch.empty(0)
|
| 341 |
+
if frame_num:
|
| 342 |
+
waveforms = []
|
| 343 |
+
feats = []
|
| 344 |
+
feats_lens = []
|
| 345 |
+
for i in range(batch_size):
|
| 346 |
+
waveform = input[i].cuda()
|
| 347 |
+
# we need accurate wave samples that used for fbank extracting
|
| 348 |
+
waveforms.append(
|
| 349 |
+
waveform[
|
| 350 |
+
: (
|
| 351 |
+
(frame_num - 1) * self.frame_shift_sample_length
|
| 352 |
+
+ self.frame_sample_length
|
| 353 |
+
)
|
| 354 |
+
]
|
| 355 |
+
)
|
| 356 |
+
waveform = waveform * (1 << 15)
|
| 357 |
+
waveform = waveform.unsqueeze(0)
|
| 358 |
+
mat = kaldi.fbank(
|
| 359 |
+
waveform,
|
| 360 |
+
num_mel_bins=self.n_mels,
|
| 361 |
+
frame_length=self.frame_length,
|
| 362 |
+
frame_shift=self.frame_shift,
|
| 363 |
+
dither=self.dither,
|
| 364 |
+
energy_floor=0.0,
|
| 365 |
+
window_type=self.window,
|
| 366 |
+
sample_frequency=self.fs,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
feat_length = mat.size(0)
|
| 370 |
+
feats.append(mat)
|
| 371 |
+
feats_lens.append(feat_length)
|
| 372 |
+
|
| 373 |
+
waveforms = torch.stack(waveforms)
|
| 374 |
+
feats_lens = torch.as_tensor(feats_lens)
|
| 375 |
+
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
| 376 |
+
cache["fbanks"] = feats_pad
|
| 377 |
+
cache["fbanks_lens"] = copy.deepcopy(feats_lens)
|
| 378 |
+
return waveforms, feats_pad, feats_lens
|
| 379 |
+
|
| 380 |
+
def forward_lfr_cmvn(
|
| 381 |
+
self,
|
| 382 |
+
input: torch.Tensor,
|
| 383 |
+
input_lengths: torch.Tensor,
|
| 384 |
+
is_final: bool = False,
|
| 385 |
+
cache: dict = {},
|
| 386 |
+
**kwargs,
|
| 387 |
+
):
|
| 388 |
+
batch_size = input.size(0)
|
| 389 |
+
feats = []
|
| 390 |
+
feats_lens = []
|
| 391 |
+
lfr_splice_frame_idxs = []
|
| 392 |
+
for i in range(batch_size):
|
| 393 |
+
mat = input[i, : input_lengths[i], :]
|
| 394 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
| 395 |
+
# update self.lfr_splice_cache in self.apply_lfr
|
| 396 |
+
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
|
| 397 |
+
mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = (
|
| 398 |
+
self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
|
| 399 |
+
)
|
| 400 |
+
if self.cmvn_file is not None:
|
| 401 |
+
mat = self.apply_cmvn(mat, self.cmvn)
|
| 402 |
+
feat_length = mat.size(0)
|
| 403 |
+
feats.append(mat)
|
| 404 |
+
feats_lens.append(feat_length)
|
| 405 |
+
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
| 406 |
+
feats_lens = torch.as_tensor(feats_lens)
|
| 407 |
+
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
| 408 |
+
lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
|
| 409 |
+
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
| 410 |
+
|
| 411 |
+
def forward(self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs):
|
| 412 |
+
is_final = kwargs.get("is_final", False)
|
| 413 |
+
cache = kwargs.get("cache", {})
|
| 414 |
+
if len(cache) == 0:
|
| 415 |
+
self.init_cache(cache)
|
| 416 |
+
|
| 417 |
+
batch_size = input.shape[0]
|
| 418 |
+
assert (
|
| 419 |
+
batch_size == 1
|
| 420 |
+
), "we support to extract feature online only when the batch size is equal to 1 now"
|
| 421 |
+
|
| 422 |
+
waveforms, feats, feats_lengths = self.forward_fbank(
|
| 423 |
+
input, input_lengths, cache=cache
|
| 424 |
+
) # input shape: B T D
|
| 425 |
+
|
| 426 |
+
if feats.shape[0]:
|
| 427 |
+
|
| 428 |
+
cache["waveforms"] = torch.cat(
|
| 429 |
+
(cache["reserve_waveforms"], waveforms.cpu()), dim=1
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if not cache["lfr_splice_cache"]: # 初始化splice_cache
|
| 433 |
+
for i in range(batch_size):
|
| 434 |
+
cache["lfr_splice_cache"].append(
|
| 435 |
+
feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1)
|
| 436 |
+
)
|
| 437 |
+
# need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
|
| 438 |
+
if feats_lengths[0] + cache["lfr_splice_cache"][0].shape[0] >= self.lfr_m:
|
| 439 |
+
lfr_splice_cache_tensor = torch.stack(
|
| 440 |
+
cache["lfr_splice_cache"]
|
| 441 |
+
) # B T D
|
| 442 |
+
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
|
| 443 |
+
|
| 444 |
+
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
|
| 445 |
+
frame_from_waveforms = int(
|
| 446 |
+
(cache["waveforms"].shape[1] - self.frame_sample_length)
|
| 447 |
+
/ self.frame_shift_sample_length
|
| 448 |
+
+ 1
|
| 449 |
+
)
|
| 450 |
+
minus_frame = (
|
| 451 |
+
(self.lfr_m - 1) // 2
|
| 452 |
+
if cache["reserve_waveforms"].numel() == 0
|
| 453 |
+
else 0
|
| 454 |
+
)
|
| 455 |
+
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(
|
| 456 |
+
feats, feats_lengths, is_final, cache=cache
|
| 457 |
+
)
|
| 458 |
+
if self.lfr_m == 1:
|
| 459 |
+
cache["reserve_waveforms"] = torch.empty(0)
|
| 460 |
+
else:
|
| 461 |
+
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
| 462 |
+
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
| 463 |
+
# print('frame_frame: ' + str(frame_from_waveforms))
|
| 464 |
+
cache["reserve_waveforms"] = cache["waveforms"][
|
| 465 |
+
:,
|
| 466 |
+
reserve_frame_idx
|
| 467 |
+
* self.frame_shift_sample_length : frame_from_waveforms
|
| 468 |
+
* self.frame_shift_sample_length,
|
| 469 |
+
]
|
| 470 |
+
sample_length = (
|
| 471 |
+
frame_from_waveforms - 1
|
| 472 |
+
) * self.frame_shift_sample_length + self.frame_sample_length
|
| 473 |
+
cache["waveforms"] = cache["waveforms"][:, :sample_length]
|
| 474 |
+
else:
|
| 475 |
+
# update self.reserve_waveforms and self.lfr_splice_cache
|
| 476 |
+
cache["reserve_waveforms"] = cache["waveforms"][
|
| 477 |
+
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
|
| 478 |
+
]
|
| 479 |
+
for i in range(batch_size):
|
| 480 |
+
cache["lfr_splice_cache"][i] = torch.cat(
|
| 481 |
+
(cache["lfr_splice_cache"][i], feats[i]), dim=0
|
| 482 |
+
)
|
| 483 |
+
return torch.empty(0), feats_lengths
|
| 484 |
+
else:
|
| 485 |
+
if is_final:
|
| 486 |
+
cache["waveforms"] = (
|
| 487 |
+
waveforms
|
| 488 |
+
if cache["reserve_waveforms"].numel() == 0
|
| 489 |
+
else cache["reserve_waveforms"]
|
| 490 |
+
)
|
| 491 |
+
feats = torch.stack(cache["lfr_splice_cache"])
|
| 492 |
+
feats_lengths = (
|
| 493 |
+
torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
|
| 494 |
+
)
|
| 495 |
+
feats, feats_lengths, _ = self.forward_lfr_cmvn(
|
| 496 |
+
feats, feats_lengths, is_final, cache=cache
|
| 497 |
+
)
|
| 498 |
+
# if is_final:
|
| 499 |
+
# self.init_cache(cache)
|
| 500 |
+
return feats, feats_lengths
|
| 501 |
+
|
| 502 |
+
def init_cache(self, cache: dict = {}):
|
| 503 |
+
cache["reserve_waveforms"] = torch.empty(0)
|
| 504 |
+
cache["input_cache"] = torch.empty(0)
|
| 505 |
+
cache["lfr_splice_cache"] = []
|
| 506 |
+
cache["waveforms"] = None
|
| 507 |
+
cache["fbanks"] = None
|
| 508 |
+
cache["fbanks_lens"] = None
|
| 509 |
+
return cache
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class WavFrontendMel23(nn.Module):
|
| 513 |
+
"""Conventional frontend structure for ASR."""
|
| 514 |
+
|
| 515 |
+
def __init__(
|
| 516 |
+
self,
|
| 517 |
+
fs: int = 16000,
|
| 518 |
+
frame_length: int = 25,
|
| 519 |
+
frame_shift: int = 10,
|
| 520 |
+
lfr_m: int = 1,
|
| 521 |
+
lfr_n: int = 1,
|
| 522 |
+
**kwargs,
|
| 523 |
+
):
|
| 524 |
+
super().__init__()
|
| 525 |
+
self.fs = fs
|
| 526 |
+
self.frame_length = frame_length
|
| 527 |
+
self.frame_shift = frame_shift
|
| 528 |
+
self.lfr_m = lfr_m
|
| 529 |
+
self.lfr_n = lfr_n
|
| 530 |
+
self.n_mels = 23
|
| 531 |
+
|
| 532 |
+
def output_size(self) -> int:
|
| 533 |
+
return self.n_mels * (2 * self.lfr_m + 1)
|
| 534 |
+
|
| 535 |
+
def forward(
|
| 536 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 537 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 538 |
+
batch_size = input.size(0)
|
| 539 |
+
feats = []
|
| 540 |
+
feats_lens = []
|
| 541 |
+
for i in range(batch_size):
|
| 542 |
+
waveform_length = input_lengths[i]
|
| 543 |
+
waveform = input[i][:waveform_length]
|
| 544 |
+
waveform = waveform.numpy()
|
| 545 |
+
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
|
| 546 |
+
mat = eend_ola_feature.transform(mat)
|
| 547 |
+
mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
|
| 548 |
+
mat = mat[:: self.lfr_n]
|
| 549 |
+
mat = torch.from_numpy(mat)
|
| 550 |
+
feat_length = mat.size(0)
|
| 551 |
+
feats.append(mat)
|
| 552 |
+
feats_lens.append(feat_length)
|
| 553 |
+
|
| 554 |
+
feats_lens = torch.as_tensor(feats_lens)
|
| 555 |
+
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
|
| 556 |
+
return feats_pad, feats_lens
|
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# 2020, Technische Universität München; Ludwig Kürzinger
|
| 3 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 4 |
+
|
| 5 |
+
"""Sliding Window for raw audio input data."""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SlidingWindow(nn.Module):
|
| 13 |
+
"""Sliding Window.
|
| 14 |
+
Provides a sliding window over a batched continuous raw audio tensor.
|
| 15 |
+
Optionally, provides padding (Currently not implemented).
|
| 16 |
+
Combine this module with a pre-encoder compatible with raw audio data,
|
| 17 |
+
for example Sinc convolutions.
|
| 18 |
+
Known issues:
|
| 19 |
+
Output length is calculated incorrectly if audio shorter than win_length.
|
| 20 |
+
WARNING: trailing values are discarded - padding not implemented yet.
|
| 21 |
+
There is currently no additional window function applied to input values.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
win_length: int = 400,
|
| 27 |
+
hop_length: int = 160,
|
| 28 |
+
channels: int = 1,
|
| 29 |
+
padding: int = None,
|
| 30 |
+
fs=None,
|
| 31 |
+
):
|
| 32 |
+
"""Initialize.
|
| 33 |
+
Args:
|
| 34 |
+
win_length: Length of frame.
|
| 35 |
+
hop_length: Relative starting point of next frame.
|
| 36 |
+
channels: Number of input channels.
|
| 37 |
+
padding: Padding (placeholder, currently not implemented).
|
| 38 |
+
fs: Sampling rate (placeholder for compatibility, not used).
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.fs = fs
|
| 42 |
+
self.win_length = win_length
|
| 43 |
+
self.hop_length = hop_length
|
| 44 |
+
self.channels = channels
|
| 45 |
+
self.padding = padding
|
| 46 |
+
|
| 47 |
+
def forward(
|
| 48 |
+
self, input: torch.Tensor, input_lengths: torch.Tensor
|
| 49 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 50 |
+
"""Apply a sliding window on the input.
|
| 51 |
+
Args:
|
| 52 |
+
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
|
| 53 |
+
input_lengths: Input lengths within batch.
|
| 54 |
+
Returns:
|
| 55 |
+
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
|
| 56 |
+
Tensor: Output lengths within batch.
|
| 57 |
+
"""
|
| 58 |
+
input_size = input.size()
|
| 59 |
+
B = input_size[0]
|
| 60 |
+
T = input_size[1]
|
| 61 |
+
C = self.channels
|
| 62 |
+
D = self.win_length
|
| 63 |
+
# (B, T, C) --> (T, B, C)
|
| 64 |
+
continuous = input.view(B, T, C).permute(1, 0, 2)
|
| 65 |
+
windowed = continuous.unfold(0, D, self.hop_length)
|
| 66 |
+
# (T, B, C, D) --> (B, T, C, D)
|
| 67 |
+
output = windowed.permute(1, 0, 2, 3).contiguous()
|
| 68 |
+
# After unfold(), windowed lengths change:
|
| 69 |
+
output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
|
| 70 |
+
return output, output_lengths
|
| 71 |
+
|
| 72 |
+
def output_size(self) -> int:
|
| 73 |
+
"""Return output length of feature dimension D, i.e. the window length."""
|
| 74 |
+
return self.win_length
|
|
File without changes
|
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Label smoothing module."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LabelSmoothingLoss(nn.Module):
|
| 15 |
+
"""Label-smoothing loss.
|
| 16 |
+
|
| 17 |
+
:param int size: the number of class
|
| 18 |
+
:param int padding_idx: ignored class id
|
| 19 |
+
:param float smoothing: smoothing rate (0.0 means the conventional CE)
|
| 20 |
+
:param bool normalize_length: normalize loss by sequence length if True
|
| 21 |
+
:param torch.nn.Module criterion: loss function to be smoothed
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
size,
|
| 27 |
+
padding_idx,
|
| 28 |
+
smoothing,
|
| 29 |
+
normalize_length=False,
|
| 30 |
+
criterion=nn.KLDivLoss(reduction="none"),
|
| 31 |
+
):
|
| 32 |
+
"""Construct an LabelSmoothingLoss object."""
|
| 33 |
+
super(LabelSmoothingLoss, self).__init__()
|
| 34 |
+
self.criterion = criterion
|
| 35 |
+
self.padding_idx = padding_idx
|
| 36 |
+
self.confidence = 1.0 - smoothing
|
| 37 |
+
self.smoothing = smoothing
|
| 38 |
+
self.size = size
|
| 39 |
+
self.true_dist = None
|
| 40 |
+
self.normalize_length = normalize_length
|
| 41 |
+
|
| 42 |
+
def forward(self, x, target):
|
| 43 |
+
"""Compute loss between x and target.
|
| 44 |
+
|
| 45 |
+
:param torch.Tensor x: prediction (batch, seqlen, class)
|
| 46 |
+
:param torch.Tensor target:
|
| 47 |
+
target signal masked with self.padding_id (batch, seqlen)
|
| 48 |
+
:return: scalar float value
|
| 49 |
+
:rtype torch.Tensor
|
| 50 |
+
"""
|
| 51 |
+
assert x.size(2) == self.size
|
| 52 |
+
batch_size = x.size(0)
|
| 53 |
+
x = x.view(-1, self.size)
|
| 54 |
+
target = target.view(-1)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
true_dist = x.clone()
|
| 57 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
| 58 |
+
ignore = target == self.padding_idx # (B,)
|
| 59 |
+
total = len(target) - ignore.sum().item()
|
| 60 |
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
| 61 |
+
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
| 62 |
+
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
| 63 |
+
denom = total if self.normalize_length else batch_size
|
| 64 |
+
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class SequenceBinaryCrossEntropy(nn.Module):
|
| 68 |
+
def __init__(
|
| 69 |
+
self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none")
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.normalize_length = normalize_length
|
| 73 |
+
self.criterion = criterion
|
| 74 |
+
|
| 75 |
+
def forward(self, pred, label, lengths):
|
| 76 |
+
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
|
| 77 |
+
loss = self.criterion(pred, label)
|
| 78 |
+
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
|
| 79 |
+
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class NllLoss(nn.Module):
|
| 83 |
+
"""Nll loss.
|
| 84 |
+
|
| 85 |
+
:param int size: the number of class
|
| 86 |
+
:param int padding_idx: ignored class id
|
| 87 |
+
:param bool normalize_length: normalize loss by sequence length if True
|
| 88 |
+
:param torch.nn.Module criterion: loss function
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
size,
|
| 94 |
+
padding_idx,
|
| 95 |
+
normalize_length=False,
|
| 96 |
+
criterion=nn.NLLLoss(reduction="none"),
|
| 97 |
+
):
|
| 98 |
+
"""Construct an NllLoss object."""
|
| 99 |
+
super(NllLoss, self).__init__()
|
| 100 |
+
self.criterion = criterion
|
| 101 |
+
self.padding_idx = padding_idx
|
| 102 |
+
self.size = size
|
| 103 |
+
self.true_dist = None
|
| 104 |
+
self.normalize_length = normalize_length
|
| 105 |
+
|
| 106 |
+
def forward(self, x, target):
|
| 107 |
+
"""Compute loss between x and target.
|
| 108 |
+
|
| 109 |
+
:param torch.Tensor x: prediction (batch, seqlen, class)
|
| 110 |
+
:param torch.Tensor target:
|
| 111 |
+
target signal masked with self.padding_id (batch, seqlen)
|
| 112 |
+
:return: scalar float value
|
| 113 |
+
:rtype torch.Tensor
|
| 114 |
+
"""
|
| 115 |
+
assert x.size(2) == self.size
|
| 116 |
+
batch_size = x.size(0)
|
| 117 |
+
x = x.view(-1, self.size)
|
| 118 |
+
target = target.view(-1)
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
ignore = target == self.padding_idx # (B,)
|
| 121 |
+
total = len(target) - ignore.sum().item()
|
| 122 |
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
| 123 |
+
kl = self.criterion(x, target)
|
| 124 |
+
denom = total if self.normalize_length else batch_size
|
| 125 |
+
return kl.masked_fill(ignore, 0).sum() / denom
|