class HebEMO: def __init__(self, device=-1, emotions = ['anticipation', 'joy', 'trust', 'fear', 'surprise', 'anger', 'sadness', 'disgust']): from transformers import pipeline from tqdm import tqdm self.device = device self.emotions = emotions self.hebemo_models = {} for emo in tqdm(emotions): self.hebemo_models[emo] = pipeline( "sentiment-analysis", model="avichr/hebEMO_"+emo, tokenizer="avichr/heBERT", device = self.device #-1 run on CPU, else - device ID ) def hebemo(self, text = None, input_path=False, save_results=False, read_lines=False, plot=False): ''' text (str): a text or list of text to analyze input_path(str): the path to the text file (txt file, each row for different instance) returns pandas DataFrame of the analyzed texts and save it to the same dir of the input file ''' from pyplutchik import plutchik from spider_plot import spider_plot import matplotlib.pyplot as plt import pandas as pd import time import torch from tqdm import tqdm if text is None and type(input_path) is str: # read the file with open(input_path, encoding='utf8') as p: txt = p.readlines() elif text is not None and (input_path is None or input_path is False): if type(text) is str: if read_lines: txt = text.split('\n') else: txt = [text] elif type(text) is list: txt = text else: raise ValueError('text should be text or list of text.') else: raise ValueError('you should provide a text string, list of strings or text path.') # run hebEMO hebEMO_df = pd.DataFrame(txt) for emo in tqdm(self.emotions): x = self.hebemo_models[emo](txt) hebEMO_df = hebEMO_df.join(pd.DataFrame(x).rename(columns = {'label': emo, 'score':'confidence_'+emo})) del x torch.cuda.empty_cache() hebEMO_df = hebEMO_df.applymap(lambda x: 0 if x=='LABEL_0' else 1 if x=='LABEL_1' else x) if save_results is not False: gen_name = str(int(time.time()*1e7)) if type(save_results) is str: hebEMO_df.to_csv(save_results+'/'+gen_name+'_heEMOed.csv', encoding='utf8') else: hebEMO_df.to_csv(gen_name+'_heEMOed.csv', encoding='utf8') if plot: hebEMO = pd.DataFrame() for emo in hebEMO_df.columns[1::2]: hebEMO[emo] = abs(hebEMO_df[emo]-(1-hebEMO_df['confidence_'+emo])) for i in range(0,1): try: ax = plutchik(hebEMO.to_dict(orient='records')[i]) except: ax = spider_plot(hebEMO) print(hebEMO_df[0][i]) plt.show() return (hebEMO_df[0][i], ax) else: return (hebEMO_df) # HebEMO_model = HebEMO()