breath-training / app.py
Dimsumcat's picture
Update app.py
723723f verified
import requests
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import io
import numpy as np
import tensorflow as tf
import librosa
from pydub import AudioSegment
import wave
import os
os.environ['NUMBA_CACHE_DIR'] = '/tmp'
app = FastAPI()
# Load the model
model = tf.keras.models.load_model('model.h5')
def extract_features(y, sr):
mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
return np.mean(mfccs.T, axis=0)
def segment_audio(y, sr, segment_length=2, hop_length=1):
frames = []
for start in range(0, len(y) - int(segment_length * sr), int(hop_length * sr)):
end = start + int(segment_length * sr)
segment = y[start:end]
frames.append(extract_features(segment, sr))
return np.array(frames, dtype=np.float32)
def predict_periods(model, y, sr, segment_length=2, hop_length=1):
frames = segment_audio(y, sr, segment_length, hop_length)
predictions = model.predict(frames)
predicted_labels = np.argmax(predictions, axis=1)
# Calculate duration of each segment
num_segments = len(predicted_labels)
durations = [(segment_length * i, segment_length * (i + 1)) for i in range(num_segments)]
return predicted_labels, durations
def process_audio(audio_file):
# Load the audio file using pydub
audio = AudioSegment.from_file(audio_file.file, format="m4a")
# Convert the audio file to WAV format
wav_io = io.BytesIO()
audio.export(wav_io, format="wav")
wav_io.seek(0)
# Load the WAV file using wave
with wave.open(wav_io, 'rb') as wav_file:
# Get the audio data and sample rate
audio_data = wav_file.readframes(wav_file.getnframes())
sample_rate = wav_file.getframerate()
# Convert the audio data to a numpy array
y = np.frombuffer(audio_data, dtype=np.int16)
y = y / 32768.0 # Normalize the audio data
return y, sample_rate
def upload_to_tmpfiles(file_path):
upload_url = 'https://tmpfiles.org/api/v1/upload'
try:
with open(file_path, 'rb') as f:
response = requests.post(upload_url, files={'file': f})
response.raise_for_status() # Raise an error for bad HTTP status codes
result = response.json()
# Debug: Print the response to see its structure
print("Upload response:", result)
except Exception as e:
print(f"Error uploading to tmpfiles: {e}")
raise
@app.post("/predict/")
async def predict(audio: UploadFile = File(...)):
try:
# Process the uploaded audio file
y, sr = process_audio(audio)
# Process the audio file contents
predicted_labels, durations = predict_periods(model, y, sr)
# Map predicted labels to periods and their durations
periods = []
for i, label in enumerate(predicted_labels):
period_type = 'Inhale' if label == 1 else 'Exhale' # Adjust label mapping as needed
period_start, period_end = durations[i]
duration = period_end - period_start
periods.append({
"Period": i + 1,
"Type": period_type,
"Duration": round(duration, 2) # Round duration to 2 decimal places
})
# Save the graph image
#graph_path = '/tmp/tmpgraph.png'
# Replace with your actual graph generation logic
# plt.savefig(graph_path)
# Upload the graph image
#graph_url = upload_to_tmpfiles(graph_path)
return JSONResponse(content={
"periods": periods,
"results": "File uploaded successfully"
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")