Spaces:
Runtime error
Runtime error
import gradio as gr | |
# from transformers import pipeline | |
# pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es") | |
# def predict(text): | |
# return pipe(text)[0]["translation_text"] | |
# iface = gr.Interface( | |
# fn=predict, | |
# inputs='text', | |
# outputs='text', | |
# examples=[["Hello! My name is Omar"]] | |
# ) | |
# iface.launch() | |
from Toxonomy.modules.confusionmatrix import plot_confusion_matrix | |
import glob | |
import pandas as pd | |
from Toxonomy.modules.preprocessor import preprocessing_for_bert | |
print("hello") | |
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler | |
import torch | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
import torch | |
import torch.nn as nn | |
from transformers import BertModel | |
from Toxonomy.modules.classifier import PretrainedBert,FinetunningBert,initialize_finetunningBert,finetunningBert_training,bertPredictions | |
from transformers import AdamW, get_linear_schedule_with_warmup | |
device = 0 | |
import random | |
import time | |
import torch.nn as nn | |
print("completed") | |
def Kmers_funct(seq, size): | |
return [seq[x:x+size].lower() for x in range(len(seq) - size + 1)] | |
def kmers_sentences(mySeq): | |
#Kmers_funct(mySeq, size=7) | |
words = Kmers_funct(mySeq, size=3) | |
joined_sentence = ' '.join(words) | |
return joined_sentence | |
import re | |
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler | |
import torch | |
import numpy as np | |
def predict(text): | |
device = 0 | |
print(text) | |
temp_df=pd.DataFrame([text]).astype('str') | |
temp_df.columns=['seq'] | |
mask = temp_df['seq'].str.len() <= 7000 | |
temp_df = temp_df.loc[mask] | |
temp_df['Processed']=temp_df['seq'].apply(kmers_sentences) #.reset_index() | |
test_inputs, test_masks = preprocessing_for_bert(temp_df['Processed']) | |
test_data = TensorDataset(test_inputs, test_masks, test_masks) | |
test_sampler = RandomSampler(test_data) | |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=8) | |
bert_classifier = PretrainedBert(freeze_bert=False) | |
bert_classifier.load_state_dict(torch.load("./virBERT.pt",map_location=torch.device('cpu'))) | |
print("location") | |
print(next(bert_classifier.parameters()).is_cuda) | |
#bert_classifier.to(device) | |
pred=bertPredictions(torch,bert_classifier,test_dataloader) | |
if(int(pred[0][0])==0): | |
out='IAV' | |
if(int(pred[0][0])==1): | |
out='IBV' | |
if(int(pred[0][0])==2): | |
out='SFTS' | |
if(int(pred[0][0])==3): | |
out='Dengue' | |
if(int(pred[0][0])==4): | |
out='EnteroA' | |
if(int(pred[0][0])==5): | |
out='EnteroB' | |
if(int(pred[0][0])==6): | |
out='ICV' | |
if(int(pred[0][0])==7): | |
out='HBV' | |
if(int(pred[0][0])==8): | |
out='HCV' | |
if(int(pred[0][0])==9): | |
out='HSV-1' | |
if(int(pred[0][0])==10): | |
out='HPV' | |
if(int(pred[0][0])==11): | |
out='MPV' | |
if(int(pred[0][0])==12): | |
out='WNV' | |
if(int(pred[0][0])==13): | |
out='ZIKA' | |
return out | |
iface = gr.Interface( | |
fn=predict, | |
inputs='text', | |
outputs='text' | |
) | |
iface.launch() |