File size: 3,269 Bytes
89040ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import sys
import re
from typing import Dict, List

import csv
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import pathlib
import librosa
import lightning.pytorch as pl
from models.clap_encoder import CLAP_Encoder

sys.path.append('../AudioSep/')
from utils import (
    load_ss_model,
    calculate_sdr,
    calculate_sisdr,
    parse_yaml,
    get_mean_sdr_from_dict,
)


class VGGSoundEvaluator:
    def __init__(
        self,
        sampling_rate=32000
    ) -> None:
        r"""VGGSound evaluator.

        Args:
            data_recipe (str): dataset split, 'yan' 
        Returns:
            None
        """

        self.sampling_rate = sampling_rate

        with open('evaluation/metadata/vggsound_eval.csv') as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            eval_list = [row for row in csv_reader][1:]
        
        self.eval_list = eval_list
        self.audio_dir = 'evaluation/data/vggsound'

    def __call__(
        self,
        pl_model: pl.LightningModule
    ) -> Dict:
        r"""Evalute."""

        print(f'Evaluation on VGGSound+ with [text label] queries.')
        
        pl_model.eval()
        device = pl_model.device

        sisdrs_list = []
        sdris_list = []
        sisdris_list = []
        

        with torch.no_grad():
            for eval_data in tqdm(self.eval_list):

                # labels, source_path, mixture_path = eval_data
                file_id, mix_wav, s0_wav, s0_text, s1_wav, s1_text = eval_data

                labels = s0_text

                mixture_path = os.path.join(self.audio_dir, mix_wav)
                source_path = os.path.join(self.audio_dir, s0_wav)


                source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
                mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)

                sdr_no_sep = calculate_sdr(ref=source, est=mixture)
                                
                text = [labels]
                conditions = pl_model.query_encoder.get_query_embed(
                    modality='text',
                    text=text,
                    device=device 
                )
                    
                input_dict = {
                    "mixture": torch.Tensor(mixture)[None, None, :].to(device),
                    "condition": conditions,
                }
                
                sep_segment = pl_model.ss_model(input_dict)["waveform"]
                    # sep_segment: (batch_size=1, channels_num=1, segment_samples)

                sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
                    # sep_segment: (segment_samples,)

                sdr = calculate_sdr(ref=source, est=sep_segment)
                sdri = sdr - sdr_no_sep

                sisdr_no_sep = calculate_sisdr(ref=source, est=mixture)
                sisdr = calculate_sisdr(ref=source, est=sep_segment)
                sisdri = sisdr - sisdr_no_sep

                sisdrs_list.append(sisdr)
                sdris_list.append(sdri)
                sisdris_list.append(sisdri)


        mean_sisdr = np.mean(sisdrs_list)
        mean_sdri = np.mean(sdris_list)

        return mean_sisdr, mean_sdri