Pavankalyan commited on
Commit
6afc25f
1 Parent(s): a93df0f

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from model import Wav2VecModel
3
+ from dataset import S2IDataset, collate_fn
4
+ import requests
5
+ requests.packages.urllib3.disable_warnings()
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchaudio
10
+ import torch.nn.functional as F
11
+ import pytorch_lightning as pl
12
+
13
+ from pytorch_lightning import Trainer
14
+ from pytorch_lightning.callbacks import ModelCheckpoint
15
+ from pytorch_lightning.loggers import WandbLogger
16
+
17
+ # SEED
18
+ SEED=100
19
+ pl.utilities.seed.seed_everything(SEED)
20
+ torch.manual_seed(SEED)
21
+
22
+ import os
23
+ os.environ['WANDB_MODE'] = 'online'
24
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
25
+ os.environ["CUDA_VISIBLE_DEVICES"]="1"
26
+
27
+ class LightningModel(pl.LightningModule):
28
+ def __init__(self,):
29
+ super().__init__()
30
+ self.model = Wav2VecModel()
31
+
32
+ def forward(self, x):
33
+ return self.model(x)
34
+
35
+ def configure_optimizers(self):
36
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
37
+ return [optimizer]
38
+
39
+ def loss_fn(self, prediction, targets):
40
+ return nn.CrossEntropyLoss()(prediction, targets)
41
+
42
+ def training_step(self, batch, batch_idx):
43
+ x, y = batch
44
+ y = y.view(-1)
45
+
46
+ logits = self(x)
47
+ probs = F.softmax(logits, dim=1)
48
+ loss = self.loss_fn(logits, y)
49
+
50
+ winners = logits.argmax(dim=1)
51
+ corrects = (winners == y)
52
+ acc = corrects.sum().float()/float(logits.size(0))
53
+
54
+ self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
55
+ self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
56
+ torch.cuda.empty_cache()
57
+ return {
58
+ 'loss':loss,
59
+ 'acc':acc
60
+ }
61
+
62
+ def validation_step(self, batch, batch_idx):
63
+ x, y = batch
64
+ y = y.view(-1)
65
+
66
+ logits = self(x)
67
+ loss = self.loss_fn(logits, y)
68
+
69
+ winners = logits.argmax(dim=1)
70
+ corrects = (winners == y)
71
+ acc = corrects.sum().float() / float( logits.size(0))
72
+
73
+ self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True)
74
+ self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True)
75
+
76
+ return {'val_loss':loss,
77
+ 'val_acc':acc,
78
+ }
79
+
80
+ def test_step(self, batch, batch_idx):
81
+ x, y = batch
82
+ y = y.view(-1)
83
+
84
+ logits = self(x)
85
+ loss = self.loss_fn(logits, y)
86
+
87
+ winners = logits.argmax(dim=1)
88
+ corrects = (winners == y)
89
+ acc = corrects.sum().float() / float( logits.size(0))
90
+
91
+ self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True)
92
+ self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True)
93
+
94
+ return {'val_loss':loss,
95
+ 'val_acc':acc,
96
+ }
97
+
98
+ def predict(self, wav):
99
+ self.eval()
100
+ with torch.no_grad():
101
+ output = self.forward(wav)
102
+ predicted_class = torch.argmax(output, dim=1)
103
+ return predicted_class
104
+
105
+ if __name__ == "__main__":
106
+
107
+ dataset = S2IDataset(
108
+ csv_path="./speech-to-intent/train.csv",
109
+ wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent"
110
+ )
111
+
112
+ test_dataset = S2IDataset(
113
+ csv_path="./speech-to-intent/test.csv",
114
+ wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent"
115
+ )
116
+
117
+ train_len = int(len(dataset) * 0.90)
118
+ val_len = len(dataset) - train_len
119
+ print(train_len, val_len)
120
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len], generator=torch.Generator().manual_seed(SEED))
121
+ print(len(test_dataset))
122
+
123
+ trainloader = torch.utils.data.DataLoader(
124
+ train_dataset,
125
+ batch_size=4,
126
+ shuffle=True,
127
+ num_workers=4,
128
+ collate_fn = collate_fn,
129
+ )
130
+
131
+ valloader = torch.utils.data.DataLoader(
132
+ val_dataset,
133
+ batch_size=4,
134
+ num_workers=4,
135
+ collate_fn = collate_fn,
136
+ )
137
+
138
+ testloader = torch.utils.data.DataLoader(
139
+ test_dataset,
140
+ #batch_size=4,
141
+ num_workers=4,
142
+ collate_fn = collate_fn,
143
+ )
144
+
145
+ print(torch.cuda.mem_get_info())
146
+
147
+ model = LightningModel()
148
+
149
+ run_name = "wav2vec"
150
+ logger = WandbLogger(
151
+ name=run_name,
152
+ project='S2I-baseline'
153
+ )
154
+
155
+ model_checkpoint_callback = ModelCheckpoint(
156
+ dirpath='checkpoints',
157
+ monitor='val/acc',
158
+ mode='max',
159
+ verbose=1,
160
+ filename=run_name + "-epoch={epoch}.ckpt")
161
+
162
+ trainer = Trainer(
163
+ fast_dev_run=False,
164
+ gpus=1,
165
+ max_epochs=5,
166
+ checkpoint_callback=True,
167
+ callbacks=[
168
+ model_checkpoint_callback,
169
+ ],
170
+ logger=logger,
171
+ )
172
+ checkpoint_path = "./checkpoints/wav2vec-epoch=epoch=4.ckpt.ckpt"
173
+ checkpoint = torch.load(checkpoint_path)
174
+ model.load_state_dict(checkpoint['state_dict'])
175
+ trainer = Trainer(
176
+ gpus=1
177
+ )
178
+
179
+ #trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader)
180
+ #trainer.test(model,dataloaders=testloader,verbose=True)
181
+ wav_path = "./speech-to-intent/wav_audios/92145547-3ab6-44e0-9245-085642fc4318.wav"
182
+ resmaple = torchaudio.transforms.Resample(8000, 16000)
183
+ wav_tensor,_ = torchaudio.load(wav_path)
184
+ wav_tensor = resmaple(wav_tensor)
185
+ model = model.to('cuda')
186
+ y_hat = model.predict(wav_tensor)
187
+ #with torch.no_grad():
188
+ # y_hat = model(wav_tensor)
189
+
190
+ print(y_hat)
191
+
192
+