File size: 3,982 Bytes
e13d732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122


import lightning_module
import torch
import torchaudio
import unittest

class Score:
    """Predicting score for each audio clip."""

    def __init__(
        self,
        ckpt_path: str = "epoch=3-step=7459.ckpt",
        input_sample_rate: int = 16000,
        device: str = "cpu"):
        """
        Args:
            ckpt_path: path to pretrained checkpoint of UTMOS strong learner.
            input_sample_rate: sampling rate of input audio tensor. The input audio tensor
                is automatically downsampled to 16kHz.
        """
        print(f"Using device: {device}")
        self.device = device
        self.model = lightning_module.BaselineLightningModule.load_from_checkpoint(
            ckpt_path).eval().to(device)
        self.in_sr = input_sample_rate
        self.resampler = torchaudio.transforms.Resample(
            orig_freq=input_sample_rate,
            new_freq=16000,
            resampling_method="sinc_interpolation",
            lowpass_filter_width=6,
            dtype=torch.float32,
        ).to(device)
    
    def score(self, wavs: torch.tensor) -> torch.tensor:
        """
        Args:
            wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
                the model processes the input as a single audio clip. The model
                performs batch processing when len(wavs) == 3. 
        """
        if len(wavs.shape) == 1:
            out_wavs = wavs.unsqueeze(0).unsqueeze(0)
        elif len(wavs.shape) == 2:
            out_wavs = wavs.unsqueeze(0)
        elif len(wavs.shape) == 3:
            out_wavs = wavs
        else:
            raise ValueError('Dimension of input tensor needs to be <= 3.')
        if self.in_sr != 16000:
            out_wavs = self.resampler(out_wavs)
        bs = out_wavs.shape[0]
        batch = {
            'wav': out_wavs,
            'domains': torch.zeros(bs, dtype=torch.int).to(self.device),
            'judge_id': torch.ones(bs, dtype=torch.int).to(self.device)*288
        }
        with torch.no_grad():
            output = self.model(batch)
        
        return output.mean(dim=1).squeeze(1).cpu().detach().numpy()*2 + 3


class TestFunc(unittest.TestCase):
    """Test class."""

    def test_1dim_0(self):
        scorer = Score(input_sample_rate=16000)
        seq_len = 10000
        inp_audio = torch.ones(seq_len)
        pred = scorer.score(inp_audio)
        self.assertGreaterEqual(pred, 0.)
        self.assertLessEqual(pred, 5.)

    def test_1dim_1(self):
        scorer = Score(input_sample_rate=24000)
        seq_len = 10000
        inp_audio = torch.ones(seq_len)
        pred = scorer.score(inp_audio)
        self.assertGreaterEqual(pred, 0.)
        self.assertLessEqual(pred, 5.)

    def test_2dim_0(self):
        scorer = Score(input_sample_rate=16000)
        seq_len = 10000
        inp_audio = torch.ones(1, seq_len)
        pred = scorer.score(inp_audio)
        self.assertGreaterEqual(pred, 0.)
        self.assertLessEqual(pred, 5.)

    def test_2dim_1(self):
        scorer = Score(input_sample_rate=24000)
        seq_len = 10000
        inp_audio = torch.ones(1, seq_len)
        pred = scorer.score(inp_audio)
        print(pred)
        print(pred.shape)
        self.assertGreaterEqual(pred, 0.)
        self.assertLessEqual(pred, 5.)

    def test_3dim_0(self):
        scorer = Score(input_sample_rate=16000)
        seq_len = 10000
        batch = 8
        inp_audio = torch.ones(batch, 1, seq_len)
        pred = scorer.score(inp_audio)
        for p in pred:
            self.assertGreaterEqual(p, 0.)
            self.assertLessEqual(p, 5.)

    def test_3dim_1(self):
        scorer = Score(input_sample_rate=24000)
        seq_len = 10000
        batch = 8
        inp_audio = torch.ones(batch, 1, seq_len)
        pred = scorer.score(inp_audio)
        for p in pred:
            self.assertGreaterEqual(p, 0.)
            self.assertLessEqual(p, 5.)

if __name__ == '__main__':
    unittest.main()