File size: 1,893 Bytes
1e0a8f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9f750
1e0a8f3
 
 
 
 
 
 
 
 
 
7d9f750
 
 
1e0a8f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9f750
1e0a8f3
7d9f750
 
 
1e0a8f3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
AudioClassifier class

Author: HenryAreiza
Date: 08/09/2023
"""

from scipy.io import wavfile
from scipy.signal import decimate
from transformers import pipeline

class AudioClassifier:
    """
    A class for classifying audio commands using a pre-trained model.

    This class provides functionality for classifying audio commands based on
    a pre-trained audio classification model.

    Attributes:
        vocab (list): Vocabulary of valid commands
        commands (list): List of corresponding mouse actions
        pipe: The Hugging Face Transformers pipeline for audio classification.
    """

    def __init__(self):
        """
        Initializes the AudioClassifier class.
        """
        self.vocab = ["left", "right", "up", "down", "go", "follow",
                      "on", "off", "one", "two", "three", "stop"]

        self.commands = ["left click", "right click", "scroll up", "scroll down", "double click", "sustained click", "enable cursor movement",
                         "disable cursor movement", "slow cursor speed", "medium cursor speed", "fast cursor speed", "finish the application"]

        # Load the audio classification pipeline
        self.pipe = pipeline("audio-classification", model="0xb1/wav2vec2-base-finetuned-speech_commands-v0.02")

    def predict(self, audio_path):
        """
        Classify audio data into a command label.

        Args:
            audio_data (numpy.ndarray): Input audio data.

        Returns:
            result (str): The classified command label.
        """
        _, audio = wavfile.read(audio_path)
        audio = decimate(audio, 3)
        result = self.pipe(audio)[0]["label"]
        
        if result not in self.vocab:
            result = 'unknown command'
        else:
            result = result + ' ---> ' + '(' + self.commands[self.vocab.index(result)] + ')'

        return result