Pavankalyan commited on
Commit
8d5928a
1 Parent(s): e1522cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -61
app.py CHANGED
@@ -3,7 +3,7 @@ from model import Wav2VecModel
3
  from dataset import S2IDataset, collate_fn
4
  import requests
5
  requests.packages.urllib3.disable_warnings()
6
-
7
  import torch
8
  import torch.nn as nn
9
  import torchaudio
@@ -103,72 +103,12 @@ class LightningModel(pl.LightningModule):
103
  return predicted_class
104
 
105
  if __name__ == "__main__":
106
-
107
- dataset = S2IDataset(
108
- csv_path="./speech-to-intent/train.csv",
109
- wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent"
110
- )
111
-
112
- test_dataset = S2IDataset(
113
- csv_path="./speech-to-intent/test.csv",
114
- wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent"
115
- )
116
-
117
- train_len = int(len(dataset) * 0.90)
118
- val_len = len(dataset) - train_len
119
- print(train_len, val_len)
120
- train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len], generator=torch.Generator().manual_seed(SEED))
121
- print(len(test_dataset))
122
-
123
- trainloader = torch.utils.data.DataLoader(
124
- train_dataset,
125
- batch_size=4,
126
- shuffle=True,
127
- num_workers=4,
128
- collate_fn = collate_fn,
129
- )
130
-
131
- valloader = torch.utils.data.DataLoader(
132
- val_dataset,
133
- batch_size=4,
134
- num_workers=4,
135
- collate_fn = collate_fn,
136
- )
137
-
138
- testloader = torch.utils.data.DataLoader(
139
- test_dataset,
140
- #batch_size=4,
141
- num_workers=4,
142
- collate_fn = collate_fn,
143
- )
144
-
145
  print(torch.cuda.mem_get_info())
146
 
147
  model = LightningModel()
148
 
149
  run_name = "wav2vec"
150
- logger = WandbLogger(
151
- name=run_name,
152
- project='S2I-baseline'
153
- )
154
-
155
- model_checkpoint_callback = ModelCheckpoint(
156
- dirpath='checkpoints',
157
- monitor='val/acc',
158
- mode='max',
159
- verbose=1,
160
- filename=run_name + "-epoch={epoch}.ckpt")
161
 
162
- trainer = Trainer(
163
- fast_dev_run=False,
164
- gpus=1,
165
- max_epochs=5,
166
- checkpoint_callback=True,
167
- callbacks=[
168
- model_checkpoint_callback,
169
- ],
170
- logger=logger,
171
- )
172
  checkpoint_path = "./checkpoints/wav2vec-epoch=epoch=4.ckpt.ckpt"
173
  checkpoint = torch.load(checkpoint_path)
174
  model.load_state_dict(checkpoint['state_dict'])
@@ -187,6 +127,29 @@ if __name__ == "__main__":
187
  #with torch.no_grad():
188
  # y_hat = model(wav_tensor)
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  print(y_hat)
 
 
191
 
192
 
 
3
  from dataset import S2IDataset, collate_fn
4
  import requests
5
  requests.packages.urllib3.disable_warnings()
6
+ import gradio as gr
7
  import torch
8
  import torch.nn as nn
9
  import torchaudio
 
103
  return predicted_class
104
 
105
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  print(torch.cuda.mem_get_info())
107
 
108
  model = LightningModel()
109
 
110
  run_name = "wav2vec"
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
 
 
 
 
 
 
 
 
112
  checkpoint_path = "./checkpoints/wav2vec-epoch=epoch=4.ckpt.ckpt"
113
  checkpoint = torch.load(checkpoint_path)
114
  model.load_state_dict(checkpoint['state_dict'])
 
127
  #with torch.no_grad():
128
  # y_hat = model(wav_tensor)
129
 
130
+ def trabscribe(audio):
131
+ wav_tensor,_ = audio
132
+ wav_tensor = resmaple(wav_tensor)
133
+ #model = model.to('cuda')
134
+ y_hat = model.predict(wav_tensor)
135
+ labels = {0:"branch_address : enquiry about bank branch location",
136
+ 1:"activate_card : enquiry about activating card products",
137
+ 2:"past_transactions : enquiry about past transactions in a specific time period",
138
+ 3:"dispatch_status : enquiry about the dispatch status of card products",
139
+ 4:"outstanding_balance : enquiry about outstanding balance on card products",
140
+ 5:"card_issue : report about an issue with using card products",
141
+ 6:"ifsc_code : enquiry about IFSC code of bank branch",
142
+ 7:"generate_pin : enquiry about changing or generating a new pin for their card product",
143
+ 8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction",
144
+ 9:"loan_query : enquiry about different kinds of loans",
145
+ 10:"balance_enquiry : enquiry about bank account balance",
146
+ 11:"change_limit : enquiry about changing the limit for card products",
147
+ 12:"block : enquiry about blocking card or banking product",
148
+ 13:"lost : report about losing a card product}
149
+ return labels[y_hat]
150
+
151
  print(y_hat)
152
+ get_intent = gr.Interface(fn = transcribe,
153
+ gr.Audio(source="microphone"), outputs="text").launch()
154
 
155