|
import os |
|
import requests |
|
from smolagents import Tool |
|
|
|
class AudioTranscriptionTool(Tool): |
|
name = "audio_transcriber" |
|
description = "Transcribe a given audio file in .mp3 or .wav format using Whisper via Hugging Face API." |
|
inputs = { |
|
"file_path": { |
|
"type": "string", |
|
"description": "Path to the audio file (.mp3 or .wav)" |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
api_token = os.getenv("HF_API_TOKEN") |
|
if not api_token: |
|
raise EnvironmentError("HF_API_TOKEN not found in environment variables.") |
|
self.api_url = "https://api-inference.huggingface.com/models/openai/whisper-large" |
|
self.headers = { |
|
"Authorization": f"Bearer {api_token}" |
|
} |
|
|
|
def forward(self, file_path: str) -> str: |
|
if not file_path.lower().endswith((".mp3", ".wav")): |
|
return "Error: File must be .mp3 or .wav format." |
|
|
|
try: |
|
with open(file_path, "rb") as audio_file: |
|
audio_bytes = audio_file.read() |
|
|
|
response = requests.post( |
|
self.api_url, |
|
headers=self.headers, |
|
data=audio_bytes, |
|
timeout=60 |
|
) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
transcription = result.get("text") |
|
if transcription: |
|
return transcription.strip() |
|
else: |
|
return "Error: No transcription found in API response." |
|
else: |
|
return f"Error transcribing audio: {response.status_code} - {response.text}" |
|
|
|
except Exception as e: |
|
return f"Error transcribing audio: {e}" |
|
|
|
|