from model import Wav2VecModel from dataset import S2IDataset, collate_fn import requests requests.packages.urllib3.disable_warnings() import gradio as gr import torch import torch.nn as nn import torchaudio import torch.nn.functional as F import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger # SEED SEED=100 pl.seed_everything(SEED) torch.manual_seed(SEED) import os os.environ['WANDB_MODE'] = 'online' os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"]="1" class LightningModel(pl.LightningModule): def __init__(self,): super().__init__() self.model = Wav2VecModel() def forward(self, x): return self.model(x) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-5) return [optimizer] def loss_fn(self, prediction, targets): return nn.CrossEntropyLoss()(prediction, targets) def training_step(self, batch, batch_idx): x, y = batch y = y.view(-1) logits = self(x) probs = F.softmax(logits, dim=1) loss = self.loss_fn(logits, y) winners = logits.argmax(dim=1) corrects = (winners == y) acc = corrects.sum().float()/float(logits.size(0)) self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True) self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True) torch.cuda.empty_cache() return { 'loss':loss, 'acc':acc } def validation_step(self, batch, batch_idx): x, y = batch y = y.view(-1) logits = self(x) loss = self.loss_fn(logits, y) winners = logits.argmax(dim=1) corrects = (winners == y) acc = corrects.sum().float() / float( logits.size(0)) self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True) self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True) return {'val_loss':loss, 'val_acc':acc, } def test_step(self, batch, batch_idx): x, y = batch y = y.view(-1) logits = self(x) loss = self.loss_fn(logits, y) winners = logits.argmax(dim=1) corrects = (winners == y) acc = corrects.sum().float() / float( logits.size(0)) self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True) self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True) return {'val_loss':loss, 'val_acc':acc, } def predict(self, wav): self.eval() with torch.no_grad(): output = self.forward(wav) predicted_class = torch.argmax(output, dim=1) return predicted_class model = LightningModel() run_name = "wav2vec" checkpoint_path = "./wav2vec-epoch=epoch=4.ckpt.ckpt" checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu')) model.load_state_dict(checkpoint['state_dict']) trainer = Trainer() #trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader) #trainer.test(model,dataloaders=testloader,verbose=True) #with torch.no_grad(): # y_hat = model(wav_tensor) def trabscribe(audio): wav_tensor,_ = audio wav_tensor = resmaple(wav_tensor) #model = model.to('cuda') y_hat = model.predict(wav_tensor) labels = {0:"branch_address : enquiry about bank branch location", 1:"activate_card : enquiry about activating card products", 2:"past_transactions : enquiry about past transactions in a specific time period", 3:"dispatch_status : enquiry about the dispatch status of card products", 4:"outstanding_balance : enquiry about outstanding balance on card products", 5:"card_issue : report about an issue with using card products", 6:"ifsc_code : enquiry about IFSC code of bank branch", 7:"generate_pin : enquiry about changing or generating a new pin for their card product", 8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction", 9:"loan_query : enquiry about different kinds of loans", 10:"balance_enquiry : enquiry about bank account balance", 11:"change_limit : enquiry about changing the limit for card products", 12:"block : enquiry about blocking card or banking product", 13:"lost : report about losing a card product"} return labels[y_hat] get_intent = gr.Interface(fn = transcribe, inputs=gr.Audio(source="microphone", type="filepath"), outputs="text").launch()