Spaces:
Sleeping
Sleeping
import requests | |
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import JSONResponse | |
import io | |
import numpy as np | |
import tensorflow as tf | |
import librosa | |
from pydub import AudioSegment | |
import wave | |
import os | |
os.environ['NUMBA_CACHE_DIR'] = '/tmp' | |
app = FastAPI() | |
# Load the model | |
model = tf.keras.models.load_model('model.h5') | |
def extract_features(y, sr): | |
mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13) | |
return np.mean(mfccs.T, axis=0) | |
def segment_audio(y, sr, segment_length=2, hop_length=1): | |
frames = [] | |
for start in range(0, len(y) - int(segment_length * sr), int(hop_length * sr)): | |
end = start + int(segment_length * sr) | |
segment = y[start:end] | |
frames.append(extract_features(segment, sr)) | |
return np.array(frames, dtype=np.float32) | |
def predict_periods(model, y, sr, segment_length=2, hop_length=1): | |
frames = segment_audio(y, sr, segment_length, hop_length) | |
predictions = model.predict(frames) | |
predicted_labels = np.argmax(predictions, axis=1) | |
# Calculate duration of each segment | |
num_segments = len(predicted_labels) | |
durations = [(segment_length * i, segment_length * (i + 1)) for i in range(num_segments)] | |
return predicted_labels, durations | |
def process_audio(audio_file): | |
# Load the audio file using pydub | |
audio = AudioSegment.from_file(audio_file.file, format="m4a") | |
# Convert the audio file to WAV format | |
wav_io = io.BytesIO() | |
audio.export(wav_io, format="wav") | |
wav_io.seek(0) | |
# Load the WAV file using wave | |
with wave.open(wav_io, 'rb') as wav_file: | |
# Get the audio data and sample rate | |
audio_data = wav_file.readframes(wav_file.getnframes()) | |
sample_rate = wav_file.getframerate() | |
# Convert the audio data to a numpy array | |
y = np.frombuffer(audio_data, dtype=np.int16) | |
y = y / 32768.0 # Normalize the audio data | |
return y, sample_rate | |
def upload_to_tmpfiles(file_path): | |
upload_url = 'https://tmpfiles.org/api/v1/upload' | |
try: | |
with open(file_path, 'rb') as f: | |
response = requests.post(upload_url, files={'file': f}) | |
response.raise_for_status() # Raise an error for bad HTTP status codes | |
result = response.json() | |
# Debug: Print the response to see its structure | |
print("Upload response:", result) | |
except Exception as e: | |
print(f"Error uploading to tmpfiles: {e}") | |
raise | |
async def predict(audio: UploadFile = File(...)): | |
try: | |
# Process the uploaded audio file | |
y, sr = process_audio(audio) | |
# Process the audio file contents | |
predicted_labels, durations = predict_periods(model, y, sr) | |
# Map predicted labels to periods and their durations | |
periods = [] | |
for i, label in enumerate(predicted_labels): | |
period_type = 'Inhale' if label == 1 else 'Exhale' # Adjust label mapping as needed | |
period_start, period_end = durations[i] | |
duration = period_end - period_start | |
periods.append({ | |
"Period": i + 1, | |
"Type": period_type, | |
"Duration": round(duration, 2) # Round duration to 2 decimal places | |
}) | |
# Save the graph image | |
#graph_path = '/tmp/tmpgraph.png' | |
# Replace with your actual graph generation logic | |
# plt.savefig(graph_path) | |
# Upload the graph image | |
#graph_url = upload_to_tmpfiles(graph_path) | |
return JSONResponse(content={ | |
"periods": periods, | |
"results": "File uploaded successfully" | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}") | |