Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import transformers | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class Model(torch.nn.Module): | |
def __init__(self): | |
super(Model, self).__init__() | |
self.config = transformers.DistilBertConfig() | |
self.bert = transformers.AutoModelForSequenceClassification.from_pretrained("https://huggingface.co/ayse/distilbert-english-finetuned/resolve/main/model.pth", config=self.config) | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
def forward(self, input_text): | |
encoding = self.tokenizer.encode_plus( | |
input_text, | |
add_special_tokens = True, | |
pad_to_max_length = True, | |
return_token_type_ids = False, | |
return_attention_mask = True, | |
return_tensors = 'pt' | |
) | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
output = self.bert( | |
input_ids = input_ids, | |
attention_mask = attention_mask) | |
return output | |
def predict(input_text): | |
model = Model() | |
model.eval() | |
outputs = model(input_text) | |
logits = outputs.logits | |
prediction = torch.argmax(logits, dim=-1) | |
if prediction.item() == 0: | |
return "NEGATIVE" | |
if prediction.item() == 1: | |
return "POSITIVE" | |
iface = gr.Interface(predict, | |
inputs="text", | |
outputs="text", | |
title="Sentiment Classification from Text", | |
description="This sentiment classifier is a final project of a data science bootcamp. I trained DistilBERT with Tinder Application Reviews on Google Play Store (EN).", | |
allow_flagging="never") | |
iface.launch(inbrowser=True, share=True) | |