Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from openprompt.plms import load_plm | |
from openprompt import PromptDataLoader | |
from openprompt.prompts import ManualVerbalizer | |
from openprompt.prompts import ManualTemplate | |
from openprompt.data_utils import InputExample | |
from openprompt import PromptForClassification | |
def sentiment_analysis(sentence, template, model_name, positive, neutral, negative): | |
model_name = model_name | |
template = template.replace('[SENTENCE]', '{"placeholder":"text_a"}') | |
template = template.replace('[MASK]', '{"mask"}') | |
classes = ['positive', 'neutral', 'negative'] | |
label_words = { | |
"positive": positive.split(" "), | |
"neutral": neutral.split(" "), | |
"negative": negative.split(" "), | |
} | |
type_dic = { | |
"bert-base-uncased":"bert", | |
"roberta-base":"roberta", | |
"yiyanghkust/finbert-pretrain":"bert", | |
} | |
testdata = [InputExample(guid=0,text_a=sentence,label=0)] | |
plm, tokenizer, model_config, WrapperClass = load_plm(type_dic[model_name], model_name) | |
promptTemplate = ManualTemplate( | |
text = template, | |
tokenizer = tokenizer, | |
) | |
promptVerbalizer = ManualVerbalizer( | |
classes = classes, | |
label_words = label_words, | |
tokenizer = tokenizer, | |
) | |
test_dataloader = PromptDataLoader( | |
dataset = testdata, | |
tokenizer = tokenizer, | |
template = promptTemplate, | |
tokenizer_wrapper_class = WrapperClass, | |
batch_size = 1, | |
max_seq_length = 512, | |
) | |
prompt_model = PromptForClassification( | |
plm=plm, | |
template=promptTemplate, | |
verbalizer=promptVerbalizer, | |
freeze_plm=False #whether or not to freeze the pretrained language model | |
) | |
for step, inputs in enumerate(test_dataloader): | |
logits = prompt_model(inputs) | |
return classes[torch.argmax(logits, dim=-1)[0]] | |
demo = gr.Interface(fn=sentiment_analysis, | |
inputs = [gr.Textbox(placeholder="Enter sentence here.",label="sentence"), | |
gr.Textbox(placeholder="Your template must have a [SENTENCE] token and a [MASK] token.",label="template"), | |
gr.Radio(choices=["roberta-base","bert-base-uncased","yiyanghkust/finbert-pretrain"], label="model choics"), | |
gr.Textbox(placeholder="Separate words with Spaces.",label="positive label words"), | |
gr.Textbox(placeholder="Separate words with Spaces.",label="neutral label words"), | |
gr.Textbox(placeholder="Separate words with Spaces.",label="negative label words") | |
], | |
outputs="text", | |
) | |
demo.launch() |