File size: 3,565 Bytes
5dfda47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ad99ca
5dfda47
598e684
 
 
 
 
 
 
 
 
 
 
 
 
5dfda47
 
 
 
 
 
 
 
 
 
 
 
14395b7
 
 
5dfda47
 
 
 
 
 
 
 
 
 
 
 
 
 
c32b67b
4ad99ca
 
5dfda47
14395b7
 
5dfda47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598e684
5dfda47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

import torch
from datasets import load_dataset
from transformers import AutoFeatureExtractor
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

from pipeline_utils import compute_speaker_stats, plot_reconstruction





def main():


    dataset = load_dataset(
        "sanchit-gandhi/voxpopuli_dummy",
        # "train",
        split="validation"
    )


    #dataset = load_dataset(
    #    "mythicinfinity/libritts", 
    #    "clean",
    #    split="test.clean",
    #    #trust_remote_code=True
    #)

    # dataset = load_dataset(
    #     "mythicinfinity/libritts", 
    #     "clean",
    #     split="test.clean"
    # )
    # dataset = load_dataset(
    #     "facebook/voxpopuli",
    #     "en",
    #     split="test"
    # )


    preprocessor = AutoFeatureExtractor.from_pretrained('MU-NLPC/F0_Energy_joint_VQVAE_embeddings-preprocessor', 
                                                        #trust_remote_code=True
                                                       )

    processed_dataset = dataset.map(
        lambda x: preprocessor.extract_features(x['audio']['array']),
        load_from_cache_file=False,
        # num_proc=4
    )

    processed_dataset.save_to_disk("processed_dataset")

    speaker_stats = compute_speaker_stats(processed_dataset)
    torch.save(speaker_stats, "speaker_stats.pt")


    from transformers import pipeline
    embedding_pipeline = pipeline(task="prosody-embedding", model="MU-NLPC/F0_Energy_joint_VQVAE_embeddings-interp", 
                          f0_interp=True,
                          f0_normalize=False,
                          speaker_stats=speaker_stats,
                          #trust_remote_code=True
                                 )


    results = processed_dataset.map(
        lambda x: embedding_pipeline(x),
        remove_columns=processed_dataset.column_names,
        load_from_cache_file=False
        # num_proc=4
    )
    
    results.save_to_disk("embeddings_dataset")

    print(f"Processed {len(results)} samples")
    
    embedding_codebook = embedding_pipeline.model.vq.level_blocks[0].k
    print(f"embedding_codebook.shape", embedding_codebook.shape)

    embeddings_example = results[0]['codes'][0][0]
    print("Embeddings example:", embeddings_example)

    
    # inspect the embeddings in the codebook as follows

    # code_point = embeddings_example[0]
    # print(f"code_point", code_point)
    # code_point_embedding = embedding_codebook[code_point]
    # print(f"code_point_embedding", code_point_embedding)
    # print(f"code_point_embedding.shape", code_point_embedding.shape)


    # check that they are the same as the hidden states used in the model

    # hidden_states = np.array(results[0]['hidden_states'])
    # hidden_state = hidden_states[0, 0, :, 0]
    # print(f"hidden_state", hidden_state)
    
    metrics_list = [result['metrics'] for result in results]
    avg_metrics = {}
    
    for metric in results[0]['metrics'].keys():
        values = [m[metric] for m in metrics_list]
        avg_metrics[metric] = sum(values) / len(values)
        # print(f"metric", metric)
        # print(f"len(values)", len(values))
    
    print("\nAverage metrics across dataset:")
    print(avg_metrics)


    print(f"Plotting reconstruction curves...")
    for i in tqdm(range(len(results))):
        fig = plot_reconstruction(results[i], i)
        os.makedirs('plots', exist_ok=True)
        plt.savefig(f'plots/reconstruction_sample{i}.png')
        plt.close()
    print(f"Done.")


if __name__ == '__main__':
    main()