import gradio as gr import requests import numpy as np import torch import torch.nn as nn import torchvision.transforms as transform import os import logging from transformers import WEIGHTS_NAME,AdamW,AlbertConfig,AlbertTokenizer,BertConfig,BertTokenizer from pabee.modeling_albert import AlbertForSequenceClassification from pabee.modeling_bert import BertForSequenceClassification from transformers import glue_output_modes as output_modes from transformers import glue_processors as processors import datasets from whitebox_utils.classifier import MyClassifier import random import numpy as np import torch import argparse def random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) logger = logging.getLogger(__name__) # TODO: dataset model tokenizer etc. best_model_path = { 'albert_STS-B':'./outputs/train/albert/STS-B/checkpoint-600', } MODEL_CLASSES = { "bert": (BertConfig, BertForSequenceClassification, BertTokenizer), "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer), } model_name = 'albert' dataset = 'STS-B' task_name = f'{dataset}'.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() # transformers package-preprocessor output_mode = output_modes[task_name] # output type label_list = processor.get_labels() num_labels = len(label_list) output_dir = f'./outputs/train/{model_name}/{dataset}' data_dir = f'./glue_data/{dataset}' config_class, model_class, tokenizer_class = MODEL_CLASSES[model_name] tokenizer = tokenizer_class.from_pretrained(output_dir, do_lower_case=True) model = model_class.from_pretrained(best_model_path[f'{model_name}_{dataset}']) exit_type='patience' exit_value=3 classifier = MyClassifier(model,tokenizer,label_list,output_mode,exit_type,exit_value,model_name) def greet(text,text2): text_input = [(text,text2)] return classifier.get_prob(text_input),classifier.get_current_exit() iface = gr.Interface(fn=greet, inputs=["text","text"], outputs=["text","number"]) iface.launch()