gautamtata commited on
Commit
3c03da2
1 Parent(s): f41ae34

Modifying to accept file URL from the Supabase

Browse files
Files changed (1) hide show
  1. handler.py +32 -8
handler.py CHANGED
@@ -5,6 +5,9 @@ import torchaudio
5
  import torch.nn.functional as F
6
  from typing import Dict, List, Any
7
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
 
 
8
 
9
  from transformers.models.wav2vec2.modeling_wav2vec2 import (
10
  Wav2Vec2PreTrainedModel,
@@ -145,15 +148,36 @@ class EndpointHandler():
145
  outputs = [{"label": self.config.id2label[i], "score": score} for i, score in enumerate(scores)]
146
  return outputs
147
 
148
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
149
  """
150
- The actual method called during inference. Expects data to have a 'path' to the audio file.
151
  """
152
- # Get the path to the audio file from the request data
153
- path = data.get("path")
 
 
 
 
 
 
 
154
 
155
- # If the path is provided, we run the prediction, else return an error message
156
- if path:
157
- return self.predict(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  else:
159
- return {"error": "Path to the audio file is required."}
 
5
  import torch.nn.functional as F
6
  from typing import Dict, List, Any
7
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
+ import requests
9
+ import tempfile
10
+ import os
11
 
12
  from transformers.models.wav2vec2.modeling_wav2vec2 import (
13
  Wav2Vec2PreTrainedModel,
 
148
  outputs = [{"label": self.config.id2label[i], "score": score} for i, score in enumerate(scores)]
149
  return outputs
150
 
151
+ def download_file(self, url):
152
  """
153
+ Downloads the file from the given URL and returns the path to the saved temporary file.
154
  """
155
+ response = requests.get(url)
156
+ if response.status_code == 200:
157
+ # Create a temporary file
158
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
159
+ temp_file.write(response.content)
160
+ temp_file.close()
161
+ return temp_file.name
162
+ else:
163
+ return None
164
 
165
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
166
+ """
167
+ The method called during inference. Expects data to have a 'url' to the audio file.
168
+ """
169
+ # Get the URL to the audio file from the request data
170
+ url = data.get("url")
171
+
172
+ # If the URL is provided, download the file and run the prediction
173
+ if url:
174
+ file_path = self.download_file(url)
175
+ if file_path:
176
+ output = self.predict(file_path)
177
+ # Optionally, delete the temporary file
178
+ os.remove(file_path)
179
+ return output
180
+ else:
181
+ return {"error": "Could not download the file from the provided URL."}
182
  else:
183
+ return {"error": "URL to the audio file is required."}