gautamtata
commited on
Commit
•
3c03da2
1
Parent(s):
f41ae34
Modifying to accept file URL from the Supabase
Browse files- handler.py +32 -8
handler.py
CHANGED
@@ -5,6 +5,9 @@ import torchaudio
|
|
5 |
import torch.nn.functional as F
|
6 |
from typing import Dict, List, Any
|
7 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
|
|
|
8 |
|
9 |
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
10 |
Wav2Vec2PreTrainedModel,
|
@@ -145,15 +148,36 @@ class EndpointHandler():
|
|
145 |
outputs = [{"label": self.config.id2label[i], "score": score} for i, score in enumerate(scores)]
|
146 |
return outputs
|
147 |
|
148 |
-
def
|
149 |
"""
|
150 |
-
|
151 |
"""
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
else:
|
159 |
-
return {"error": "
|
|
|
5 |
import torch.nn.functional as F
|
6 |
from typing import Dict, List, Any
|
7 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
8 |
+
import requests
|
9 |
+
import tempfile
|
10 |
+
import os
|
11 |
|
12 |
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
13 |
Wav2Vec2PreTrainedModel,
|
|
|
148 |
outputs = [{"label": self.config.id2label[i], "score": score} for i, score in enumerate(scores)]
|
149 |
return outputs
|
150 |
|
151 |
+
def download_file(self, url):
|
152 |
"""
|
153 |
+
Downloads the file from the given URL and returns the path to the saved temporary file.
|
154 |
"""
|
155 |
+
response = requests.get(url)
|
156 |
+
if response.status_code == 200:
|
157 |
+
# Create a temporary file
|
158 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
|
159 |
+
temp_file.write(response.content)
|
160 |
+
temp_file.close()
|
161 |
+
return temp_file.name
|
162 |
+
else:
|
163 |
+
return None
|
164 |
|
165 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
166 |
+
"""
|
167 |
+
The method called during inference. Expects data to have a 'url' to the audio file.
|
168 |
+
"""
|
169 |
+
# Get the URL to the audio file from the request data
|
170 |
+
url = data.get("url")
|
171 |
+
|
172 |
+
# If the URL is provided, download the file and run the prediction
|
173 |
+
if url:
|
174 |
+
file_path = self.download_file(url)
|
175 |
+
if file_path:
|
176 |
+
output = self.predict(file_path)
|
177 |
+
# Optionally, delete the temporary file
|
178 |
+
os.remove(file_path)
|
179 |
+
return output
|
180 |
+
else:
|
181 |
+
return {"error": "Could not download the file from the provided URL."}
|
182 |
else:
|
183 |
+
return {"error": "URL to the audio file is required."}
|