Pavankalyan commited on
Commit
fce051b
1 Parent(s): c979e2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -46
app.py CHANGED
@@ -102,49 +102,48 @@ class LightningModel(pl.LightningModule):
102
  predicted_class = torch.argmax(output, dim=1)
103
  return predicted_class
104
 
105
- if __name__ == "__main__":
106
- print(torch.cuda.mem_get_info())
107
-
108
- model = LightningModel()
109
-
110
- run_name = "wav2vec"
111
-
112
- checkpoint_path = "./wav2vec-epoch=epoch=4.ckpt.ckpt"
113
- checkpoint = torch.load(checkpoint_path)
114
- model.load_state_dict(checkpoint['state_dict'])
115
- trainer = Trainer(
116
- gpus=1
117
- )
118
-
119
- #trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader)
120
- #trainer.test(model,dataloaders=testloader,verbose=True)
121
-
122
- #with torch.no_grad():
123
- # y_hat = model(wav_tensor)
124
-
125
- def trabscribe(audio):
126
- wav_tensor,_ = audio
127
- wav_tensor = resmaple(wav_tensor)
128
- #model = model.to('cuda')
129
- y_hat = model.predict(wav_tensor)
130
- labels = {0:"branch_address : enquiry about bank branch location",
131
- 1:"activate_card : enquiry about activating card products",
132
- 2:"past_transactions : enquiry about past transactions in a specific time period",
133
- 3:"dispatch_status : enquiry about the dispatch status of card products",
134
- 4:"outstanding_balance : enquiry about outstanding balance on card products",
135
- 5:"card_issue : report about an issue with using card products",
136
- 6:"ifsc_code : enquiry about IFSC code of bank branch",
137
- 7:"generate_pin : enquiry about changing or generating a new pin for their card product",
138
- 8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction",
139
- 9:"loan_query : enquiry about different kinds of loans",
140
- 10:"balance_enquiry : enquiry about bank account balance",
141
- 11:"change_limit : enquiry about changing the limit for card products",
142
- 12:"block : enquiry about blocking card or banking product",
143
- 13:"lost : report about losing a card product}
144
- return labels[y_hat]
145
-
146
- print(y_hat)
147
- get_intent = gr.Interface(fn = transcribe,
148
- gr.Audio(source="microphone"), outputs="text").launch()
149
-
150
-
 
102
  predicted_class = torch.argmax(output, dim=1)
103
  return predicted_class
104
 
105
+
106
+ print(torch.cuda.mem_get_info())
107
+
108
+ model = LightningModel()
109
+
110
+ run_name = "wav2vec"
111
+
112
+ checkpoint_path = "./wav2vec-epoch=epoch=4.ckpt.ckpt"
113
+ checkpoint = torch.load(checkpoint_path)
114
+ model.load_state_dict(checkpoint['state_dict'])
115
+ trainer = Trainer(
116
+ gpus=1
117
+ )
118
+
119
+ #trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader)
120
+ #trainer.test(model,dataloaders=testloader,verbose=True)
121
+
122
+ #with torch.no_grad():
123
+ # y_hat = model(wav_tensor)
124
+
125
+ def trabscribe(audio):
126
+ wav_tensor,_ = audio
127
+ wav_tensor = resmaple(wav_tensor)
128
+ #model = model.to('cuda')
129
+ y_hat = model.predict(wav_tensor)
130
+ labels = {0:"branch_address : enquiry about bank branch location",
131
+ 1:"activate_card : enquiry about activating card products",
132
+ 2:"past_transactions : enquiry about past transactions in a specific time period",
133
+ 3:"dispatch_status : enquiry about the dispatch status of card products",
134
+ 4:"outstanding_balance : enquiry about outstanding balance on card products",
135
+ 5:"card_issue : report about an issue with using card products",
136
+ 6:"ifsc_code : enquiry about IFSC code of bank branch",
137
+ 7:"generate_pin : enquiry about changing or generating a new pin for their card product",
138
+ 8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction",
139
+ 9:"loan_query : enquiry about different kinds of loans",
140
+ 10:"balance_enquiry : enquiry about bank account balance",
141
+ 11:"change_limit : enquiry about changing the limit for card products",
142
+ 12:"block : enquiry about blocking card or banking product",
143
+ 13:"lost : report about losing a card product}
144
+ return labels[y_hat]
145
+
146
+ print(y_hat)
147
+ get_intent = gr.Interface(fn = transcribe,
148
+ gr.Audio(source="microphone"), outputs="text").launch()
149
+