File size: 1,121 Bytes
3b1b881 4277173 3b1b881 60b09de 3b1b881 2f7bc0e 5c27e6d 2f7bc0e 4277173 2f7bc0e 3b1b881 8b9a20c 3b1b881 4277173 3b1b881 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from peft import PeftModel, PeftConfig
base_model = "cardiffnlp/twitter-roberta-base-sentiment-latest"
adapter_model = 'saideep-arikontham/twitter-roberta-base-sentiment-latest-trump-stance-3'
# define label maps
id2label = {0: "Anti-Trump", 1 : "Pro-Trump"}
label2id = {"Anti-Trump" : 0, "Pro-Trump" : 1}
# generate classification model from model_checkpoint
model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=2, id2label = id2label, label2id = label2id, ignore_mismatched_sizes=True)
model = PeftModel.from_pretrained(model, adapter_model)
tokenizer = AutoTokenizer.from_pretrained(adapter_model)
def greet(text):
model.to('cpu')
inputs = tokenizer.encode(text, return_tensors="pt").to("cpu")
# compute logits
logits = model(inputs).logits
# convert logits to label
predictions = torch.argmax(logits)
return "This text is " + id2label[predictions.tolist()] + "!!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch() |