saideep-arikontham's picture
Update app.py
14a0b2f verified
raw
history blame
No virus
1.13 kB
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from peft import PeftModel, PeftConfig
import re
base_model = "cardiffnlp/twitter-roberta-base-sentiment-latest"
adapter_model = 'saideep-arikontham/twitter-roberta-base-sentiment-latest-trump-stance-1'
# define label maps
id2label = {0: "Anti-Trump", 1 : "Pro-Trump"}
label2id = {"Anti-Trump" : 0, "Pro-Trump" : 1}
# generate classification model from model_checkpoint
model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=2, id2label = id2label, label2id = label2id, ignore_mismatched_sizes=True)
model = PeftModel.from_pretrained(model, adapter_model)
tokenizer = AutoTokenizer.from_pretrained(adapter_model)
def greet(text):
model.to('cpu')
inputs = tokenizer.encode(text, return_tensors="pt").to("cpu")
# compute logits
logits = model(inputs).logits
# convert logits to label
predictions = torch.argmax(logits)
return "This text is " + id2label[predictions.tolist()] + "!!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()