PlaceAnalysis / app.py
lukelike1001's picture
added data, pre-trained model, and app
de99c92
raw
history blame
2.36 kB
import gradio as gr
import os
os.system('python -m spacy download en_core_web_sm')
import spacy
from spacy import displacy
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
TOKEN_SIZE = 128
nlp = spacy.load("en_core_web_sm")
def multi_analysis(text):
scores = sentiment_analysis(text)
pos_tokens = text_analysis(text)
return scores, pos_tokens
def sentiment_analysis(text):
# load the model and tokenizer from local directories
model = DistilBertForSequenceClassification.from_pretrained('saved_model/')
tokenizer = DistilBertTokenizer.from_pretrained('saved_model/')
# tokenize the inputs
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=128)
# ignore gradients as we only need inference (aka logits)
with torch.no_grad():
logits = model(**inputs).logits
# apply softmax to make sure the probabilities sum up to 1
predicted_probabilities = torch.softmax(logits, dim=1).squeeze().tolist()
# return the probability of each label (positive, neutral, negative)
labels = ["NEG", "NEU", "POS"]
confidences = {label: prob for label, prob in zip(labels, predicted_probabilities)}
return confidences
def text_analysis(text):
doc = nlp(text)
pos_tokens = []
for token in doc:
pos_tokens.extend([(token.text, token.pos_), (" ", None)])
return pos_tokens
# add a title and description to the model
title = "Reddit Sentiment Analysis"
description = """In July 2023, Reddit changed its API pricing from free to $0.24 per 1000 API calls,
which was met with major backlash various communities. This sentiment analysis
model is based on DistilBERT and has been fine-tuned to better analyze Reddit
comments, with its F1 score at ~94%. The HuggingFace page can be found at:
and the Github repository can be found at:
https://github.com/lukelike1001/PlaceAnalysis"""
app = gr.Interface(
fn=multi_analysis,
inputs=gr.Textbox(placeholder="Enter sentence here..."),
outputs=["label", "highlight"],
title=title,
description=description,
examples=[
["What are the coords for that?"],
["the CEO of Reddit killed 3rd party apps."]
],
)
app.launch(share=True)