File size: 3,426 Bytes
9558cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8da27d
 
 
 
 
 
9558cab
471d1d7
 
 
 
 
 
9558cab
 
 
 
 
 
 
ed4e382
 
 
9558cab
 
 
 
 
 
 
 
 
 
 
 
 
 
47fbb2d
162b3af
 
9558cab
ed4e382
 
9558cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8da27d
9558cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122



import torch
from datasets import load_dataset
from transformers import AutoFeatureExtractor
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

from pipeline_utils import compute_speaker_stats, plot_reconstruction





def main():
    dataset = load_dataset(
        "sanchit-gandhi/voxpopuli_dummy",
        # "train",
        split="validation"
    )


    #dataset = load_dataset(
    #    "mythicinfinity/libritts", 
    #    "clean",
    #    split="test.clean",
    #    #trust_remote_code=True
    #)
    # dataset = load_dataset(
    #     "facebook/voxpopuli",
    #     "en",
    #     split="test"
    # )


    preprocessor = AutoFeatureExtractor.from_pretrained('MU-NLPC/F0_Energy_joint_VQVAE_embeddings-preprocessor', 
                                                        #trust_remote_code=True
                                                       )

    processed_dataset = dataset.map(
        lambda x: preprocessor.extract_features(x['audio']['array']),
        load_from_cache_file=False,
        # num_proc=4
    )

    processed_dataset.save_to_disk("processed_dataset")

    speaker_stats = compute_speaker_stats(processed_dataset)
    torch.save(speaker_stats, "speaker_stats.pt")


    from transformers import pipeline
    embedding_pipeline = pipeline(task="prosody-embedding", model="MU-NLPC/F0_Energy_joint_VQVAE_embeddings", 
                          f0_interp=False,
                          f0_normalize=True,
                          speaker_stats=speaker_stats,
                          #trust_remote_code=True
                        )


    results = processed_dataset.map(
        lambda x: embedding_pipeline(x),
        remove_columns=processed_dataset.column_names,
        load_from_cache_file=False
        # num_proc=4
    )
    
    results.save_to_disk("embeddings_dataset")

    print(f"Processed {len(results)} samples")
    
    embedding_codebook = embedding_pipeline.model.vq.level_blocks[0].k
    print(f"embedding_codebook.shape", embedding_codebook.shape)

    embeddings_example = results[0]['codes'][0][0]
    print("Embeddings example:", embeddings_example)

    
    # inspect the embeddings in the codebook as follows

    # code_point = embeddings_example[0]
    # print(f"code_point", code_point)
    # code_point_embedding = embedding_codebook[code_point]
    # print(f"code_point_embedding", code_point_embedding)
    # print(f"code_point_embedding.shape", code_point_embedding.shape)


    # check that they are the same as the hidden states used in the model

    # hidden_states = np.array(results[0]['hidden_states'])
    # hidden_state = hidden_states[0, 0, :, 0]
    # print(f"hidden_state", hidden_state)
    
    metrics_list = [result['metrics'] for result in results]
    avg_metrics = {}
    
    for metric in results[0]['metrics'].keys():
        values = [m[metric] for m in metrics_list]
        avg_metrics[metric] = sum(values) / len(values)
        # print(f"metric", metric)
        # print(f"len(values)", len(values))
    
    print("\nAverage metrics across dataset:")
    print(avg_metrics)


    print(f"Plotting reconstruction curves...")
    for i in tqdm(range(len(results))):
        fig = plot_reconstruction(results[i], i)
        os.makedirs('plots', exist_ok=True)
        plt.savefig(f'plots/reconstruction_sample{i}.png')
        plt.close()
    print(f"Done.")


if __name__ == '__main__':
    main()