VirBert2 / app.py
rajaatif786's picture
Update app.py
832a362
import gradio as gr
# from transformers import pipeline
# pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es")
# def predict(text):
# return pipe(text)[0]["translation_text"]
# iface = gr.Interface(
# fn=predict,
# inputs='text',
# outputs='text',
# examples=[["Hello! My name is Omar"]]
# )
# iface.launch()
from Toxonomy.modules.confusionmatrix import plot_confusion_matrix
import glob
import pandas as pd
from Toxonomy.modules.preprocessor import preprocessing_for_bert
print("hello")
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import torch
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from transformers import BertModel
from Toxonomy.modules.classifier import PretrainedBert,FinetunningBert,initialize_finetunningBert,finetunningBert_training,bertPredictions
from transformers import AdamW, get_linear_schedule_with_warmup
device = 0
import random
import time
import torch.nn as nn
print("completed")
def Kmers_funct(seq, size):
return [seq[x:x+size].lower() for x in range(len(seq) - size + 1)]
def kmers_sentences(mySeq):
#Kmers_funct(mySeq, size=7)
words = Kmers_funct(mySeq, size=3)
joined_sentence = ' '.join(words)
return joined_sentence
import re
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import torch
import numpy as np
def predict(text):
device = 0
print(text)
temp_df=pd.DataFrame([text]).astype('str')
temp_df.columns=['seq']
mask = temp_df['seq'].str.len() <= 7000
temp_df = temp_df.loc[mask]
temp_df['Processed']=temp_df['seq'].apply(kmers_sentences) #.reset_index()
test_inputs, test_masks = preprocessing_for_bert(temp_df['Processed'])
test_data = TensorDataset(test_inputs, test_masks, test_masks)
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=8)
bert_classifier = PretrainedBert(freeze_bert=False)
bert_classifier.load_state_dict(torch.load("./virBERT.pt",map_location=torch.device('cpu')))
print("location")
print(next(bert_classifier.parameters()).is_cuda)
#bert_classifier.to(device)
pred=bertPredictions(torch,bert_classifier,test_dataloader)
if(int(pred[0][0])==0):
out='IAV'
if(int(pred[0][0])==1):
out='IBV'
if(int(pred[0][0])==2):
out='SFTS'
if(int(pred[0][0])==3):
out='Dengue'
if(int(pred[0][0])==4):
out='EnteroA'
if(int(pred[0][0])==5):
out='EnteroB'
if(int(pred[0][0])==6):
out='ICV'
if(int(pred[0][0])==7):
out='HBV'
if(int(pred[0][0])==8):
out='HCV'
if(int(pred[0][0])==9):
out='HSV-1'
if(int(pred[0][0])==10):
out='HPV'
if(int(pred[0][0])==11):
out='MPV'
if(int(pred[0][0])==12):
out='WNV'
if(int(pred[0][0])==13):
out='ZIKA'
return out
iface = gr.Interface(
fn=predict,
inputs='text',
outputs='text'
)
iface.launch()