spycoder commited on
Commit
ba42b9f
·
1 Parent(s): 0d63a0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -4
app.py CHANGED
@@ -1,7 +1,49 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import soundfile as sf
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
 
8
+ import os
9
+ import soundfile as sf
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
15
+ from sklearn.model_selection import train_test_split
16
+ import re
17
+ from collections import Counter
18
+ from sklearn.metrics import classification_report
19
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device)
20
+ model_path = "dysarthria_classifier12.pth"
21
+ if os.path.exists(model_path):
22
+ print(f"Loading saved model {model_path}")
23
+ model.load_state_dict(torch.load(model_path))
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
26
+ def predict(file_path):
27
+ max_length = 100000
28
 
29
+ model.eval()
30
+ with torch.no_grad():
31
+ wav_data, _ = sf.read(file_path.name)
32
+ inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True)
33
+
34
+ input_values = inputs.input_values.squeeze(0)
35
+ if max_length - input_values.shape[-1] > 0:
36
+ input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1)
37
+ else:
38
+ input_values = input_values[:max_length]
39
+ input_values = input_values.unsqueeze(0).to(device)
40
+ inputs = {"input_values": input_values}
41
+
42
+ logits = model(**inputs).logits
43
+ logits = logits.squeeze()
44
+ predicted_class_id = torch.argmax(logits, dim=-1).item()
45
+
46
+ return predicted_class_id
47
+
48
+ iface = gr.Interface(fn=predict, inputs="file", outputs="text")
49
+ iface.launch()