File size: 6,502 Bytes
bd4eae4
55d6080
8a0cdd8
 
0db692a
 
 
55d6080
 
 
 
7b52294
55d6080
7b52294
55d6080
 
 
 
3c1a68d
55d6080
204b035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fdec9b
 
 
 
204b035
 
1fdec9b
 
204b035
 
 
 
 
 
 
 
1fdec9b
 
 
 
204b035
 
 
 
 
1fdec9b
 
 
55d6080
 
 
0db692a
3105d32
0db692a
3105d32
0db692a
1fdec9b
0db692a
 
 
 
30ffa0e
 
 
 
 
f6dabb0
30ffa0e
 
0db692a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204b035
0db692a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33b9b1f
 
 
 
 
 
 
 
 
 
0db692a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import base64
import os
from langchain_core.tools import tool as langchain_tool
from smolagents.tools import Tool, tool
from pydub import AudioSegment
from pyAudioAnalysis import audioSegmentation as aS
from io import BytesIO
from huggingface_hub import InferenceClient

class TranscribeAudioTool(Tool):
    name = "transcribe_audio"
    description = "Transcribe an audio file (in base64 format or as an AudioSegment object)"
    inputs = {
        "audio": {"type": "any", "description": "The audio file in base64 format or as an AudioSegment object only"}
    }
    output_type = "string"

    def setup(self):
        self.model = InferenceClient(model="openai/whisper-large-v3", provider="hf-inference", token=os.getenv("HUGGINGFACE_API_KEY"))

    def _convert_audio_segment_to_wav(self, audio_segment: AudioSegment) -> bytes:
        """Convert AudioSegment to WAV format bytes"""
        try:
            # Ensure audio is in the correct format for Whisper
            # Convert to mono if stereo
            if audio_segment.channels > 1:
                audio_segment = audio_segment.set_channels(1)
            
            # Convert to 16kHz if different sample rate
            if audio_segment.frame_rate != 16000:
                audio_segment = audio_segment.set_frame_rate(16000)
            
            # Convert to 16-bit if different bit depth
            if audio_segment.sample_width != 2:  # 2 bytes = 16 bits
                audio_segment = audio_segment.set_sample_width(2)
            
            # Export to WAV format
            buffer = BytesIO()
            audio_segment.export(buffer, format="wav")
            return buffer.getvalue()
        except Exception as e:
            raise RuntimeError(f"Error converting audio segment: {str(e)}")

    def forward(self, audio: any) -> str:
        try:
            # Handle AudioSegment object
            if isinstance(audio, AudioSegment):
                # Direct conversion to WAV bytes with proper format
                audio_data = self._convert_audio_segment_to_wav(audio)
            # Handle base64 string
            elif isinstance(audio, str):
                try:
                    # Decode base64 and convert to AudioSegment for format standardization
                    audio_data = base64.b64decode(audio)
                    audio_segment = AudioSegment.from_file(BytesIO(audio_data))
                    # Convert to proper format for Whisper
                    audio_data = self._convert_audio_segment_to_wav(audio_segment)
                except Exception as e:
                    raise ValueError(f"Invalid base64 audio data: {str(e)}")
            else:
                raise ValueError(f"Unsupported audio type: {type(audio)}. Expected base64 string or AudioSegment object.")
            
            # Transcribe using the model
            try:
                result = self.model.automatic_speech_recognition(audio_data)
                return result["text"]
            except Exception as e:
                raise RuntimeError(f"Error in transcription: {str(e)}")
            
        except Exception as e:
            raise RuntimeError(f"Error in transcription: {str(e)}")

transcribe_audio_tool = TranscribeAudioTool()

@tool
def get_audio_from_file_path(file_path: str) -> str:
    """
    Load an audio file from a file path and convert it to a base64 string
    Args:
        file_path: Path to the audio file (should be in mp3 format)
    Returns:
        The audio file in base64 format
    """
    # Load the audio file
    try:
        audio = AudioSegment.from_file(file_path)
    except Exception as e:
        current_file_path = os.path.abspath(__file__)
        current_file_dir = os.path.dirname(current_file_path)
        file_path = os.path.join(current_file_dir, file_path)
        audio = AudioSegment.from_file(file_path)
        
    # Export the audio to a BytesIO object
    buffer = BytesIO()
    audio.export(buffer, format="wav")  # You can change the format if needed

    # Encode the audio data to base64
    audio_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
    return audio_base64

@tool
def noise_reduction(audio: str) -> str:
    """
    Reduce noise from an audio file
    Args:
        audio: The audio file in base64 format
    Returns:
        The denoised audio file in base64 format
    """
    # Decode the base64 audio
    audio_data = base64.b64decode(audio)
    audio_segment = AudioSegment.from_file(BytesIO(audio_data))

    # Apply noise reduction (simple example using low-pass filter)
    denoised_audio = audio_segment.low_pass_filter(3000)

    # Encode back to base64
    buffer = BytesIO()
    denoised_audio.export(buffer, format="wav")
    return base64.b64encode(buffer.getvalue()).decode('utf-8')

@tool
def audio_segmentation(audio: str, segment_length: int = 30) -> list:
    """
    Segment an audio file into smaller chunks
    Args:
        audio: The audio file in base64 format
        segment_length: Length of each segment in seconds
    Returns:
        List of audio segments in base64 format. Each of these segments can be used as input for the `transcribe_audio` tool.
    """
    # Decode the base64 audio
    audio_data = base64.b64decode(audio)
    audio_segment = AudioSegment.from_file(BytesIO(audio_data))

    # Segment the audio
    segments = []
    for i in range(0, len(audio_segment), segment_length * 1000):
        segment = audio_segment[i:i + segment_length * 1000]
        buffer = BytesIO()
        segment.export(buffer, format="wav")
        segments.append(base64.b64encode(buffer.getvalue()).decode('utf-8'))

    return segments

@tool
def speaker_diarization(audio: str) -> list:
    """
    Diarize an audio file into speakers
    Args:
        audio: The audio file in base64 format
    Returns:
        List of speaker segments
    """
    # Decode the base64 audio
    audio_data = base64.b64decode(audio)
    audio_buffer = BytesIO(audio_data)
    
    # Create a temporary BytesIO object for processing
    temp_buffer = BytesIO()
    audio_segment = AudioSegment.from_file(audio_buffer)
    audio_segment.export(temp_buffer, format="wav")
    temp_buffer.seek(0)

    # Perform speaker diarization using the buffer
    [flags, classes, centers] = aS.speakerDiarization(temp_buffer, 2)  # Assuming 2 speakers

    # Process the output
    speaker_segments = []
    for i, flag in enumerate(flags):
        speaker_segments.append((i, flag))

    return speaker_segments