import gradio as gr import torch from minicons import cwe import pandas as pd import numpy as np from model import FeatureNormPredictor def predict (word, sentence, lm_name, layer, norm): if word not in sentence: return "invalid input: word not in sentence" model_name = lm_name + str(layer) + '_to_' + norm lm = cwe.CWE('bert-base-uncased') if layer not in range (lm.layers): return "invalid input: layer not in lm" model = FeatureNormPredictor.load_from_checkpoint( checkpoint_path=model_name+'.ckpt', map_location=None ) model.eval() inputs = [word, sentence, lm_name, str(layer), norm] outputs = [input+'\t'+str(np.random.randint(0,100, size=1)[0]) for input in inputs] return "\n".join(outputs) demo = gr.Interface( fn=predict, inputs=[ "text", "text", gr.Radio(["bert", "roberta", "electra"]), "number", gr.Radio(["Binder", "McRae", "Buchanan"]), ], outputs=["text"], ) demo.launch()