jcho02 commited on
Commit
4485862
·
verified ·
1 Parent(s): e2369f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -32
app.py CHANGED
@@ -7,6 +7,9 @@ from transformers import WhisperModel, WhisperFeatureExtractor
7
  import datasets
8
  from datasets import load_dataset, DatasetDict, Audio
9
  from huggingface_hub import PyTorchModelHubMixin
 
 
 
10
 
11
  # Ensure you have the device setup (cuda or cpu)
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -25,7 +28,8 @@ class SpeechInferenceDataset(Dataset):
25
  return_tensors="pt",
26
  sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
27
  input_features = inputs.input_features
28
- decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
 
29
  return input_features, decoder_input_ids
30
 
31
  # Define model class
@@ -51,44 +55,34 @@ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
51
  logits = self.classifier(pooled_output)
52
  return logits
53
 
54
- # Prepare data function
55
- def prepare_data(audio_file_path, model_checkpoint="openai/whisper-base"):
56
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
57
- inference_data = datasets.Dataset.from_dict({"path": [audio_file_path], "audio": [audio_file_path]}).cast_column("audio", Audio(sampling_rate=16_000))
58
- inference_dataset = SpeechInferenceDataset(inference_data, feature_extractor)
59
- inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)
60
- input_features, decoder_input_ids = next(iter(inference_loader))
61
- input_features = input_features.squeeze(1).to(device)
62
- decoder_input_ids = decoder_input_ids.squeeze(1).to(device)
63
- return input_features, decoder_input_ids
64
 
65
- # Prediction function
66
- def predict(audio_file_path, config={"encoder": "openai/whisper-base", "num_labels": 2}):
67
- input_features, decoder_input_ids = prepare_data(audio_file_path)
68
-
69
- # Load the model from Hugging Face Hub
70
- model = SpeechClassifier(config)
71
- model.to(device)
72
- # Use the correct method to load your model (this is an example and may not directly apply)
73
- model.load_state_dict(torch.load(model.push_from_hub("jcho02/whisper_cleft")))
74
- model.eval()
75
-
76
- with torch.no_grad():
77
- logits = model(input_features, decoder_input_ids)
78
- predicted_ids = int(torch.argmax(logits, dim=-1))
79
- return predicted_ids
80
 
81
  # Gradio Interface function for uploaded files
82
  def gradio_file_interface(uploaded_file):
83
- with open(uploaded_file.name, "wb") as f:
84
- f.write(uploaded_file.read())
85
- prediction = predict(uploaded_file.name)
86
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
87
  return label
88
 
89
  # Gradio Interface function for microphone input
90
  def gradio_mic_interface(mic_input):
91
- prediction = predict(mic_input.name)
 
 
 
92
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
93
  return label
94
 
@@ -99,18 +93,18 @@ demo = gr.Blocks()
99
  with demo:
100
  mic_transcribe = gr.Interface(
101
  fn=gradio_mic_interface,
102
- inputs=gr.Audio(), # No type needed for microphone input
103
  outputs=gr.Textbox(label="Prediction")
104
  )
105
 
106
  file_transcribe = gr.Interface(
107
  fn=gradio_file_interface,
108
- inputs=gr.Audio(type="filepath"), # Specify filepath for file upload
109
  outputs=gr.Textbox(label="Prediction")
110
  )
111
 
112
  # Use a tabbed interface to switch between the microphone and file upload interfaces
113
  gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
114
 
115
- # Launch the demo with debugging enabled to catch any potential errors early on
116
  demo.launch(debug=True)
 
7
  import datasets
8
  from datasets import load_dataset, DatasetDict, Audio
9
  from huggingface_hub import PyTorchModelHubMixin
10
+ import numpy as np
11
+ import tempfile
12
+ import os
13
 
14
  # Ensure you have the device setup (cuda or cpu)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
  return_tensors="pt",
29
  sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
30
  input_features = inputs.input_features
31
+ # Modify decoder_input_ids as per your model's requirements
32
+ decoder_input_ids = torch.tensor([[1, 1]])
33
  return input_features, decoder_input_ids
34
 
35
  # Define model class
 
55
  logits = self.classifier(pooled_output)
56
  return logits
57
 
58
+ # Prepare data function (may need to update for numpy input)
59
+ def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
60
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
61
+ # ... your logic for preparing data ...
62
+ # Must return tensor that your model can process
63
+ pass
 
 
 
 
64
 
65
+ # Prediction function (may need to update for numpy input)
66
+ def predict(audio_data, sampling_rate, config={"encoder": "openai/whisper-base", "num_labels": 2}):
67
+ # Load the model from Hugging Face Hub (ensure correct loading mechanism)
68
+ model = SpeechClassifier(config).to(device)
69
+ # ... your logic for prediction using model ...
70
+ # Must return a prediction
71
+ pass
 
 
 
 
 
 
 
 
72
 
73
  # Gradio Interface function for uploaded files
74
  def gradio_file_interface(uploaded_file):
75
+ # Gradio passes a file path as a string for uploaded files
76
+ prediction = predict(uploaded_file, config)
 
77
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
78
  return label
79
 
80
  # Gradio Interface function for microphone input
81
  def gradio_mic_interface(mic_input):
82
+ # Gradio passes mic input as a numpy array and sample rate
83
+ audio_data = mic_input['data']
84
+ sampling_rate = mic_input['sample_rate']
85
+ prediction = predict(audio_data, sampling_rate, config)
86
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
87
  return label
88
 
 
93
  with demo:
94
  mic_transcribe = gr.Interface(
95
  fn=gradio_mic_interface,
96
+ inputs=gr.Audio(type="numpy"), # Receives numpy array for mic input
97
  outputs=gr.Textbox(label="Prediction")
98
  )
99
 
100
  file_transcribe = gr.Interface(
101
  fn=gradio_file_interface,
102
+ inputs=gr.Audio(type="filepath"), # Receives file path for file upload
103
  outputs=gr.Textbox(label="Prediction")
104
  )
105
 
106
  # Use a tabbed interface to switch between the microphone and file upload interfaces
107
  gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
108
 
109
+ # Launch the demo with debugging enabled
110
  demo.launch(debug=True)