Spaces:
Running
Running
import gradio as gr | |
import requests | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transform | |
import os | |
import logging | |
from transformers import WEIGHTS_NAME,AdamW,AlbertConfig,AlbertTokenizer,BertConfig,BertTokenizer | |
from pabee.modeling_albert import AlbertForSequenceClassification | |
from pabee.modeling_bert import BertForSequenceClassification | |
from transformers import glue_output_modes as output_modes | |
from transformers import glue_processors as processors | |
import datasets | |
from whitebox_utils.classifier import MyClassifier | |
import random | |
import numpy as np | |
import torch | |
import argparse | |
def random_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
logger = logging.getLogger(__name__) | |
# TODO: dataset model tokenizer etc. | |
best_model_path = { | |
'albert_STS-B':'./outputs/train/albert/STS-B/checkpoint-600', | |
} | |
MODEL_CLASSES = { | |
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer), | |
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer), | |
} | |
model_name = 'albert' | |
dataset = 'STS-B' | |
task_name = f'{dataset}'.lower() | |
if task_name not in processors: | |
raise ValueError("Task not found: %s" % (task_name)) | |
processor = processors[task_name]() # transformers package-preprocessor | |
output_mode = output_modes[task_name] # output type | |
label_list = processor.get_labels() | |
num_labels = len(label_list) | |
output_dir = f'./outputs/train/{mode_name}/{dataset}' | |
data_dir = f'./glue_data/{dataset}' | |
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_name] | |
tokenizer = tokenizer_class.from_pretrained(output_dir, do_lower_case=True) | |
model = model_class.from_pretrained(best_model_path[f'{model_name}_{dataset}']) | |
exit_type='patience' | |
exit_value=3 | |
classifier = MyClassifier(model,tokenizer,label_list,output_mode,exit_type,exit_value,model_name) | |
def greet(text,text2,exit_pos): | |
text_input = [(text,text2)] | |
classifier.get_prob_time(text_input,exit_position=exit_pos) | |
iface = gr.Interface(fn=greet, inputs=["text","text","number"], outputs="text") | |
iface.launch() |