File size: 2,466 Bytes
df24f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# ======================
# app.py
# ======================

from flask import Flask, request, jsonify
import tensorflow as tf
import numpy as np
import pandas as pd
import librosa
from transformers import AutoTokenizer, TFAutoModel

# Load saved files
model = tf.keras.models.load_model("model.h5")
scaler = pd.read_pickle("scaler.pkl")
encoder = pd.read_pickle("label_encoder.pkl")
meta = pd.read_excel("raga_metadata.xlsx")

# Load IndicBERT model
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
bert_model = TFAutoModel.from_pretrained("ai4bharat/IndicBERTv2-MLM-only", from_pt=True)

app = Flask(__name__)

def extract_features(file_path):
    y, sr = librosa.load(file_path, sr=22050)
    features = {
        "chroma_stft": np.mean(librosa.feature.chroma_stft(y=y, sr=sr)),
        "spec_cent": np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)),
    }
    mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=18)
    for i in range(18):
        features[f"mfcc{i+1}"] = np.mean(mfccs[i])
    return pd.DataFrame([features])

def predict_raga(audio_df, raga_name):
    audio_scaled = scaler.transform(audio_df)
    audio_lstm_input = audio_scaled.reshape((1, 1, audio_scaled.shape[1]))

    # Get description for this raga
    description_text = meta[meta['raga'] == raga_name]['description'].values
    if len(description_text) == 0:
        description_text = [""]

    desc_tok = tokenizer(description_text.tolist(), padding=True, truncation=True, max_length=64, return_tensors="tf")
    desc_embed = bert_model(desc_tok['input_ids'], attention_mask=desc_tok['attention_mask'])[0][:, 0, :]

    pred = model.predict([audio_lstm_input, desc_embed])
    return encoder.inverse_transform([np.argmax(pred)])[0]

@app.route("/")
def home():
    return "🎶 Raga Prediction API is Live!"

@app.route("/predict", methods=["POST"])
def predict():
    try:
        audio_file = request.files['audio']
        raga_name = request.form['raga_name']

        temp_audio_path = "temp_audio.wav"
        audio_file.save(temp_audio_path)

        features = extract_features(temp_audio_path)
        predicted_raga = predict_raga(features, raga_name)

        return jsonify({
            "predicted_raga": predicted_raga
        })

    except Exception as e:
        return jsonify({"error": str(e)})

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)