Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import pipeline | |
import uvicorn | |
import tempfile | |
import torchaudio | |
app = FastAPI() | |
# Allow CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load model | |
pipe = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er") | |
async def predict(file: UploadFile = File(...)): | |
try: | |
# Save uploaded file to a temp file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
tmp.write(await file.read()) | |
tmp_path = tmp.name | |
# Load and preprocess audio | |
waveform, sample_rate = torchaudio.load(tmp_path) | |
# Get prediction | |
result = pipe(tmp_path) | |
# Get top prediction label | |
top_emotion = result[0]["label"].lower() | |
return {"emotion": top_emotion} | |
except Exception as e: | |
return {"error": str(e)} | |