File size: 1,132 Bytes
3b1b881
4277173
3b1b881
60b09de
60a307f
3b1b881
2f7bc0e
423dfd7
2f7bc0e
4277173
 
 
 
 
 
 
2f7bc0e
 
 
7fe75d1
60a307f
3b1b881
8b9a20c
 
3b1b881
 
 
 
 
 
 
 
4277173
3b1b881
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from peft import PeftModel, PeftConfig
import re

base_model = "cardiffnlp/twitter-roberta-base-sentiment-latest"
adapter_model = 'saideep-arikontham/twitter-roberta-base-sentiment-latest-trump-stance-1'

# define label maps
id2label = {0: "Anti-Trump", 1 : "Pro-Trump"}
label2id = {"Anti-Trump" : 0, "Pro-Trump" : 1}

# generate classification model from model_checkpoint
model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=2, id2label = id2label, label2id = label2id, ignore_mismatched_sizes=True)

model = PeftModel.from_pretrained(model, adapter_model)
tokenizer = AutoTokenizer.from_pretrained(adapter_model)


def greet(text):
    
    model.to('cpu')
    inputs = tokenizer.encode(text, return_tensors="pt").to("cpu")
    # compute logits
    logits = model(inputs).logits
    # convert logits to label
    predictions = torch.argmax(logits)

    return "This text is " + id2label[predictions.tolist()] + "!!"



demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()