Spaces:
Runtime error
Runtime error
import sys | |
import time | |
import printj | |
from transformers import pipeline # , set_seed | |
import numpy as np | |
import pandas as pd | |
# import nltk | |
import re | |
import streamlit as st | |
class StoryGenerator: | |
def __init__(self): | |
self.initialise_models() | |
self.stats_df = pd.DataFrame(data=[], columns=[]) | |
self.stories = [] | |
self.data = [] | |
def initialise_models(self): | |
start = time.time() | |
self.generator = pipeline('text-generation', model='gpt2') | |
self.classifier = pipeline("text-classification", | |
model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True) | |
initialising_time = time.time()-start | |
print(f'Initialising Time: {initialising_time}') | |
# set_seed(42) | |
# sys.exit() | |
def reset(): | |
self.clear_stories() | |
self.clear_stats() | |
def clear_stories(self): | |
self.data = [] | |
self.stories = [] | |
def clear_stats(self): | |
self.stats_df = pd.DataFrame(data=[], columns=[]) | |
def get_emotion(self, text): | |
emotions = self.classifier(text) | |
emotion = max(emotions[0], key=lambda x: x['score']) | |
return emotion | |
def get_num_token(text): | |
# return len(nltk.word_tokenize(text)) | |
return len(re.findall(r'\w+', text)) | |
def check_show_emotion(confidence_score, frequency, w): | |
frequency_penalty = 1 - frequency | |
probability_emote = w * confidence_score + (1-w) * frequency_penalty | |
return probability_emote > np.random.random_sample() | |
def story(self, | |
story_till_now="Hello, I'm a language model,", | |
num_generation=4, | |
length=10): | |
# last_length = 0 | |
for i in range(num_generation): | |
last_length = len(story_till_now) | |
genreate_robot_sentence = self.generator(story_till_now, max_length=self.get_num_token(story_till_now) + | |
length, num_return_sequences=1) | |
story_till_now = genreate_robot_sentence[0]['generated_text'] | |
new_sentence = story_till_now[last_length:] | |
emotion = self.get_emotion(new_sentence) | |
# printj.yellow(f'Sentence {i}:') | |
# story_to_print = f'{printj.ColorText.cyan(story_till_now[:last_length])}{printj.ColorText.green(story_till_now[last_length:])}\n' | |
# print(story_to_print) | |
# printj.purple(f'Emotion: {emotion}') | |
return story_till_now, emotion | |
def auto_ist(self, | |
story_till_now="Hello, I'm a language model,", | |
num_generation=4, | |
length=20, reaction_weight=0.5): | |
stats_df = pd.DataFrame(data=[], columns=[]) | |
stats_dict = dict() | |
num_reactions = 0 | |
reaction_frequency = 0 | |
emotion = self.get_emotion(story_till_now) # first line emotion | |
story_data = [{ | |
'sentence': story_till_now, | |
'turn': 'first', | |
'emotion': emotion['label'], | |
'confidence_score': emotion['score'], | |
}] | |
for i in range(num_generation): | |
# Text generation for User | |
last_length = len(story_till_now) | |
printj.cyan(story_till_now) | |
printj.red.bold_on_white( | |
f'loop: {i}; generate user text; length: {last_length}') | |
genreate_user_sentence = self.generator(story_till_now, max_length=self.get_num_token( | |
story_till_now)+length, num_return_sequences=1) | |
story_till_now = genreate_user_sentence[0]['generated_text'] | |
new_sentence_user = story_till_now[last_length:] | |
printj.red.bold_on_white(f'loop: {i}; check emotion') | |
# Emotion self.classifier for User | |
emotion_user = self.get_emotion(new_sentence_user) | |
if emotion_user['label'] == 'neutral': | |
show_emotion_user = False | |
else: | |
reaction_frequency = num_reactions/(i+1) | |
show_emotion_user = self.check_show_emotion( | |
confidence_score=emotion_user['score'], frequency=reaction_frequency, w=reaction_weight) | |
if show_emotion_user: | |
num_reactions += 1 | |
story_data.append({ | |
'sentence': new_sentence_user, | |
'turn': 'user', | |
'emotion': emotion_user['label'], | |
'confidence_score': emotion_user['score'], | |
}) | |
stats_dict['sentence_no'] = i | |
stats_dict['turn'] = 'user' | |
stats_dict['sentence'] = new_sentence_user | |
stats_dict['show_emotion'] = show_emotion_user | |
stats_dict['emotion_label'] = emotion_user['label'] | |
stats_dict['emotion_score'] = emotion_user['score'] | |
stats_dict['num_reactions'] = num_reactions | |
stats_dict['reaction_frequency'] = reaction_frequency | |
stats_dict['reaction_weight'] = reaction_weight | |
stats_df = pd.concat( | |
[stats_df, pd.DataFrame(stats_dict, index=[f'idx_{i}'])]) | |
# Text generation for Robot | |
last_length = len(story_till_now) | |
printj.cyan(story_till_now) | |
printj.red.bold_on_white( | |
f'loop: {i}; generate robot text; length: {last_length}') | |
genreate_robot_sentence = self.generator(story_till_now, max_length=self.get_num_token( | |
story_till_now)+length, num_return_sequences=1) | |
story_till_now = genreate_robot_sentence[0]['generated_text'] | |
new_sentence_robot = story_till_now[last_length:] | |
emotion_robot = self.get_emotion(new_sentence_robot) | |
story_data.append({ | |
'sentence': new_sentence_robot, | |
'turn': 'robot', | |
'emotion': emotion_robot['label'], | |
'confidence_score': emotion_robot['score'], | |
}) | |
stats_dict['sentence_no'] = i | |
stats_dict['turn'] = 'robot' | |
stats_dict['sentence'] = new_sentence_robot | |
stats_dict['show_emotion'] = None | |
stats_dict['emotion_label'] = emotion_robot['label'] | |
stats_dict['emotion_score'] = emotion_robot['score'] | |
stats_dict['num_reactions'] = None | |
stats_dict['reaction_frequency'] = None | |
stats_dict['reaction_weight'] = None | |
stats_df = pd.concat( | |
[stats_df, pd.DataFrame(stats_dict, index=[f'idx_{i}'])]) | |
return stats_df, story_till_now, story_data | |
def get_stats(self, | |
story_till_now="Hello, I'm a language model,", | |
num_generation=4, | |
length=20, reaction_weight=-1, num_tests=2): | |
use_random_w = reaction_weight == -1 | |
# self.stories = [] | |
try: | |
num_rows = max(self.stats_df.story_id)+1 | |
except Exception: | |
num_rows = 0 | |
for story_id in range(num_tests): | |
if use_random_w: | |
# reaction_weight = np.random.random_sample() | |
reaction_weight = np.round(np.random.random_sample(), 1) | |
stats_df0, _story_till_now, story_data = self.auto_ist( | |
story_till_now=story_till_now, | |
num_generation=4, | |
length=20, reaction_weight=reaction_weight) | |
stats_df0.insert(loc=0, column='story_id', value=story_id+num_rows) | |
# stats_df0['story_id'] = story_id | |
self.stats_df = pd.concat([self.stats_df, stats_df0]) | |
printj.yellow(f'story_id: {story_id}') | |
printj.green(stats_df0) | |
self.stories.append(_story_till_now) | |
self.data.append(story_data) | |
self.stats_df = self.stats_df.reset_index(drop=True) | |
print(self.stats_df) | |
def save_stats(self, path='pandas_simple.xlsx'): | |
writer = pd.ExcelWriter(path, engine='xlsxwriter') | |
# Convert the dataframe to an XlsxWriter Excel object. | |
self.stats_df.to_excel(writer, sheet_name='IST') | |
# Close the Pandas Excel writer and output the Excel file. | |
writer.save() | |