HebEMO_demo / HebEMO.py
avichr's picture
Update HebEMO.py
e021892
raw
history blame
No virus
3.18 kB
class HebEMO:
def __init__(self, device=-1, emotions = ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger',
'sadness', 'disgust']):
from transformers import pipeline
from tqdm import tqdm
self.device = device
self.emotions = emotions
self.hebemo_models = {}
for emo in tqdm(emotions):
self.hebemo_models[emo] = pipeline(
"sentiment-analysis",
model="avichr/hebEMO_"+emo,
tokenizer="avichr/heBERT",
device = self.device #-1 run on CPU, else - device ID
)
def hebemo(self, text = None, input_path=False, save_results=False, read_lines=False, plot=False):
'''
text (str): a text or list of text to analyze
input_path(str): the path to the text file (txt file, each row for different instance)
returns pandas DataFrame of the analyzed texts and save it to the same dir of the input file
'''
from pyplutchik import plutchik
from spider_plot import spider_plot
import matplotlib.pyplot as plt
import pandas as pd
import time
import torch
from tqdm import tqdm
if text is None and type(input_path) is str:
# read the file
with open(input_path, encoding='utf8') as p:
txt = p.readlines()
elif text is not None and (input_path is None or input_path is False):
if type(text) is str:
if read_lines:
txt = text.split('\n')
else:
txt = [text]
elif type(text) is list:
txt = text
else:
raise ValueError('text should be text or list of text.')
else:
raise ValueError('you should provide a text string, list of strings or text path.')
# run hebEMO
hebEMO_df = pd.DataFrame(txt)
for emo in tqdm(self.emotions):
x = self.hebemo_models[emo](txt)
hebEMO_df = hebEMO_df.join(pd.DataFrame(x).rename(columns = {'label': emo, 'score':'confidence_'+emo}))
del x
torch.cuda.empty_cache()
hebEMO_df = hebEMO_df.applymap(lambda x: 0 if x=='LABEL_0' else 1 if x=='LABEL_1' else x)
if save_results is not False:
gen_name = str(int(time.time()*1e7))
if type(save_results) is str:
hebEMO_df.to_csv(save_results+'/'+gen_name+'_heEMOed.csv', encoding='utf8')
else:
hebEMO_df.to_csv(gen_name+'_heEMOed.csv', encoding='utf8')
if plot:
hebEMO = pd.DataFrame()
for emo in hebEMO_df.columns[1::2]:
hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo]))
for i in range(0,1):
try: ax = plutchik(hebEMO.to_dict(orient='records')[i])
except: ax = spider_plot(hebEMO)
print(hebEMO_df[0][i])
plt.show()
return (hebEMO_df[0][i], ax)
else:
return (hebEMO_df)
# HebEMO_model = HebEMO()