Final_Assignment_Template / audio_tools.py
huytofu92's picture
Prompts engineering!
204b035
import base64
import os
from langchain_core.tools import tool as langchain_tool
from smolagents.tools import Tool, tool
from pydub import AudioSegment
from pyAudioAnalysis import audioSegmentation as aS
from io import BytesIO
from huggingface_hub import InferenceClient
class TranscribeAudioTool(Tool):
name = "transcribe_audio"
description = "Transcribe an audio file (in base64 format or as an AudioSegment object)"
inputs = {
"audio": {"type": "any", "description": "The audio file in base64 format or as an AudioSegment object only"}
}
output_type = "string"
def setup(self):
self.model = InferenceClient(model="openai/whisper-large-v3", provider="hf-inference", token=os.getenv("HUGGINGFACE_API_KEY"))
def _convert_audio_segment_to_wav(self, audio_segment: AudioSegment) -> bytes:
"""Convert AudioSegment to WAV format bytes"""
try:
# Ensure audio is in the correct format for Whisper
# Convert to mono if stereo
if audio_segment.channels > 1:
audio_segment = audio_segment.set_channels(1)
# Convert to 16kHz if different sample rate
if audio_segment.frame_rate != 16000:
audio_segment = audio_segment.set_frame_rate(16000)
# Convert to 16-bit if different bit depth
if audio_segment.sample_width != 2: # 2 bytes = 16 bits
audio_segment = audio_segment.set_sample_width(2)
# Export to WAV format
buffer = BytesIO()
audio_segment.export(buffer, format="wav")
return buffer.getvalue()
except Exception as e:
raise RuntimeError(f"Error converting audio segment: {str(e)}")
def forward(self, audio: any) -> str:
try:
# Handle AudioSegment object
if isinstance(audio, AudioSegment):
# Direct conversion to WAV bytes with proper format
audio_data = self._convert_audio_segment_to_wav(audio)
# Handle base64 string
elif isinstance(audio, str):
try:
# Decode base64 and convert to AudioSegment for format standardization
audio_data = base64.b64decode(audio)
audio_segment = AudioSegment.from_file(BytesIO(audio_data))
# Convert to proper format for Whisper
audio_data = self._convert_audio_segment_to_wav(audio_segment)
except Exception as e:
raise ValueError(f"Invalid base64 audio data: {str(e)}")
else:
raise ValueError(f"Unsupported audio type: {type(audio)}. Expected base64 string or AudioSegment object.")
# Transcribe using the model
try:
result = self.model.automatic_speech_recognition(audio_data)
return result["text"]
except Exception as e:
raise RuntimeError(f"Error in transcription: {str(e)}")
except Exception as e:
raise RuntimeError(f"Error in transcription: {str(e)}")
transcribe_audio_tool = TranscribeAudioTool()
@tool
def get_audio_from_file_path(file_path: str) -> str:
"""
Load an audio file from a file path and convert it to a base64 string
Args:
file_path: Path to the audio file (should be in mp3 format)
Returns:
The audio file in base64 format
"""
# Load the audio file
try:
audio = AudioSegment.from_file(file_path)
except Exception as e:
current_file_path = os.path.abspath(__file__)
current_file_dir = os.path.dirname(current_file_path)
file_path = os.path.join(current_file_dir, file_path)
audio = AudioSegment.from_file(file_path)
# Export the audio to a BytesIO object
buffer = BytesIO()
audio.export(buffer, format="wav") # You can change the format if needed
# Encode the audio data to base64
audio_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
return audio_base64
@tool
def noise_reduction(audio: str) -> str:
"""
Reduce noise from an audio file
Args:
audio: The audio file in base64 format
Returns:
The denoised audio file in base64 format
"""
# Decode the base64 audio
audio_data = base64.b64decode(audio)
audio_segment = AudioSegment.from_file(BytesIO(audio_data))
# Apply noise reduction (simple example using low-pass filter)
denoised_audio = audio_segment.low_pass_filter(3000)
# Encode back to base64
buffer = BytesIO()
denoised_audio.export(buffer, format="wav")
return base64.b64encode(buffer.getvalue()).decode('utf-8')
@tool
def audio_segmentation(audio: str, segment_length: int = 30) -> list:
"""
Segment an audio file into smaller chunks
Args:
audio: The audio file in base64 format
segment_length: Length of each segment in seconds
Returns:
List of audio segments in base64 format. Each of these segments can be used as input for the `transcribe_audio` tool.
"""
# Decode the base64 audio
audio_data = base64.b64decode(audio)
audio_segment = AudioSegment.from_file(BytesIO(audio_data))
# Segment the audio
segments = []
for i in range(0, len(audio_segment), segment_length * 1000):
segment = audio_segment[i:i + segment_length * 1000]
buffer = BytesIO()
segment.export(buffer, format="wav")
segments.append(base64.b64encode(buffer.getvalue()).decode('utf-8'))
return segments
@tool
def speaker_diarization(audio: str) -> list:
"""
Diarize an audio file into speakers
Args:
audio: The audio file in base64 format
Returns:
List of speaker segments
"""
# Decode the base64 audio
audio_data = base64.b64decode(audio)
audio_buffer = BytesIO(audio_data)
# Create a temporary BytesIO object for processing
temp_buffer = BytesIO()
audio_segment = AudioSegment.from_file(audio_buffer)
audio_segment.export(temp_buffer, format="wav")
temp_buffer.seek(0)
# Perform speaker diarization using the buffer
[flags, classes, centers] = aS.speakerDiarization(temp_buffer, 2) # Assuming 2 speakers
# Process the output
speaker_segments = []
for i, flag in enumerate(flags):
speaker_segments.append((i, flag))
return speaker_segments