Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
os.system('python -m spacy download en_core_web_sm') | |
import spacy | |
from spacy import displacy | |
import torch | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
TOKEN_SIZE = 128 | |
nlp = spacy.load("en_core_web_sm") | |
def multi_analysis(text): | |
scores = sentiment_analysis(text) | |
pos_tokens = text_analysis(text) | |
return scores, pos_tokens | |
def sentiment_analysis(text): | |
# load the model and tokenizer from local directories | |
model = DistilBertForSequenceClassification.from_pretrained('saved_model/') | |
tokenizer = DistilBertTokenizer.from_pretrained('saved_model/') | |
# tokenize the inputs | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=128) | |
# ignore gradients as we only need inference (aka logits) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
# apply softmax to make sure the probabilities sum up to 1 | |
predicted_probabilities = torch.softmax(logits, dim=1).squeeze().tolist() | |
# return the probability of each label (positive, neutral, negative) | |
labels = ["NEG", "NEU", "POS"] | |
confidences = {label: prob for label, prob in zip(labels, predicted_probabilities)} | |
return confidences | |
def text_analysis(text): | |
doc = nlp(text) | |
pos_tokens = [] | |
for token in doc: | |
pos_tokens.extend([(token.text, token.pos_), (" ", None)]) | |
return pos_tokens | |
# add a title and description to the model | |
title = "Reddit Sentiment Analysis" | |
description = """In July 2023, Reddit changed its API pricing from free to $0.24 per 1000 API calls, | |
which was met with major backlash various communities. This sentiment analysis | |
model is based on DistilBERT and has been fine-tuned to better analyze Reddit | |
comments, with its F1 score at ~94%. The HuggingFace page can be found at: | |
and the Github repository can be found at: | |
https://github.com/lukelike1001/PlaceAnalysis""" | |
app = gr.Interface( | |
fn=multi_analysis, | |
inputs=gr.Textbox(placeholder="Enter sentence here..."), | |
outputs=["label", "highlight"], | |
title=title, | |
description=description, | |
examples=[ | |
["What are the coords for that?"], | |
["the CEO of Reddit killed 3rd party apps."] | |
], | |
) | |
app.launch(share=True) | |