saideep-arikontham's picture
Update app.py
60b09de verified
raw
history blame
804 Bytes
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoModel, AutoTokenizer
import torch
from peft import PeftModel, PeftConfig
base_model = "cardiffnlp/twitter-roberta-base-sentiment-latest"
adapter_model = 'saideep-arikontham/twitter-roberta-base-sentiment-latest-trump-stance'
model = PeftModel.from_pretrained(model, adapter_model)
tokenizer = AutoTokenizer.from_pretrained(adapter_model)
def greet(text):
model.to('mps')
inputs = tokenizer.encode(text, return_tensors="pt").to("mps")
# compute logits
logits = model(inputs).logits
# convert logits to label
predictions = torch.argmax(logits)
return "This text is " + id2label[predictions.tolist()] + "!!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()