|
import gradio as gr |
|
import torch |
|
import os |
|
from huggingface_hub import login |
|
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer |
|
from transformers import AutoTokenizer |
|
|
|
README = """ |
|
# Movie Review Score Discriminator |
|
It is a program that classifies whether it is positive or negative by entering movie reviews. |
|
You can choose between the Korean version and the English version. |
|
## Usage |
|
|
|
""" |
|
|
|
|
|
model_name = "roberta-base" |
|
learning_rate = 5e-5 |
|
batch_size_train = 64 |
|
step = 1900 |
|
login(token='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') |
|
|
|
file_name = "model-{}.pt".format(step) |
|
state_dict = torch.load(os.path.join(file_name)) |
|
|
|
id2label = {0: "NEGATIVE", 1: "POSITIVE"} |
|
label2id = {"NEGATIVE": 0, "POSITIVE": 1} |
|
|
|
|
|
title = "Movie Review Score Discriminator" |
|
description = "It is a program that classifies whether it is positive or negative by entering movie reviews. You can choose between the Korean version and the English version." |
|
examples = ["the greatest musicians ", "cold movie "] |
|
|
|
|
|
|
|
def tokenized_data(tokenizer, inputs): |
|
return tokenizer.batch_encode_plus( |
|
inputs, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=64, |
|
truncation=True) |
|
|
|
|
|
def greet(text): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_name, num_labels=2, id2label=id2label, label2id=label2id, |
|
state_dict=state_dict |
|
) |
|
inputs = tokenized_data(tokenizer, text) |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
logits = model(input_ids=inputs[0], attention_mask=inputs[1]).logits |
|
|
|
return logits |
|
|
|
demo1 = gr.Interface.load("models/cardiffnlp/twitter-roberta-base-sentiment", inputs="text", outputs="text", |
|
title=title, theme="peach", |
|
allow_flagging="auto", |
|
description=description, examples=examples) |
|
|
|
|
|
demo2 = gr.Interface(fn=greet, inputs="text", outputs="text", |
|
title=title, theme="peach", |
|
allow_flagging="auto", |
|
description=description, examples=examples) |
|
|
|
if __name__ == "__main__": |
|
demo2.launch() |