AB739 commited on
Commit
d4152c9
·
verified ·
1 Parent(s): 9a3271e

Create audio_onnx.py

Browse files
Files changed (1) hide show
  1. tasks/audio_onnx.py +118 -0
tasks/audio_onnx.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from datetime import datetime
3
+ from datasets import load_dataset
4
+ from sklearn.metrics import accuracy_score
5
+ import os
6
+ import torch
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+ from torchaudio import transforms
9
+ from torchvision import models
10
+ import onnxruntime as ort # Add ONNX Runtime
11
+ from .utils.evaluation import AudioEvaluationRequest
12
+ from .utils.emissions import tracker, clean_emissions_data, get_space_info
13
+
14
+ from dotenv import load_dotenv
15
+ load_dotenv()
16
+
17
+ router = APIRouter()
18
+
19
+ DESCRIPTION = "Tiny_DNN"
20
+ ROUTE = "/audio"
21
+
22
+ torch.set_num_threads(4)
23
+ torch.set_num_interop_threads(2)
24
+
25
+ @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
26
+ async def evaluate_audio(request: AudioEvaluationRequest):
27
+ # Get space info
28
+ username, space_url = get_space_info()
29
+
30
+ # Define the label mapping
31
+ LABEL_MAPPING = {
32
+ "chainsaw": 0,
33
+ "environment": 1
34
+ }
35
+
36
+ # Load and prepare the dataset
37
+ dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
38
+ train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
39
+ test_dataset = train_test["test"]
40
+ true_labels = test_dataset["label"]
41
+
42
+ resampler = transforms.Resample(orig_freq=12000, new_freq=16000)
43
+ mel_transform = transforms.MelSpectrogram(sample_rate=16000, n_mels=64)
44
+ amplitude_to_db = transforms.AmplitudeToDB()
45
+
46
+ def resize_audio(_waveform, target_length):
47
+ num_frames = _waveform.shape[-1]
48
+ if num_frames != target_length:
49
+ _resampler = transforms.Resample(orig_freq=num_frames, new_freq=target_length)
50
+ _waveform = _resampler(_waveform)
51
+ return _waveform
52
+
53
+ resized_waveforms = [
54
+ resize_audio(torch.tensor(sample['audio']['array'], dtype=torch.float32).unsqueeze(0), target_length=72000)
55
+ for sample in test_dataset
56
+ ]
57
+
58
+ waveforms, labels = [], []
59
+ for waveform, label in zip(resized_waveforms, true_labels):
60
+ waveforms.append(amplitude_to_db(mel_transform(resampler(waveform))))
61
+ labels.append(label)
62
+
63
+ waveforms = torch.stack(waveforms)
64
+ labels = torch.tensor(labels)
65
+
66
+ test_loader = DataLoader(
67
+ TensorDataset(waveforms, labels),
68
+ batch_size=128,
69
+ shuffle=False,
70
+ pin_memory=True,
71
+ num_workers=4
72
+ )
73
+
74
+ # Load ONNX model
75
+ onnx_model_path = "./output_model.onnx"
76
+ session_options = ort.SessionOptions()
77
+ session_options.intra_op_num_threads = 4
78
+ session_options.inter_op_num_threads = 2
79
+ ort_session = ort.InferenceSession(onnx_model_path, session_options)
80
+
81
+ # Start tracking emissions
82
+ tracker.start()
83
+ tracker.start_task("inference")
84
+
85
+ # ONNX inference
86
+ predictions = []
87
+ for data, target in test_loader:
88
+ inputs = data.numpy() # Convert tensor to numpy
89
+ ort_inputs = {'input': inputs}
90
+ ort_outputs = ort_session.run(None, ort_inputs)
91
+ predicted = ort_outputs[0].argmax(axis=1) # Assuming output shape is [batch_size, num_classes]
92
+ predictions.extend(predicted.tolist())
93
+
94
+ # Stop tracking emissions
95
+ emissions_data = tracker.stop_task()
96
+
97
+ # Calculate accuracy
98
+ accuracy = accuracy_score(true_labels, predictions)
99
+
100
+ # Prepare results dictionary
101
+ results = {
102
+ "username": username,
103
+ "space_url": space_url,
104
+ "submission_timestamp": datetime.now().isoformat(),
105
+ "model_description": DESCRIPTION,
106
+ "accuracy": float(accuracy),
107
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
108
+ "emissions_gco2eq": emissions_data.emissions * 1000,
109
+ "emissions_data": clean_emissions_data(emissions_data),
110
+ "api_route": ROUTE,
111
+ "dataset_config": {
112
+ "dataset_name": request.dataset_name,
113
+ "test_size": request.test_size,
114
+ "test_seed": request.test_seed
115
+ }
116
+ }
117
+
118
+ return results