Spaces:
Build error
Build error
Upload app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import tensorflow_hub as hub
|
| 4 |
+
import numpy as np
|
| 5 |
+
import csv
|
| 6 |
+
import requests
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import scipy
|
| 10 |
+
from scipy.io import wavfile
|
| 11 |
+
from pydub import AudioSegment
|
| 12 |
+
import io
|
| 13 |
+
from io import BytesIO
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Load the model
|
| 17 |
+
model = hub.load('Audio_Multiple_v1')
|
| 18 |
+
|
| 19 |
+
def class_names_from_csv(class_map_csv_text):
|
| 20 |
+
"""Returns list of class names corresponding to score vector."""
|
| 21 |
+
class_names = []
|
| 22 |
+
with tf.io.gfile.GFile(class_map_csv_text) as csvfile:
|
| 23 |
+
reader = csv.DictReader(csvfile)
|
| 24 |
+
for row in reader:
|
| 25 |
+
class_names.append(row['display_name'])
|
| 26 |
+
return class_names
|
| 27 |
+
|
| 28 |
+
class_map_path = model.class_map_path().numpy()
|
| 29 |
+
class_names = class_names_from_csv(class_map_path)
|
| 30 |
+
|
| 31 |
+
def ensure_sample_rate(original_sample_rate, waveform, desired_sample_rate=16000):
|
| 32 |
+
if original_sample_rate != desired_sample_rate:
|
| 33 |
+
desired_length = int(round(float(len(waveform)) / original_sample_rate * desired_sample_rate))
|
| 34 |
+
waveform = np.array(scipy.signal.resample(waveform, desired_length), dtype=np.float32)
|
| 35 |
+
return desired_sample_rate, waveform
|
| 36 |
+
|
| 37 |
+
def convert_mp3_to_wav(mp3_data):
|
| 38 |
+
audio = AudioSegment.from_file(io.BytesIO(mp3_data), format="mp3")
|
| 39 |
+
wav_buffer = io.BytesIO()
|
| 40 |
+
audio.export(wav_buffer, format='wav')
|
| 41 |
+
wav_buffer.seek(0)
|
| 42 |
+
return wav_buffer.getvalue()
|
| 43 |
+
|
| 44 |
+
def process_audio_file(file_data, url):
|
| 45 |
+
try:
|
| 46 |
+
sample_rate, wav_data = wavfile.read(BytesIO(file_data))
|
| 47 |
+
if wav_data.ndim > 1:
|
| 48 |
+
wav_data = np.mean(wav_data, axis=1)
|
| 49 |
+
|
| 50 |
+
sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)
|
| 51 |
+
waveform = wav_data / tf.int16.max
|
| 52 |
+
|
| 53 |
+
scores, embeddings, spectrogram = model(waveform)
|
| 54 |
+
|
| 55 |
+
scores_np = scores.numpy()
|
| 56 |
+
spectrogram_np = spectrogram.numpy()
|
| 57 |
+
mean_scores = np.mean(scores, axis=0)
|
| 58 |
+
|
| 59 |
+
top_two_indices = np.argsort(mean_scores)[-2:][::-1]
|
| 60 |
+
inferred_class = class_names[top_two_indices[0]]
|
| 61 |
+
|
| 62 |
+
if inferred_class == "Silence" and len(top_two_indices) > 1:
|
| 63 |
+
inferred_class = class_names[top_two_indices[1]]
|
| 64 |
+
|
| 65 |
+
answer_dict = {'url': url, 'answer': [inferred_class]}
|
| 66 |
+
return answer_dict
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logging.error(f"Error processing {url}: {e}")
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
def get_audio_data(url):
|
| 72 |
+
response = requests.get(url)
|
| 73 |
+
response.raise_for_status()
|
| 74 |
+
return response.content
|
| 75 |
+
|
| 76 |
+
# def send_results_to_api(data, result_url):
|
| 77 |
+
# headers = {"Content-Type": "application/json"}
|
| 78 |
+
# try:
|
| 79 |
+
# response = requests.post(result_url, json=data, headers=headers)
|
| 80 |
+
# response.raise_for_status() # Raise error for non-200 responses
|
| 81 |
+
# return response.json() # Return any JSON response from the API
|
| 82 |
+
# except requests.exceptions.HTTPError as http_err:
|
| 83 |
+
# logging.error(f"HTTP error occurred: {http_err}")
|
| 84 |
+
# return {"error": f"HTTP error occurred: {http_err}"}
|
| 85 |
+
# except requests.exceptions.RequestException as req_err:
|
| 86 |
+
# logging.error(f"Request error occurred: {req_err}")
|
| 87 |
+
# return {"error": f"Request error occurred: {req_err}"}
|
| 88 |
+
# except ValueError as val_err:
|
| 89 |
+
# logging.error(f"Error decoding JSON response: {val_err}")
|
| 90 |
+
# return {"error": f"Error decoding JSON response: {val_err}"}
|
| 91 |
+
|
| 92 |
+
def process_audio(params):
|
| 93 |
+
try:
|
| 94 |
+
params = json.loads(params)
|
| 95 |
+
except json.JSONDecodeError as e:
|
| 96 |
+
return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"}
|
| 97 |
+
|
| 98 |
+
audio_files = params.get("urls", [])
|
| 99 |
+
# api = params.get("api", "")
|
| 100 |
+
# job_id = params.get("job_id", "")
|
| 101 |
+
|
| 102 |
+
solutions = []
|
| 103 |
+
for audio_url in audio_files:
|
| 104 |
+
audio_data = get_audio_data(audio_url)
|
| 105 |
+
|
| 106 |
+
if audio_url.endswith(".mp3"):
|
| 107 |
+
wav_data = convert_mp3_to_wav(audio_data)
|
| 108 |
+
result = process_audio_file(wav_data, audio_url)
|
| 109 |
+
|
| 110 |
+
elif audio_url.endswith(".wav"):
|
| 111 |
+
result = process_audio_file(audio_data, audio_url)
|
| 112 |
+
|
| 113 |
+
if result:
|
| 114 |
+
solutions.append(result)
|
| 115 |
+
|
| 116 |
+
# result_url = f"{api}/{job_id}"
|
| 117 |
+
# send_results_to_api(solutions, result_url)
|
| 118 |
+
|
| 119 |
+
return json.dumps({"solutions": solutions})
|
| 120 |
+
|
| 121 |
+
import gradio as gr
|
| 122 |
+
|
| 123 |
+
inputt = gr.Textbox(label="Parameters (JSON format) Eg. {'urls':['file1.mp3','file2.wav']}")
|
| 124 |
+
outputs = gr.JSON()
|
| 125 |
+
|
| 126 |
+
application = gr.Interface(fn=process_audio, inputs=inputt, outputs=outputs, title="Audio Classification with API Integration")
|
| 127 |
+
application.launch()
|