File size: 4,180 Bytes
260b46d
 
 
 
 
 
 
93b48cb
260b46d
 
 
 
 
 
 
 
9fbfaa6
260b46d
 
 
 
 
 
 
93b48cb
 
260b46d
 
 
 
93b48cb
 
260b46d
93b48cb
260b46d
 
 
 
 
 
 
 
 
 
 
84d4ed6
260b46d
 
93b48cb
260b46d
84d4ed6
260b46d
9fbfaa6
260b46d
 
 
84d4ed6
 
 
260b46d
 
ac059f4
c1b9ba0
260b46d
 
 
c1b9ba0
 
260b46d
3815be3
 
 
93b48cb
 
 
 
 
 
 
 
ac059f4
93b48cb
 
c1b9ba0
260b46d
ac059f4
260b46d
 
ac059f4
 
 
 
260b46d
 
 
275afd0
260b46d
 
 
275afd0
260b46d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from pathlib import Path
import os
from functools import partial

from frechet_audio_distance import FrechetAudioDistance
import pandas
import argbind
import torch
from tqdm import tqdm

import audiotools
from audiotools import AudioSignal

@argbind.bind(without_prefix=True)
def eval(
    exp_dir: str = None,
    baseline_key: str = "baseline", 
    audio_ext: str = ".wav",
):
    assert exp_dir is not None
    exp_dir = Path(exp_dir)
    assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"

    # set up our metrics
    # sisdr_loss = audiotools.metrics.distance.SISDRLoss()
    # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
    mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
    frechet = FrechetAudioDistance(
        use_pca=False, 
        use_activation=False,
        verbose=True, 
        audio_load_worker=4,
    )
    frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")

    # figure out what conditions we have
    conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]

    assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
    conditions.remove(baseline_key)

    print(f"Found {len(conditions)} conditions in {exp_dir}")
    print(f"conditions: {conditions}")

    baseline_dir = exp_dir / baseline_key 
    baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))

    metrics = []
    for condition in tqdm(conditions):
        cond_dir = exp_dir / condition
        cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))

        print(f"computing fad for {baseline_dir} and {cond_dir}")
        frechet_score = frechet.score(baseline_dir, cond_dir)

        # make sure we have the same number of files
        num_files = min(len(baseline_files), len(cond_files))
        baseline_files = baseline_files[:num_files]
        cond_files = cond_files[:num_files]
        assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"

        def process(baseline_file, cond_file):
            # make sure the files match (same name)
            assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"

            # load the files
            baseline_sig = AudioSignal(str(baseline_file))
            cond_sig = AudioSignal(str(cond_file))

            cond_sig.resample(baseline_sig.sample_rate)
            cond_sig.truncate_samples(baseline_sig.length)

            # if our condition is inpainting, we need to trim the conditioning off
            if "inpaint" in condition:
                ctx_amt = float(condition.split("_")[-1])
                ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
                print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
                cond_sig.trim(ctx_samples, ctx_samples)
                baseline_sig.trim(ctx_samples, ctx_samples)

            return {
                # "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
                # "stft": stft_loss(baseline_sig, cond_sig).item(),
                "mel": mel_loss(baseline_sig, cond_sig).item(),
                "frechet": frechet_score,
                # "visqol": vsq,
                "condition": condition,
                "file": baseline_file.stem,
            }

        print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
        metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))

    metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]


    for mk in metric_keys:
        stat = pandas.DataFrame(metrics)
        stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
        stat.to_csv(exp_dir / f"stats-{mk}.csv")

    df = pandas.DataFrame(metrics)
    df.to_csv(exp_dir / "metrics-all.csv", index=False)


if __name__ == "__main__":
    args = argbind.parse_args()

    with argbind.scope(args):
        eval()