Hammad712 commited on
Commit
6682f41
·
verified ·
1 Parent(s): 85a59b8

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. depression_audio_model1.keras +3 -0
  3. main.py +75 -0
  4. requirements.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ depression_audio_model1.keras filter=lfs diff=lfs merge=lfs -text
depression_audio_model1.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ca9cc124fa7d6d0ca747f170db61331020c0b0844b7bf296413ad56065b7edb
3
+ size 36825812
main.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ import tensorflow as tf
4
+ import librosa
5
+ import numpy as np
6
+ import uvicorn
7
+ import os
8
+
9
+ # Load the pre-trained model
10
+ loaded_model = tf.keras.models.load_model('depression_audio_model1.keras')
11
+ print("Model loaded successfully.")
12
+
13
+ # Constants
14
+ N_MELS = 128
15
+ N_FFT = 2048
16
+ HOP_LENGTH = 512
17
+ DURATION = 10
18
+ SAMPLE_RATE = 22050
19
+ FIXED_SHAPE = (N_MELS, int(DURATION * SAMPLE_RATE / HOP_LENGTH))
20
+
21
+ # Create the FastAPI app
22
+ app = FastAPI()
23
+
24
+ def extract_mel_spectrogram(file_path, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION, sample_rate=SAMPLE_RATE):
25
+ signal, _ = librosa.load(file_path, sr=sample_rate, duration=duration)
26
+ mel_spectrogram = librosa.feature.melspectrogram(y=signal, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
27
+ mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
28
+ mean = mel_spectrogram_db.mean()
29
+ std = mel_spectrogram_db.std()
30
+ if std > 0:
31
+ mel_spectrogram_db = (mel_spectrogram_db - mean) / std
32
+ else:
33
+ mel_spectrogram_db = mel_spectrogram_db - mean
34
+
35
+ if mel_spectrogram_db.shape[1] < FIXED_SHAPE[1]:
36
+ pad_width = FIXED_SHAPE[1] - mel_spectrogram_db.shape[1]
37
+ mel_spectrogram_db = np.pad(mel_spectrogram_db, ((0, 0), (0, pad_width)), mode='constant')
38
+ else:
39
+ mel_spectrogram_db = mel_spectrogram_db[:, :FIXED_SHAPE[1]]
40
+ return mel_spectrogram_db
41
+
42
+ def inference(file_path):
43
+ mel_spectrogram_db = extract_mel_spectrogram(file_path)
44
+ mel_spectrogram_db = mel_spectrogram_db.reshape(1, *mel_spectrogram_db.shape) # Add batch dimension
45
+ prediction = loaded_model.predict(mel_spectrogram_db)
46
+ predicted_label = np.argmax(prediction, axis=-1)
47
+
48
+ return int(predicted_label[0])
49
+
50
+ @app.post("/predict")
51
+ async def predict(file: UploadFile):
52
+ try:
53
+ # Check file type
54
+ if not file.filename.endswith(('.wav', '.mp3')):
55
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.")
56
+
57
+ # Save uploaded file to a temporary location
58
+ temp_file_path = f"temp_{file.filename}"
59
+ with open(temp_file_path, "wb") as temp_file:
60
+ temp_file.write(await file.read())
61
+
62
+ # Perform inference
63
+ predicted_label = inference(temp_file_path)
64
+
65
+ # Remove temporary file
66
+ os.remove(temp_file_path)
67
+
68
+ return JSONResponse(content={"prediction": predicted_label})
69
+
70
+ except Exception as e:
71
+ raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
72
+
73
+ # Run the application if the script is executed directly
74
+ if __name__ == "__main__":
75
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ tensorflow
4
+ librosa
5
+ numpy