import gradio as gr import spacy import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification TOKEN_SIZE = 128 spacy.cli.download("en_core_web_sm") nlp = spacy.load("en_core_web_sm") def multi_analysis(text): scores = sentiment_analysis(text) pos_tokens = text_analysis(text) return scores, pos_tokens def sentiment_analysis(text): # load the model and tokenizer from local directories model = DistilBertForSequenceClassification.from_pretrained('saved_model/') tokenizer = DistilBertTokenizer.from_pretrained('saved_model/') # tokenize the inputs inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=128) # ignore gradients as we only need inference (aka logits) with torch.no_grad(): logits = model(**inputs).logits # apply softmax to make sure the probabilities sum up to 1 predicted_probabilities = torch.softmax(logits, dim=1).squeeze().tolist() # return the probability of each label (positive, neutral, negative) labels = ["NEG", "NEU", "POS"] confidences = {label: prob for label, prob in zip(labels, predicted_probabilities)} return confidences def text_analysis(text): doc = nlp(text) pos_tokens = [] for token in doc: pos_tokens.extend([(token.text, token.pos_), (" ", None)]) return pos_tokens # add a title and description to the model title = "Reddit Sentiment Analysis" description = """In July 2023, Reddit changed its API pricing from free to $0.24 per 1000 API calls, which was met with major backlash various communities. This sentiment analysis model is based on DistilBERT and has been fine-tuned to better analyze Reddit comments, with its F1 score at ~94%. For further documentation, check out the Github repository at https://github.com/lukelike1001/PlaceAnalysis, and the project’s info page at https://lukelike1001.github.io/place.html.""" app = gr.Interface( fn=multi_analysis, inputs=gr.Textbox(placeholder="Enter sentence here..."), outputs=["label", "highlight"], title=title, description=description, examples=[ ["What are the coords for that?"], ["the CEO of Reddit killed 3rd party apps."] ], ) app.launch()