|
import os |
|
import torch |
|
import random |
|
import numpy as np |
|
|
|
|
|
|
|
def find_pt_files(root_dir): |
|
pt_files = [] |
|
for dirpath, _, filenames in os.walk(root_dir): |
|
for file in filenames: |
|
if file.endswith('.pt'): |
|
pt_files.append(os.path.join(dirpath, file)) |
|
return pt_files |
|
|
|
|
|
|
|
def compute_statistics(tensor_list): |
|
all_data = torch.cat(tensor_list) |
|
mean = torch.mean(all_data).item() |
|
std = torch.std(all_data).item() |
|
max_val = torch.max(all_data).item() |
|
min_val = torch.min(all_data).item() |
|
return mean, std, max_val, min_val |
|
|
|
|
|
|
|
root_dir = "spk" |
|
|
|
|
|
pt_files = find_pt_files(root_dir) |
|
|
|
|
|
sampled_files = random.sample(pt_files, min(1000, len(pt_files))) |
|
|
|
|
|
tensor_list = [] |
|
for file in sampled_files: |
|
tensor = torch.load(file) |
|
tensor_list.append(tensor.view(-1)) |
|
|
|
|
|
mean, std, max_val, min_val = compute_statistics(tensor_list) |
|
|
|
|
|
print(f"Mean: {mean}") |
|
print(f"Std: {std}") |
|
print(f"Max: {max_val}") |
|
print(f"Min: {min_val}") |