Pavankalyan commited on
Commit
d6dd1d2
1 Parent(s): 592e443

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -133
app.py CHANGED
@@ -1,146 +1,40 @@
1
-
2
- from model import Wav2VecModel
3
- from dataset import S2IDataset, collate_fn
4
- import requests
5
- requests.packages.urllib3.disable_warnings()
6
  import gradio as gr
7
- import torch
8
- import torch.nn as nn
9
- import torchaudio
10
- import torch.nn.functional as F
11
- import pytorch_lightning as pl
12
-
13
- from pytorch_lightning import Trainer
14
- from pytorch_lightning.callbacks import ModelCheckpoint
15
- from pytorch_lightning.loggers import WandbLogger
16
-
17
- # SEED
18
- SEED=100
19
- pl.seed_everything(SEED)
20
- torch.manual_seed(SEED)
21
-
22
  import os
23
- os.environ['WANDB_MODE'] = 'online'
24
- os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
25
- os.environ["CUDA_VISIBLE_DEVICES"]="1"
26
-
27
- class LightningModel(pl.LightningModule):
28
- def __init__(self,):
29
- super().__init__()
30
- self.model = Wav2VecModel()
31
-
32
- def forward(self, x):
33
- return self.model(x)
34
-
35
- def configure_optimizers(self):
36
- optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
37
- return [optimizer]
38
-
39
- def loss_fn(self, prediction, targets):
40
- return nn.CrossEntropyLoss()(prediction, targets)
41
-
42
- def training_step(self, batch, batch_idx):
43
- x, y = batch
44
- y = y.view(-1)
45
-
46
- logits = self(x)
47
- probs = F.softmax(logits, dim=1)
48
- loss = self.loss_fn(logits, y)
49
-
50
- winners = logits.argmax(dim=1)
51
- corrects = (winners == y)
52
- acc = corrects.sum().float()/float(logits.size(0))
53
-
54
- self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
55
- self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
56
- torch.cuda.empty_cache()
57
- return {
58
- 'loss':loss,
59
- 'acc':acc
60
- }
61
-
62
- def validation_step(self, batch, batch_idx):
63
- x, y = batch
64
- y = y.view(-1)
65
-
66
- logits = self(x)
67
- loss = self.loss_fn(logits, y)
68
-
69
- winners = logits.argmax(dim=1)
70
- corrects = (winners == y)
71
- acc = corrects.sum().float() / float( logits.size(0))
72
-
73
- self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True)
74
- self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True)
75
-
76
- return {'val_loss':loss,
77
- 'val_acc':acc,
78
- }
79
-
80
- def test_step(self, batch, batch_idx):
81
- x, y = batch
82
- y = y.view(-1)
83
-
84
- logits = self(x)
85
- loss = self.loss_fn(logits, y)
86
-
87
- winners = logits.argmax(dim=1)
88
- corrects = (winners == y)
89
- acc = corrects.sum().float() / float( logits.size(0))
90
-
91
- self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True)
92
- self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True)
93
-
94
- return {'val_loss':loss,
95
- 'val_acc':acc,
96
- }
97
-
98
- def predict(self, wav):
99
- self.eval()
100
- with torch.no_grad():
101
- output = self.forward(wav)
102
- predicted_class = torch.argmax(output, dim=1)
103
- return predicted_class
104
-
105
 
 
 
 
 
 
 
106
 
107
- model = LightningModel()
 
 
108
 
109
- run_name = "wav2vec"
 
 
 
110
 
111
- checkpoint_path = "./wav2vec-epoch=epoch=4.ckpt.ckpt"
112
- checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))
113
- model.load_state_dict(checkpoint['state_dict'])
114
- trainer = Trainer()
 
 
 
115
 
116
- #trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader)
117
- #trainer.test(model,dataloaders=testloader,verbose=True)
118
 
119
- #with torch.no_grad():
120
- # y_hat = model(wav_tensor)
121
 
122
  def transcribe(audio):
123
- resmaple = torchaudio.transforms.Resample(8000, 16000)
124
- wav_tensor,_ = torchaudio.load(audio)
125
- #sr, wav_tensor = audio
126
- wav_tensor = resmaple(wav_tensor)
127
- #model = model.to('cuda')
128
- y_hat = model.predict(wav_tensor)
129
- labels = {0:"branch_address : enquiry about bank branch location",
130
- 1:"activate_card : enquiry about activating card products",
131
- 2:"past_transactions : enquiry about past transactions in a specific time period",
132
- 3:"dispatch_status : enquiry about the dispatch status of card products",
133
- 4:"outstanding_balance : enquiry about outstanding balance on card products",
134
- 5:"card_issue : report about an issue with using card products",
135
- 6:"ifsc_code : enquiry about IFSC code of bank branch",
136
- 7:"generate_pin : enquiry about changing or generating a new pin for their card product",
137
- 8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction",
138
- 9:"loan_query : enquiry about different kinds of loans",
139
- 10:"balance_enquiry : enquiry about bank account balance",
140
- 11:"change_limit : enquiry about changing the limit for card products",
141
- 12:"block : enquiry about blocking card or banking product",
142
- 13:"lost : report about losing a card product"}
143
- return labels[y_hat[0].item()]
144
 
145
 
146
  get_intent = gr.Interface(fn = transcribe,
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ data_dict = {}
7
+ with open('./results_classification/file.json', 'r') as file:
8
+ data = json.load(file)
9
+ intents_dict = data
10
+ tokenizer = AutoTokenizer.from_pretrained("roberta-base")
11
+ model = AutoModelForSequenceClassification.from_pretrained("./results_classification/checkpoint-1890/")
12
 
13
+ def preprocess(text):
14
+ inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
15
+ return inputs
16
 
17
+ def postprocess(outputs):
18
+ logits = outputs.logits
19
+ predicted_labels = logits.argmax(dim=1).tolist()
20
+ return predicted_labels
21
 
22
+ def predict(text):
23
+ inputs = preprocess(text)
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ predicted_labels = postprocess(outputs)
27
+ ans = intents_dict[predicted_labels[0]]
28
+ return ans
29
 
30
+ from transformers import pipeline
 
31
 
32
+ p = pipeline(model="openai/whisper-medium")
 
33
 
34
  def transcribe(audio):
35
+ t = p(audio)['text']
36
+ ans = predict(t)
37
+ return ans
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  get_intent = gr.Interface(fn = transcribe,