D3V1L1810 commited on
Commit
03e3784
·
verified ·
1 Parent(s): e9970a8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tensorflow as tf
3
+ import tensorflow_hub as hub
4
+ import numpy as np
5
+ import csv
6
+ import requests
7
+ import json
8
+ import logging
9
+ import scipy
10
+ from scipy.io import wavfile
11
+ from pydub import AudioSegment
12
+ import io
13
+ from io import BytesIO
14
+
15
+
16
+ # Load the model
17
+ model = hub.load('Audio_Multiple_v1')
18
+
19
+ def class_names_from_csv(class_map_csv_text):
20
+ """Returns list of class names corresponding to score vector."""
21
+ class_names = []
22
+ with tf.io.gfile.GFile(class_map_csv_text) as csvfile:
23
+ reader = csv.DictReader(csvfile)
24
+ for row in reader:
25
+ class_names.append(row['display_name'])
26
+ return class_names
27
+
28
+ class_map_path = model.class_map_path().numpy()
29
+ class_names = class_names_from_csv(class_map_path)
30
+
31
+ def ensure_sample_rate(original_sample_rate, waveform, desired_sample_rate=16000):
32
+ if original_sample_rate != desired_sample_rate:
33
+ desired_length = int(round(float(len(waveform)) / original_sample_rate * desired_sample_rate))
34
+ waveform = np.array(scipy.signal.resample(waveform, desired_length), dtype=np.float32)
35
+ return desired_sample_rate, waveform
36
+
37
+ def convert_mp3_to_wav(mp3_data):
38
+ audio = AudioSegment.from_file(io.BytesIO(mp3_data), format="mp3")
39
+ wav_buffer = io.BytesIO()
40
+ audio.export(wav_buffer, format='wav')
41
+ wav_buffer.seek(0)
42
+ return wav_buffer.getvalue()
43
+
44
+ def process_audio_file(file_data, url):
45
+ try:
46
+ sample_rate, wav_data = wavfile.read(BytesIO(file_data))
47
+ if wav_data.ndim > 1:
48
+ wav_data = np.mean(wav_data, axis=1)
49
+
50
+ sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)
51
+ waveform = wav_data / tf.int16.max
52
+
53
+ scores, embeddings, spectrogram = model(waveform)
54
+
55
+ scores_np = scores.numpy()
56
+ spectrogram_np = spectrogram.numpy()
57
+ mean_scores = np.mean(scores, axis=0)
58
+
59
+ top_two_indices = np.argsort(mean_scores)[-2:][::-1]
60
+ inferred_class = class_names[top_two_indices[0]]
61
+
62
+ if inferred_class == "Silence" and len(top_two_indices) > 1:
63
+ inferred_class = class_names[top_two_indices[1]]
64
+
65
+ answer_dict = {'url': url, 'answer': [inferred_class]}
66
+ return answer_dict
67
+ except Exception as e:
68
+ logging.error(f"Error processing {url}: {e}")
69
+ return None
70
+
71
+ def get_audio_data(url):
72
+ response = requests.get(url)
73
+ response.raise_for_status()
74
+ return response.content
75
+
76
+ # def send_results_to_api(data, result_url):
77
+ # headers = {"Content-Type": "application/json"}
78
+ # try:
79
+ # response = requests.post(result_url, json=data, headers=headers)
80
+ # response.raise_for_status() # Raise error for non-200 responses
81
+ # return response.json() # Return any JSON response from the API
82
+ # except requests.exceptions.HTTPError as http_err:
83
+ # logging.error(f"HTTP error occurred: {http_err}")
84
+ # return {"error": f"HTTP error occurred: {http_err}"}
85
+ # except requests.exceptions.RequestException as req_err:
86
+ # logging.error(f"Request error occurred: {req_err}")
87
+ # return {"error": f"Request error occurred: {req_err}"}
88
+ # except ValueError as val_err:
89
+ # logging.error(f"Error decoding JSON response: {val_err}")
90
+ # return {"error": f"Error decoding JSON response: {val_err}"}
91
+
92
+ def process_audio(params):
93
+ try:
94
+ params = json.loads(params)
95
+ except json.JSONDecodeError as e:
96
+ return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"}
97
+
98
+ audio_files = params.get("urls", [])
99
+ # api = params.get("api", "")
100
+ # job_id = params.get("job_id", "")
101
+
102
+ solutions = []
103
+ for audio_url in audio_files:
104
+ audio_data = get_audio_data(audio_url)
105
+
106
+ if audio_url.endswith(".mp3"):
107
+ wav_data = convert_mp3_to_wav(audio_data)
108
+ result = process_audio_file(wav_data, audio_url)
109
+
110
+ elif audio_url.endswith(".wav"):
111
+ result = process_audio_file(audio_data, audio_url)
112
+
113
+ if result:
114
+ solutions.append(result)
115
+
116
+ # result_url = f"{api}/{job_id}"
117
+ # send_results_to_api(solutions, result_url)
118
+
119
+ return json.dumps({"solutions": solutions})
120
+
121
+ import gradio as gr
122
+
123
+ inputt = gr.Textbox(label="Parameters (JSON format) Eg. {'urls':['file1.mp3','file2.wav']}")
124
+ outputs = gr.JSON()
125
+
126
+ application = gr.Interface(fn=process_audio, inputs=inputt, outputs=outputs, title="Audio Classification with API Integration")
127
+ application.launch()