File size: 1,773 Bytes
29f3613
 
 
 
 
 
0ee6e26
29f3613
8931265
bf84ccc
01ca571
3c6272b
29f3613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a02177
 
29f3613
6d51a31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
import torch
import transformers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.config = transformers.DistilBertConfig()
        self.bert = transformers.AutoModelForSequenceClassification.from_pretrained("https://huggingface.co/ayse/distilbert-english-finetuned/resolve/main/model.pth", config=self.config)
        self.tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
        
    def forward(self, input_text):
        encoding = self.tokenizer.encode_plus(
            input_text,
            add_special_tokens = True, 
            pad_to_max_length = True,
            return_token_type_ids = False,
            return_attention_mask = True,
            return_tensors = 'pt' 
        )

        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        output = self.bert(
            input_ids = input_ids,
            attention_mask = attention_mask)

        return output


def predict(input_text):
    model = Model()
    model.eval()
    outputs = model(input_text)
    logits = outputs.logits
    prediction = torch.argmax(logits, dim=-1)
    if prediction.item() == 0:
        return "NEGATIVE"
    if prediction.item() == 1:
        return "POSITIVE"


iface = gr.Interface(predict,
inputs="text",
outputs="text",
title="Sentiment Classification from Text",
description="This sentiment classifier is a final project of a data science bootcamp. I trained DistilBERT with Tinder Application Reviews on Google Play Store (EN).",
allow_flagging="never")
iface.launch(inbrowser=True, share=True)