b1nay commited on
Commit
7774278
1 Parent(s): ce9048a

ajsdasdjks

Browse files
Files changed (2) hide show
  1. app.py +11 -1
  2. requirements.txt +4 -1
app.py CHANGED
@@ -2,10 +2,20 @@ import streamlit as st
2
  from transformers import HubertForSequenceClassification, HubertConfig, Wav2Vec2FeatureExtractor
3
  import torch
4
  import soundfile as sf
 
 
 
5
 
6
  # Load model and tokenizer
7
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
- model_name = "model_hubert_finetuned_nopeft.pth" # Replace with your model path or Hugging Face model hub path
 
 
 
 
 
 
 
9
  config = HubertConfig.from_pretrained("superb/hubert-large-superb-er")
10
  config.id2label = {0: 'neu', 1: 'hap', 2: 'ang', 3: 'sad', 4: 'dis', 5: 'sur', 6: 'fea', 7: 'cal'}
11
  config.label2id = {"neu": 0, "hap": 1, "ang": 2, "sad": 3, "dis": 4, "sur": 5, "fea": 6, "cal": 7}
 
2
  from transformers import HubertForSequenceClassification, HubertConfig, Wav2Vec2FeatureExtractor
3
  import torch
4
  import soundfile as sf
5
+ import gdown
6
+
7
+ file_id = "1xm9Uf7_wn3VR2ivuftCW0jkz5bDC0YxF"
8
 
9
  # Load model and tokenizer
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ model_name = "model_hubert_finetuned_nopeft.pth"
12
+ if not os.path.exists(model_name):
13
+ print(f"Downloading {model_name} from Google Drive...")
14
+ gdown.download(f'https://drive.google.com/uc?id={file_id}', model_name, quiet=False)
15
+ else:
16
+ print(f"{output} already exists, skipping download.")
17
+
18
+ # Replace with your model path or Hugging Face model hub path
19
  config = HubertConfig.from_pretrained("superb/hubert-large-superb-er")
20
  config.id2label = {0: 'neu', 1: 'hap', 2: 'ang', 3: 'sad', 4: 'dis', 5: 'sur', 6: 'fea', 7: 'cal'}
21
  config.label2id = {"neu": 0, "hap": 1, "ang": 2, "sad": 3, "dis": 4, "sur": 5, "fea": 6, "cal": 7}
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  streamlit
2
  transformers
3
  torch
4
- soundfile
 
 
 
 
1
  streamlit
2
  transformers
3
  torch
4
+ soundfile
5
+ gdown
6
+ flask_cors
7
+ flask