|
from ast import Str |
|
import gradio as gr |
|
from tweetnlp import Sentiment, NER |
|
from typing import Tuple, Dict |
|
from statistics import mean |
|
|
|
def clean_tweet(tweet: str, remove_chars: str = "@#") -> str: |
|
"""Remove any unwanted characters |
|
Args: |
|
tweet (str): The raw tweet |
|
remove_chars (str, optional): The characters to remove. Defaults to "@#". |
|
Returns: |
|
str: The tweet with these characters removed |
|
""" |
|
for char in remove_chars: |
|
tweet = tweet.replace(char, "") |
|
return tweet |
|
|
|
|
|
def format_sentiment(model_output: Dict) -> Dict: |
|
"""Format the output of the sentiment model |
|
Args: |
|
model_output (Dict): The model output |
|
Returns: |
|
Dict: The format for gradio |
|
""" |
|
formatted_output = dict() |
|
if model_output["label"] == "positive": |
|
formatted_output["positive"] = model_output["probability"] |
|
formatted_output["negative"] = 1 - model_output["probability"] |
|
else: |
|
formatted_output["negative"] = model_output["probability"] |
|
formatted_output["positive"] = 1 - model_output["probability"] |
|
return formatted_output |
|
|
|
|
|
def format_entities(model_output: Dict) -> Dict: |
|
"""Format the output of the NER model |
|
Args: |
|
model_output (Dict): The model output |
|
Returns: |
|
Dict: The format for gradio |
|
""" |
|
formatted_output = dict() |
|
for entity in model_output["entity_prediction"]: |
|
new_output = dict() |
|
name = " ".join(entity["entity"]) |
|
entity_type = entity["type"] |
|
new_key = f"{name}:{entity_type}" |
|
new_value = mean(entity["probability"]) |
|
formatted_output[new_key] = new_value |
|
return formatted_output |
|
|
|
|
|
def classify(tweet: str) -> Tuple[Dict, Dict]: |
|
"""Runs models |
|
Args: |
|
tweet (str): The raw tweet |
|
Returns: |
|
Tuple[Dict, Dict]: The formatted_sentiment and formatted_entities of the tweet |
|
""" |
|
tweet = clean_tweet(tweet) |
|
|
|
model_sentiment = se_model.sentiment(tweet) |
|
formatted_sentiment = format_sentiment(model_sentiment) |
|
|
|
entities = ner_model.ner(tweet) |
|
formatted_entities = format_entities(entities) |
|
return formatted_sentiment, formatted_entities |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
se_model = Sentiment() |
|
ner_model = NER() |
|
|
|
|
|
examples = list() |
|
examples.append("Dameon Pierce is clearly the #Texans starter and he once again looks good") |
|
examples.append("Deebo Samuel had 150+ receiving yards in 4 games last year - the most by any receiver in the league.") |
|
|
|
iface = gr.Interface(fn=classify, inputs="text", outputs=["label", "label"], examples=examples) |
|
iface.launch() |