Spaces:
Sleeping
Sleeping
import base64 | |
import os | |
from langchain_core.tools import tool as langchain_tool | |
from smolagents.tools import Tool, tool | |
from pydub import AudioSegment | |
from pyAudioAnalysis import audioSegmentation as aS | |
from io import BytesIO | |
from huggingface_hub import InferenceClient | |
class TranscribeAudioTool(Tool): | |
name = "transcribe_audio" | |
description = "Transcribe an audio file (in base64 format or as an AudioSegment object)" | |
inputs = { | |
"audio": {"type": "any", "description": "The audio file in base64 format or as an AudioSegment object only"} | |
} | |
output_type = "string" | |
def setup(self): | |
self.model = InferenceClient(model="openai/whisper-large-v3", provider="hf-inference", token=os.getenv("HUGGINGFACE_API_KEY")) | |
def _convert_audio_segment_to_wav(self, audio_segment: AudioSegment) -> bytes: | |
"""Convert AudioSegment to WAV format bytes""" | |
try: | |
# Ensure audio is in the correct format for Whisper | |
# Convert to mono if stereo | |
if audio_segment.channels > 1: | |
audio_segment = audio_segment.set_channels(1) | |
# Convert to 16kHz if different sample rate | |
if audio_segment.frame_rate != 16000: | |
audio_segment = audio_segment.set_frame_rate(16000) | |
# Convert to 16-bit if different bit depth | |
if audio_segment.sample_width != 2: # 2 bytes = 16 bits | |
audio_segment = audio_segment.set_sample_width(2) | |
# Export to WAV format | |
buffer = BytesIO() | |
audio_segment.export(buffer, format="wav") | |
return buffer.getvalue() | |
except Exception as e: | |
raise RuntimeError(f"Error converting audio segment: {str(e)}") | |
def forward(self, audio: any) -> str: | |
try: | |
# Handle AudioSegment object | |
if isinstance(audio, AudioSegment): | |
# Direct conversion to WAV bytes with proper format | |
audio_data = self._convert_audio_segment_to_wav(audio) | |
# Handle base64 string | |
elif isinstance(audio, str): | |
try: | |
# Decode base64 and convert to AudioSegment for format standardization | |
audio_data = base64.b64decode(audio) | |
audio_segment = AudioSegment.from_file(BytesIO(audio_data)) | |
# Convert to proper format for Whisper | |
audio_data = self._convert_audio_segment_to_wav(audio_segment) | |
except Exception as e: | |
raise ValueError(f"Invalid base64 audio data: {str(e)}") | |
else: | |
raise ValueError(f"Unsupported audio type: {type(audio)}. Expected base64 string or AudioSegment object.") | |
# Transcribe using the model | |
try: | |
result = self.model.automatic_speech_recognition(audio_data) | |
return result["text"] | |
except Exception as e: | |
raise RuntimeError(f"Error in transcription: {str(e)}") | |
except Exception as e: | |
raise RuntimeError(f"Error in transcription: {str(e)}") | |
transcribe_audio_tool = TranscribeAudioTool() | |
def get_audio_from_file_path(file_path: str) -> str: | |
""" | |
Load an audio file from a file path and convert it to a base64 string | |
Args: | |
file_path: Path to the audio file (should be in mp3 format) | |
Returns: | |
The audio file in base64 format | |
""" | |
# Load the audio file | |
try: | |
audio = AudioSegment.from_file(file_path) | |
except Exception as e: | |
current_file_path = os.path.abspath(__file__) | |
current_file_dir = os.path.dirname(current_file_path) | |
file_path = os.path.join(current_file_dir, file_path) | |
audio = AudioSegment.from_file(file_path) | |
# Export the audio to a BytesIO object | |
buffer = BytesIO() | |
audio.export(buffer, format="wav") # You can change the format if needed | |
# Encode the audio data to base64 | |
audio_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return audio_base64 | |
def noise_reduction(audio: str) -> str: | |
""" | |
Reduce noise from an audio file | |
Args: | |
audio: The audio file in base64 format | |
Returns: | |
The denoised audio file in base64 format | |
""" | |
# Decode the base64 audio | |
audio_data = base64.b64decode(audio) | |
audio_segment = AudioSegment.from_file(BytesIO(audio_data)) | |
# Apply noise reduction (simple example using low-pass filter) | |
denoised_audio = audio_segment.low_pass_filter(3000) | |
# Encode back to base64 | |
buffer = BytesIO() | |
denoised_audio.export(buffer, format="wav") | |
return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
def audio_segmentation(audio: str, segment_length: int = 30) -> list: | |
""" | |
Segment an audio file into smaller chunks | |
Args: | |
audio: The audio file in base64 format | |
segment_length: Length of each segment in seconds | |
Returns: | |
List of audio segments in base64 format. Each of these segments can be used as input for the `transcribe_audio` tool. | |
""" | |
# Decode the base64 audio | |
audio_data = base64.b64decode(audio) | |
audio_segment = AudioSegment.from_file(BytesIO(audio_data)) | |
# Segment the audio | |
segments = [] | |
for i in range(0, len(audio_segment), segment_length * 1000): | |
segment = audio_segment[i:i + segment_length * 1000] | |
buffer = BytesIO() | |
segment.export(buffer, format="wav") | |
segments.append(base64.b64encode(buffer.getvalue()).decode('utf-8')) | |
return segments | |
def speaker_diarization(audio: str) -> list: | |
""" | |
Diarize an audio file into speakers | |
Args: | |
audio: The audio file in base64 format | |
Returns: | |
List of speaker segments | |
""" | |
# Decode the base64 audio | |
audio_data = base64.b64decode(audio) | |
audio_buffer = BytesIO(audio_data) | |
# Create a temporary BytesIO object for processing | |
temp_buffer = BytesIO() | |
audio_segment = AudioSegment.from_file(audio_buffer) | |
audio_segment.export(temp_buffer, format="wav") | |
temp_buffer.seek(0) | |
# Perform speaker diarization using the buffer | |
[flags, classes, centers] = aS.speakerDiarization(temp_buffer, 2) # Assuming 2 speakers | |
# Process the output | |
speaker_segments = [] | |
for i, flag in enumerate(flags): | |
speaker_segments.append((i, flag)) | |
return speaker_segments |