ajsdasdjks
Browse files- app.py +11 -1
- requirements.txt +4 -1
app.py
CHANGED
@@ -2,10 +2,20 @@ import streamlit as st
|
|
2 |
from transformers import HubertForSequenceClassification, HubertConfig, Wav2Vec2FeatureExtractor
|
3 |
import torch
|
4 |
import soundfile as sf
|
|
|
|
|
|
|
5 |
|
6 |
# Load model and tokenizer
|
7 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
8 |
-
model_name = "model_hubert_finetuned_nopeft.pth"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
config = HubertConfig.from_pretrained("superb/hubert-large-superb-er")
|
10 |
config.id2label = {0: 'neu', 1: 'hap', 2: 'ang', 3: 'sad', 4: 'dis', 5: 'sur', 6: 'fea', 7: 'cal'}
|
11 |
config.label2id = {"neu": 0, "hap": 1, "ang": 2, "sad": 3, "dis": 4, "sur": 5, "fea": 6, "cal": 7}
|
|
|
2 |
from transformers import HubertForSequenceClassification, HubertConfig, Wav2Vec2FeatureExtractor
|
3 |
import torch
|
4 |
import soundfile as sf
|
5 |
+
import gdown
|
6 |
+
|
7 |
+
file_id = "1xm9Uf7_wn3VR2ivuftCW0jkz5bDC0YxF"
|
8 |
|
9 |
# Load model and tokenizer
|
10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
+
model_name = "model_hubert_finetuned_nopeft.pth"
|
12 |
+
if not os.path.exists(model_name):
|
13 |
+
print(f"Downloading {model_name} from Google Drive...")
|
14 |
+
gdown.download(f'https://drive.google.com/uc?id={file_id}', model_name, quiet=False)
|
15 |
+
else:
|
16 |
+
print(f"{output} already exists, skipping download.")
|
17 |
+
|
18 |
+
# Replace with your model path or Hugging Face model hub path
|
19 |
config = HubertConfig.from_pretrained("superb/hubert-large-superb-er")
|
20 |
config.id2label = {0: 'neu', 1: 'hap', 2: 'ang', 3: 'sad', 4: 'dis', 5: 'sur', 6: 'fea', 7: 'cal'}
|
21 |
config.label2id = {"neu": 0, "hap": 1, "ang": 2, "sad": 3, "dis": 4, "sur": 5, "fea": 6, "cal": 7}
|
requirements.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
streamlit
|
2 |
transformers
|
3 |
torch
|
4 |
-
soundfile
|
|
|
|
|
|
|
|
1 |
streamlit
|
2 |
transformers
|
3 |
torch
|
4 |
+
soundfile
|
5 |
+
gdown
|
6 |
+
flask_cors
|
7 |
+
flask
|