Spaces:
Runtime error
Runtime error
Oresti Theodoridis
Merge branch 'develop' into 35-create-new-text-classifier-sentiment
725b13f
unverified
import openai | |
import regex as re | |
from twitterscraper import TwitterScraper | |
from datetime import date | |
class TextClassifier: | |
def __init__(self, model_name="text-davinci-002", from_date='2022-01-01', to_date=str(date.today()), | |
user_name='jimmieakesson', | |
num_tweets=20): | |
""" | |
Initializes the TextClassifier. | |
:param model_name: name of the model from openai. | |
:param from_date: string of the format 'YYYY-MM-DD'. | |
:param to_date: string of the format 'YYYY-MM-DD'. | |
:param num_tweets: integer value of the maximum number of tweets to be scraped. | |
""" | |
self.model_name = model_name | |
self.from_date = from_date | |
self.to_date = to_date | |
self.num_tweets = num_tweets | |
self.user_name = user_name | |
self.ts = TwitterScraper.TwitterScraper(from_date, to_date, num_tweets) | |
self.df = self.ts.scrape_by_user(user_name) | |
# self.api_key = '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' | |
openai.api_key = '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' | |
def scrape_tweets(self): | |
""" | |
Scrapes tweets from the given date range. | |
""" | |
self.ts.scrape_tweets() | |
def cleanup_sentiment_results(classification_unclean): | |
""" | |
Cleans up the results of the sentiment classification. | |
:param classification_unclean: string of the classification result. | |
:return: cleaned up string. | |
""" | |
classification_clean = classification_unclean.replace('\n\n', "") | |
classification_clean = classification_clean.replace('\n', "") | |
if classification_clean.startswith(" "): | |
classification_clean = classification_clean.replace(" ", "") | |
return classification_clean | |
return response.choices[0]['text'] | |
def classify_sentiment(self, text: str): | |
""" | |
Classifies the sentiment of a text. | |
""" | |
assert isinstance(text, str) | |
prompt_string = "Classify one sentiment for this tweet:\n \"" | |
prompt_string += text | |
prompt_string += "\" \nFor example:\nSupport,\nOpposition,\nCriticism,\nPraise,\nDisagreement," \ | |
"\nAgreement,\nSkepticism,\nAdmiration,\nAnecdotes,\nJokes,\nMemes,\nSarcasm,\nSatire," \ | |
"\nQuestions,\nStatements,\nOpinions,\nPredictions.\nSENTIMENT=" | |
response = openai.Completion.create( | |
model=self.model_name, | |
prompt=prompt_string, | |
temperature=0.0, | |
max_tokens=256, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
logprobs=5 | |
) | |
classification_unclean = response.choices[0]['text'] | |
classification_clean = self.cleanup_sentiment_results(classification_unclean) | |
return classification_clean.lower() | |
def classify_sentiment_of_tweets(self): | |
""" | |
Classifies the sentiment of a user's tweets. | |
""" | |
df_sentiment = self.df.copy() | |
df_sentiment['sentiment'] = df_sentiment['tweet'].apply(self.classify_sentiment) | |
self.df = df_sentiment | |
return self.df | |
def analyze_sentiment(self, text: str, sentiment: str): | |
# TODO: fix prompt before running this method | |
""" | |
Analyzes the sentiment of a text using OpenAI. | |
:param text: string of the tweet text. | |
:param sentiment: | |
:return: | |
""" | |
# assert 1 == 2, "Måste fixa prompt innan denna metod körs" | |
prompt_string = "Who is the TARGET of this " | |
prompt_string += sentiment | |
prompt_string += " TWEET?\\nTWEET=\"" | |
prompt_string += text | |
prompt_string += "\"\\n.TARGET should consist of less than 5 words.\\nTARGET=" | |
response = openai.Completion.create( | |
model=self.model_name, | |
prompt=prompt_string, | |
temperature=0, | |
max_tokens=256, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0 | |
) | |
analyzed_sentiment = response.choices[0]['text'] | |
# Remove spaces at the start/end of the response | |
if analyzed_sentiment.startswith(' '): | |
analyzed_sentiment = analyzed_sentiment[1:] | |
if analyzed_sentiment.endswith(' '): | |
analyzed_sentiment = analyzed_sentiment[:-1] | |
# Sometimes GPT-3 gives faulty results, so a simple filter is introduced | |
# If the prediction is bad | |
# -> set target value to N/A (not applicable) | |
if len(analyzed_sentiment) > 50: | |
analyzed_sentiment = "N/A" | |
# An attempt to merge target responses that should be the same | |
analyzed_sentiment = re.sub("\(", "", analyzed_sentiment) | |
analyzed_sentiment = re.sub("\)", "", analyzed_sentiment) | |
s_list = ["s", "the swedish social democratic party"] | |
m_list = ["m", "the swedish moderate party", "the moderate party"] | |
mp_list = ["mp", "the swedish green party"] | |
if analyzed_sentiment.lower() == "v": | |
analyzed_sentiment = "Vänsterpartiet" | |
elif analyzed_sentiment.lower() == "mp": | |
analyzed_sentiment = "Miljöpartiet" | |
elif analyzed_sentiment.lower() in s_list: | |
analyzed_sentiment = "Socialdemokraterna" | |
elif analyzed_sentiment.lower() == "c": | |
analyzed_sentiment = "Centerpartiet" | |
elif analyzed_sentiment.lower() == "l": | |
analyzed_sentiment = "Liberalerna" | |
elif analyzed_sentiment.lower() == "kd": | |
analyzed_sentiment = "Kristdemokraterna" | |
elif analyzed_sentiment.lower() in m_list: | |
analyzed_sentiment = "Moderaterna" | |
elif analyzed_sentiment.lower() == "sd": | |
analyzed_sentiment = "Sverigedemokraterna" | |
elif analyzed_sentiment.lower() == "the swedish government": | |
analyzed_sentiment = "Regeringen" | |
return analyzed_sentiment | |
def analyze_sentiment_of_tweets(self): | |
""" | |
Analyzes the sentiment of a user's tweets. | |
""" | |
# check if 'sentiment' column exists, raise exception if not | |
assert 'sentiment' in self.df.columns, \ | |
"'sentiment' column does not exist. Please run classify_sentiment_of_tweets first." | |
df_sentiment = self.df.copy() | |
df_sentiment['target'] = df_sentiment.apply(lambda row: self.analyze_sentiment(row['tweet'], row['sentiment']), | |
axis=1) | |
self.df = df_sentiment | |
return self.df | |
def classify_topic(self, text: str): | |
""" | |
Classifies the topics of a text. | |
:param text: string of the tweet text. | |
""" | |
assert isinstance(text, str) | |
prompt_string = "Classify one topic for this tweet:\n \"" | |
prompt_string += text | |
prompt_string += "\" \nFor example:\nEconomy,\nEnvironment,\nHealth,\nPolitics,\nScience,\nSports,\nTechnology," \ | |
"\nTransportation,\nWorld.\nTOPIC=" | |
response = openai.Completion.create( | |
model=self.model_name, | |
prompt=prompt_string, | |
temperature=0, | |
max_tokens=892, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
) | |
classification_unclean = response.choices[0]['text'] | |
classification_clean = self.cleanup_topic_results(classification_unclean) | |
return classification_clean.lower() | |
def classify_topics_of_tweets(self): | |
""" | |
Classifies the topics of a user's tweets. | |
""" | |
df_topic = self.df | |
df_topic['topic'] = df_topic['tweet'].apply(self.classify_topic) | |
return df_topic | |
def __repr__(self): | |
return "TwitterScraper(from_date={}, to_date={}, num_tweets={})".format(self.from_date, self.to_date, | |
self.num_tweets) | |
def cleanup_topic_results(prediction_dict, text): | |
new_item = text.replace("\n", " ") | |
new_item = new_item.replace(" ", " ") | |
return new_item | |
if __name__ == "__main__": | |