MathFrenchToast's picture
feat: reach 35% so pass the test with the xagent
5675d05
raw
history blame
2.49 kB
import io
import os
import shutil
import subprocess
import requests
import uuid
from smolagents import tool
@tool
def stt(file_url: str, language: str = "en-US") -> str:
"""
Convert speech to text using local whisper model.
This function downloads an audio file from a given URL, converts it to WAV format if necessary,
then use whisper model to transcribe the audio to text.
Args:
file_url (str): The URL of the audio file to transcribe.
language (str): The language code for the transcription. Default is "en-US".
Returns:
str: The transcribed text.
"""
file_name = uuid.uuid4().hex +".mp3"
dest_folder = "c:\\Users\\mathi\\dev\\stt"
file_path = os.path.join(dest_folder + "\\tmp", file_name)
# 1. download the file from url (in pure python without wget or curl)
if not os.path.exists(file_name):
response = requests.get(file_url)
if response.status_code == 200:
with open(file_path, "wb") as f:
f.write(response.content)
else:
raise Exception(f"Error downloading file: {response.status_code}")
# 2. if it is a mp3 convert to wav with ffmpeg exec
if file_name.endswith(".mp3"):
cmd = f"ffmpeg -i {file_path} -ac 1 -ar 16000 -c:a pcm_s16le {file_path[:-4]}.wav"
cmd_as_list = cmd.split()
subprocess.run(cmd_as_list, cwd=dest_folder, check=True)
file_path = file_path[:-4] + ".wav"
file_name = file_name[:-4] + ".wav"
# 3. copy file to data folder
shutil.copy2(file_path, os.path.join(dest_folder, "testdata/"))
# 4. call docker run command
docker_command = f"""
docker run
-v {dest_folder}/models:/app/models
-v {dest_folder}/testdata:/app/testdata
ghcr.io/appleboy/go-whisper:latest
--model /app/models/ggml-small.bin
--audio-path /app/testdata/{file_name}
"""
subprocess.run(docker_command.split(), cwd=dest_folder, check=True)
# 5. cat the output file an return it
output_filepath = os.path.join(dest_folder, "testdata", f"{file_name[:-4]}.txt")
with open(output_filepath, "r") as f:
text = f.read()
return text
if __name__ == "__main__":
transcript = stt("https://agents-course-unit4-scoring.hf.space/files/99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3", )
print(transcript)