S2I / app.py
Pavankalyan's picture
Update app.py
fce051b
raw
history blame
4.72 kB
from model import Wav2VecModel
from dataset import S2IDataset, collate_fn
import requests
requests.packages.urllib3.disable_warnings()
import gradio as gr
import torch
import torch.nn as nn
import torchaudio
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
# SEED
SEED=100
pl.utilities.seed.seed_everything(SEED)
torch.manual_seed(SEED)
import os
os.environ['WANDB_MODE'] = 'online'
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
class LightningModel(pl.LightningModule):
def __init__(self,):
super().__init__()
self.model = Wav2VecModel()
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
return [optimizer]
def loss_fn(self, prediction, targets):
return nn.CrossEntropyLoss()(prediction, targets)
def training_step(self, batch, batch_idx):
x, y = batch
y = y.view(-1)
logits = self(x)
probs = F.softmax(logits, dim=1)
loss = self.loss_fn(logits, y)
winners = logits.argmax(dim=1)
corrects = (winners == y)
acc = corrects.sum().float()/float(logits.size(0))
self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
torch.cuda.empty_cache()
return {
'loss':loss,
'acc':acc
}
def validation_step(self, batch, batch_idx):
x, y = batch
y = y.view(-1)
logits = self(x)
loss = self.loss_fn(logits, y)
winners = logits.argmax(dim=1)
corrects = (winners == y)
acc = corrects.sum().float() / float( logits.size(0))
self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True)
return {'val_loss':loss,
'val_acc':acc,
}
def test_step(self, batch, batch_idx):
x, y = batch
y = y.view(-1)
logits = self(x)
loss = self.loss_fn(logits, y)
winners = logits.argmax(dim=1)
corrects = (winners == y)
acc = corrects.sum().float() / float( logits.size(0))
self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True)
return {'val_loss':loss,
'val_acc':acc,
}
def predict(self, wav):
self.eval()
with torch.no_grad():
output = self.forward(wav)
predicted_class = torch.argmax(output, dim=1)
return predicted_class
print(torch.cuda.mem_get_info())
model = LightningModel()
run_name = "wav2vec"
checkpoint_path = "./wav2vec-epoch=epoch=4.ckpt.ckpt"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
trainer = Trainer(
gpus=1
)
#trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader)
#trainer.test(model,dataloaders=testloader,verbose=True)
#with torch.no_grad():
# y_hat = model(wav_tensor)
def trabscribe(audio):
wav_tensor,_ = audio
wav_tensor = resmaple(wav_tensor)
#model = model.to('cuda')
y_hat = model.predict(wav_tensor)
labels = {0:"branch_address : enquiry about bank branch location",
1:"activate_card : enquiry about activating card products",
2:"past_transactions : enquiry about past transactions in a specific time period",
3:"dispatch_status : enquiry about the dispatch status of card products",
4:"outstanding_balance : enquiry about outstanding balance on card products",
5:"card_issue : report about an issue with using card products",
6:"ifsc_code : enquiry about IFSC code of bank branch",
7:"generate_pin : enquiry about changing or generating a new pin for their card product",
8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction",
9:"loan_query : enquiry about different kinds of loans",
10:"balance_enquiry : enquiry about bank account balance",
11:"change_limit : enquiry about changing the limit for card products",
12:"block : enquiry about blocking card or banking product",
13:"lost : report about losing a card product}
return labels[y_hat]
print(y_hat)
get_intent = gr.Interface(fn = transcribe,
gr.Audio(source="microphone"), outputs="text").launch()