Rivalcoder
Add files
2981dff
raw
history blame
2.21 kB
import tempfile
from transformers import pipeline
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Request
import os
from typing import Optional
# Initialize classifier
classifier = pipeline("audio-classification", model="superb/hubert-large-superb-er")
# Create FastAPI app (works with Gradio)
app = FastAPI()
def save_upload_file(upload_file: UploadFile) -> str:
"""Save uploaded file to temporary location"""
try:
suffix = os.path.splitext(upload_file.filename)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(upload_file.file.read())
return tmp.name
finally:
upload_file.file.close()
@app.post("/api/predict")
async def predict_from_upload(file: UploadFile = File(...)):
"""API endpoint for FormData uploads"""
try:
# Save the uploaded file temporarily
temp_path = save_upload_file(file)
# Process the audio
predictions = classifier(temp_path)
# Clean up
os.unlink(temp_path)
return {"predictions": predictions}
except Exception as e:
return {"error": str(e)}, 500
# Gradio interface for testing
def gradio_predict(audio_file):
"""Gradio interface that handles both file objects and paths"""
if isinstance(audio_file, str): # Path from Gradio upload
audio_path = audio_file
else: # Direct file object
temp_path = save_upload_file(audio_file)
audio_path = temp_path
predictions = classifier(audio_path)
if hasattr(audio_file, 'file'): # Clean up if we created temp file
os.unlink(audio_path)
return {p["label"]: p["score"] for p in predictions}
# Create Gradio interface
demo = gr.Interface(
fn=gradio_predict,
inputs=gr.Audio(type="filepath", label="Upload Audio"),
outputs=gr.Label(num_top_classes=5),
title="Audio Emotion Recognition",
description="Upload an audio file to analyze emotional content"
)
# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/")
# For running locally
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)