File size: 4,711 Bytes
ea174b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging
import torchaudio
import os
import sys
import glob
import debugpy
import torch
import numpy as np
import re

def count_params_by_module(model_name, model):
    logging.info(f"Counting num_parameters of {model_name}:")
    
    param_stats = {}
    total_params = 0  # Count total parameters
    total_requires_grad_params = 0  # Count parameters with requires_grad=True
    total_no_grad_params = 0  # Count parameters with requires_grad=False
    
    for name, param in model.named_parameters():
        module_name = name.split('.')[0]
        if module_name not in param_stats:
            param_stats[module_name] = {'total': 0, 'requires_grad': 0, 'no_grad': 0}
        
        param_num = param.numel()
        param_stats[module_name]['total'] += param_num
        total_params += param_num
        
        if param.requires_grad:
            param_stats[module_name]['requires_grad'] += param_num
            total_requires_grad_params += param_num
        else:
            param_stats[module_name]['no_grad'] += param_num
            total_no_grad_params += param_num
    
    # Calculate maximum width for each column
    max_module_name_length = max(len(module) for module in param_stats)
    max_param_length = max(len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values())
    
    # Output parameter statistics for each module
    for module, stats in param_stats.items():
        logging.info(f"\t{module:<{max_module_name_length}}: "
                     f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, "
                     f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, "
                     f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M")
    
    # Output total parameter statistics
    logging.info(f"\tTotal parameters: {total_params / 1e6:.2f}M parameters")
    logging.info(f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters")
    logging.info(f"\tNo Grad parameters: {total_no_grad_params / 1e6:.2f}M parameters")
    logging.info(f"################################################################")


def load_and_resample_audio(audio_path, target_sample_rate):
    wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T)   tensor 
    if raw_sample_rate != target_sample_rate:   
        wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # tensor 
    return wav.squeeze()

def set_logging():
    rank = os.environ.get("RANK", 0)
    logging.basicConfig(
        level=logging.INFO,
        stream=sys.stdout,
        format=f"%(asctime)s [RANK {rank}] (%(module)s:%(lineno)d) %(levelname)s : %(message)s",
    )
    
def waiting_for_debug(ip, port):
    rank = os.environ.get("RANK", "0")
    debugpy.listen((ip, port)) # Replace localhost with cluster node IP
    logging.info(f"[rank = {rank}] Waiting for debugger attach...")
    debugpy.wait_for_client()
    logging.info(f"[rank = {rank}] Debugger attached")
    
def load_audio(audio_path, target_sample_rate):
    # Load audio file, wav shape: (channels, time)
    wav, raw_sample_rate = torchaudio.load(audio_path)
    
    # If multi-channel, convert to mono by averaging across channels
    if wav.shape[0] > 1:
        wav = torch.mean(wav, dim=0, keepdim=True)  # Average across channels, keep channel dim
    
    # Resample if necessary
    if raw_sample_rate != target_sample_rate:
        wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate)
    
    # Convert to numpy, add channel dimension, then back to tensor with desired shape
    wav = np.expand_dims(wav.squeeze(0).numpy(), axis=1)  # Shape: (time, 1)
    wav = torch.tensor(wav).reshape(1, 1, -1)  # Shape: (1, 1, time)
    
    return wav

def save_audio(audio_outpath, audio_out, sample_rate):
    torchaudio.save(
        audio_outpath, 
        audio_out, 
        sample_rate=sample_rate, 
        encoding='PCM_S', 
        bits_per_sample=16
    )
    logging.info(f"Successfully saved audio at {audio_outpath}")
    
def find_audio_files(input_dir):
    audio_extensions = ['*.flac', '*.mp3', '*.wav']
    audios_input = []
    for ext in audio_extensions:
        audios_input.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True))
    logging.info(f"Found {len(audios_input)} audio files in {input_dir}")
    return sorted(audios_input)

def normalize_text(text):
    # Remove all punctuation (including English and Chinese punctuation)
    text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
    # Convert to lowercase (effective for English, no effect on Chinese)
    text = text.lower()
    # Remove extra spaces
    text = ' '.join(text.split())
    return text