|
import os |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
import numpy as np |
|
import csv |
|
import requests |
|
import json |
|
import logging |
|
import scipy |
|
from scipy.io import wavfile |
|
from pydub import AudioSegment |
|
import io |
|
from io import BytesIO |
|
|
|
|
|
|
|
model = hub.load('Audio_Multiple_v1') |
|
|
|
def class_names_from_csv(class_map_csv_text): |
|
"""Returns list of class names corresponding to score vector.""" |
|
class_names = [] |
|
with tf.io.gfile.GFile(class_map_csv_text) as csvfile: |
|
reader = csv.DictReader(csvfile) |
|
for row in reader: |
|
class_names.append(row['display_name']) |
|
return class_names |
|
|
|
class_map_path = model.class_map_path().numpy() |
|
class_names = class_names_from_csv(class_map_path) |
|
|
|
def ensure_sample_rate(original_sample_rate, waveform, desired_sample_rate=16000): |
|
if original_sample_rate != desired_sample_rate: |
|
desired_length = int(round(float(len(waveform)) / original_sample_rate * desired_sample_rate)) |
|
waveform = np.array(scipy.signal.resample(waveform, desired_length), dtype=np.float32) |
|
return desired_sample_rate, waveform |
|
|
|
def convert_mp3_to_wav(mp3_data): |
|
audio = AudioSegment.from_file(io.BytesIO(mp3_data), format="mp3") |
|
wav_buffer = io.BytesIO() |
|
audio.export(wav_buffer, format='wav') |
|
wav_buffer.seek(0) |
|
return wav_buffer.getvalue() |
|
|
|
def process_audio_file(file_data, url): |
|
try: |
|
sample_rate, wav_data = wavfile.read(BytesIO(file_data)) |
|
if wav_data.ndim > 1: |
|
wav_data = np.mean(wav_data, axis=1) |
|
|
|
sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data) |
|
waveform = wav_data / tf.int16.max |
|
|
|
scores, embeddings, spectrogram = model(waveform) |
|
|
|
scores_np = scores.numpy() |
|
spectrogram_np = spectrogram.numpy() |
|
mean_scores = np.mean(scores, axis=0) |
|
|
|
top_two_indices = np.argsort(mean_scores)[-2:][::-1] |
|
inferred_class = class_names[top_two_indices[0]] |
|
|
|
if inferred_class == "Silence" and len(top_two_indices) > 1: |
|
inferred_class = class_names[top_two_indices[1]] |
|
|
|
answer_dict = {'url': url, 'answer': [inferred_class]} |
|
return answer_dict |
|
except Exception as e: |
|
logging.error(f"Error processing {url}: {e}") |
|
return None |
|
|
|
def get_audio_data(url): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
return response.content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_audio(params): |
|
try: |
|
params = json.loads(params) |
|
except json.JSONDecodeError as e: |
|
return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"} |
|
|
|
audio_files = params.get("urls", []) |
|
|
|
|
|
|
|
solutions = [] |
|
for audio_url in audio_files: |
|
audio_data = get_audio_data(audio_url) |
|
|
|
if audio_url.endswith(".mp3"): |
|
wav_data = convert_mp3_to_wav(audio_data) |
|
result = process_audio_file(wav_data, audio_url) |
|
|
|
elif audio_url.endswith(".wav"): |
|
result = process_audio_file(audio_data, audio_url) |
|
|
|
if result: |
|
solutions.append(result) |
|
|
|
|
|
|
|
|
|
return json.dumps({"solutions": solutions}) |
|
|
|
import gradio as gr |
|
|
|
inputt = gr.Textbox(label="Parameters (JSON format) Eg. {'urls':['file1.mp3','file2.wav']}") |
|
outputs = gr.JSON() |
|
|
|
application = gr.Interface(fn=process_audio, inputs=inputt, outputs=outputs, title="Audio Classification with API Integration") |
|
application.launch() |