Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,9 @@ from transformers import WhisperModel, WhisperFeatureExtractor
|
|
7 |
import datasets
|
8 |
from datasets import load_dataset, DatasetDict, Audio
|
9 |
from huggingface_hub import PyTorchModelHubMixin
|
|
|
|
|
|
|
10 |
|
11 |
# Ensure you have the device setup (cuda or cpu)
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -25,7 +28,8 @@ class SpeechInferenceDataset(Dataset):
|
|
25 |
return_tensors="pt",
|
26 |
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
|
27 |
input_features = inputs.input_features
|
28 |
-
|
|
|
29 |
return input_features, decoder_input_ids
|
30 |
|
31 |
# Define model class
|
@@ -51,44 +55,34 @@ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
|
|
51 |
logits = self.classifier(pooled_output)
|
52 |
return logits
|
53 |
|
54 |
-
# Prepare data function
|
55 |
-
def prepare_data(
|
56 |
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
input_features, decoder_input_ids = next(iter(inference_loader))
|
61 |
-
input_features = input_features.squeeze(1).to(device)
|
62 |
-
decoder_input_ids = decoder_input_ids.squeeze(1).to(device)
|
63 |
-
return input_features, decoder_input_ids
|
64 |
|
65 |
-
# Prediction function
|
66 |
-
def predict(
|
67 |
-
|
68 |
-
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
# Use the correct method to load your model (this is an example and may not directly apply)
|
73 |
-
model.load_state_dict(torch.load(model.push_from_hub("jcho02/whisper_cleft")))
|
74 |
-
model.eval()
|
75 |
-
|
76 |
-
with torch.no_grad():
|
77 |
-
logits = model(input_features, decoder_input_ids)
|
78 |
-
predicted_ids = int(torch.argmax(logits, dim=-1))
|
79 |
-
return predicted_ids
|
80 |
|
81 |
# Gradio Interface function for uploaded files
|
82 |
def gradio_file_interface(uploaded_file):
|
83 |
-
|
84 |
-
|
85 |
-
prediction = predict(uploaded_file.name)
|
86 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
87 |
return label
|
88 |
|
89 |
# Gradio Interface function for microphone input
|
90 |
def gradio_mic_interface(mic_input):
|
91 |
-
|
|
|
|
|
|
|
92 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
93 |
return label
|
94 |
|
@@ -99,18 +93,18 @@ demo = gr.Blocks()
|
|
99 |
with demo:
|
100 |
mic_transcribe = gr.Interface(
|
101 |
fn=gradio_mic_interface,
|
102 |
-
inputs=gr.Audio(), #
|
103 |
outputs=gr.Textbox(label="Prediction")
|
104 |
)
|
105 |
|
106 |
file_transcribe = gr.Interface(
|
107 |
fn=gradio_file_interface,
|
108 |
-
inputs=gr.Audio(type="filepath"), #
|
109 |
outputs=gr.Textbox(label="Prediction")
|
110 |
)
|
111 |
|
112 |
# Use a tabbed interface to switch between the microphone and file upload interfaces
|
113 |
gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
|
114 |
|
115 |
-
# Launch the demo with debugging enabled
|
116 |
demo.launch(debug=True)
|
|
|
7 |
import datasets
|
8 |
from datasets import load_dataset, DatasetDict, Audio
|
9 |
from huggingface_hub import PyTorchModelHubMixin
|
10 |
+
import numpy as np
|
11 |
+
import tempfile
|
12 |
+
import os
|
13 |
|
14 |
# Ensure you have the device setup (cuda or cpu)
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
28 |
return_tensors="pt",
|
29 |
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
|
30 |
input_features = inputs.input_features
|
31 |
+
# Modify decoder_input_ids as per your model's requirements
|
32 |
+
decoder_input_ids = torch.tensor([[1, 1]])
|
33 |
return input_features, decoder_input_ids
|
34 |
|
35 |
# Define model class
|
|
|
55 |
logits = self.classifier(pooled_output)
|
56 |
return logits
|
57 |
|
58 |
+
# Prepare data function (may need to update for numpy input)
|
59 |
+
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
|
60 |
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
|
61 |
+
# ... your logic for preparing data ...
|
62 |
+
# Must return tensor that your model can process
|
63 |
+
pass
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
# Prediction function (may need to update for numpy input)
|
66 |
+
def predict(audio_data, sampling_rate, config={"encoder": "openai/whisper-base", "num_labels": 2}):
|
67 |
+
# Load the model from Hugging Face Hub (ensure correct loading mechanism)
|
68 |
+
model = SpeechClassifier(config).to(device)
|
69 |
+
# ... your logic for prediction using model ...
|
70 |
+
# Must return a prediction
|
71 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
# Gradio Interface function for uploaded files
|
74 |
def gradio_file_interface(uploaded_file):
|
75 |
+
# Gradio passes a file path as a string for uploaded files
|
76 |
+
prediction = predict(uploaded_file, config)
|
|
|
77 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
78 |
return label
|
79 |
|
80 |
# Gradio Interface function for microphone input
|
81 |
def gradio_mic_interface(mic_input):
|
82 |
+
# Gradio passes mic input as a numpy array and sample rate
|
83 |
+
audio_data = mic_input['data']
|
84 |
+
sampling_rate = mic_input['sample_rate']
|
85 |
+
prediction = predict(audio_data, sampling_rate, config)
|
86 |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
|
87 |
return label
|
88 |
|
|
|
93 |
with demo:
|
94 |
mic_transcribe = gr.Interface(
|
95 |
fn=gradio_mic_interface,
|
96 |
+
inputs=gr.Audio(type="numpy"), # Receives numpy array for mic input
|
97 |
outputs=gr.Textbox(label="Prediction")
|
98 |
)
|
99 |
|
100 |
file_transcribe = gr.Interface(
|
101 |
fn=gradio_file_interface,
|
102 |
+
inputs=gr.Audio(type="filepath"), # Receives file path for file upload
|
103 |
outputs=gr.Textbox(label="Prediction")
|
104 |
)
|
105 |
|
106 |
# Use a tabbed interface to switch between the microphone and file upload interfaces
|
107 |
gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
|
108 |
|
109 |
+
# Launch the demo with debugging enabled
|
110 |
demo.launch(debug=True)
|