ayse's picture
Update app.py
6d51a31
import gradio as gr
import torch
import transformers
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.config = transformers.DistilBertConfig()
self.bert = transformers.AutoModelForSequenceClassification.from_pretrained("https://huggingface.co/ayse/distilbert-english-finetuned/resolve/main/model.pth", config=self.config)
self.tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
def forward(self, input_text):
encoding = self.tokenizer.encode_plus(
input_text,
add_special_tokens = True,
pad_to_max_length = True,
return_token_type_ids = False,
return_attention_mask = True,
return_tensors = 'pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
output = self.bert(
input_ids = input_ids,
attention_mask = attention_mask)
return output
def predict(input_text):
model = Model()
model.eval()
outputs = model(input_text)
logits = outputs.logits
prediction = torch.argmax(logits, dim=-1)
if prediction.item() == 0:
return "NEGATIVE"
if prediction.item() == 1:
return "POSITIVE"
iface = gr.Interface(predict,
inputs="text",
outputs="text",
title="Sentiment Classification from Text",
description="This sentiment classifier is a final project of a data science bootcamp. I trained DistilBERT with Tinder Application Reviews on Google Play Store (EN).",
allow_flagging="never")
iface.launch(inbrowser=True, share=True)