storytelling / story_gen.py
jitesh's picture
cache importing the models
1061dba
raw
history blame
8.21 kB
import sys
import time
import printj
from transformers import pipeline # , set_seed
import numpy as np
import pandas as pd
# import nltk
import re
import streamlit as st
class StoryGenerator:
def __init__(self):
self.initialise_models()
self.stats_df = pd.DataFrame(data=[], columns=[])
self.stories = []
self.data = []
@st.cache()
def initialise_models(self):
start = time.time()
self.generator = pipeline('text-generation', model='gpt2')
self.classifier = pipeline("text-classification",
model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)
initialising_time = time.time()-start
print(f'Initialising Time: {initialising_time}')
# set_seed(42)
# sys.exit()
def reset():
self.clear_stories()
self.clear_stats()
def clear_stories(self):
self.data = []
self.stories = []
def clear_stats(self):
self.stats_df = pd.DataFrame(data=[], columns=[])
def get_emotion(self, text):
emotions = self.classifier(text)
emotion = max(emotions[0], key=lambda x: x['score'])
return emotion
@staticmethod
def get_num_token(text):
# return len(nltk.word_tokenize(text))
return len(re.findall(r'\w+', text))
@staticmethod
def check_show_emotion(confidence_score, frequency, w):
frequency_penalty = 1 - frequency
probability_emote = w * confidence_score + (1-w) * frequency_penalty
return probability_emote > np.random.random_sample()
def story(self,
story_till_now="Hello, I'm a language model,",
num_generation=4,
length=10):
# last_length = 0
for i in range(num_generation):
last_length = len(story_till_now)
genreate_robot_sentence = self.generator(story_till_now, max_length=self.get_num_token(story_till_now) +
length, num_return_sequences=1)
story_till_now = genreate_robot_sentence[0]['generated_text']
new_sentence = story_till_now[last_length:]
emotion = self.get_emotion(new_sentence)
# printj.yellow(f'Sentence {i}:')
# story_to_print = f'{printj.ColorText.cyan(story_till_now[:last_length])}{printj.ColorText.green(story_till_now[last_length:])}\n'
# print(story_to_print)
# printj.purple(f'Emotion: {emotion}')
return story_till_now, emotion
def auto_ist(self,
story_till_now="Hello, I'm a language model,",
num_generation=4,
length=20, reaction_weight=0.5):
stats_df = pd.DataFrame(data=[], columns=[])
stats_dict = dict()
num_reactions = 0
reaction_frequency = 0
emotion = self.get_emotion(story_till_now) # first line emotion
story_data = [{
'sentence': story_till_now,
'turn': 'first',
'emotion': emotion['label'],
'confidence_score': emotion['score'],
}]
for i in range(num_generation):
# Text generation for User
last_length = len(story_till_now)
printj.cyan(story_till_now)
printj.red.bold_on_white(
f'loop: {i}; generate user text; length: {last_length}')
genreate_user_sentence = self.generator(story_till_now, max_length=self.get_num_token(
story_till_now)+length, num_return_sequences=1)
story_till_now = genreate_user_sentence[0]['generated_text']
new_sentence_user = story_till_now[last_length:]
printj.red.bold_on_white(f'loop: {i}; check emotion')
# Emotion self.classifier for User
emotion_user = self.get_emotion(new_sentence_user)
if emotion_user['label'] == 'neutral':
show_emotion_user = False
else:
reaction_frequency = num_reactions/(i+1)
show_emotion_user = self.check_show_emotion(
confidence_score=emotion_user['score'], frequency=reaction_frequency, w=reaction_weight)
if show_emotion_user:
num_reactions += 1
story_data.append({
'sentence': new_sentence_user,
'turn': 'user',
'emotion': emotion_user['label'],
'confidence_score': emotion_user['score'],
})
stats_dict['sentence_no'] = i
stats_dict['turn'] = 'user'
stats_dict['sentence'] = new_sentence_user
stats_dict['show_emotion'] = show_emotion_user
stats_dict['emotion_label'] = emotion_user['label']
stats_dict['emotion_score'] = emotion_user['score']
stats_dict['num_reactions'] = num_reactions
stats_dict['reaction_frequency'] = reaction_frequency
stats_dict['reaction_weight'] = reaction_weight
stats_df = pd.concat(
[stats_df, pd.DataFrame(stats_dict, index=[f'idx_{i}'])])
# Text generation for Robot
last_length = len(story_till_now)
printj.cyan(story_till_now)
printj.red.bold_on_white(
f'loop: {i}; generate robot text; length: {last_length}')
genreate_robot_sentence = self.generator(story_till_now, max_length=self.get_num_token(
story_till_now)+length, num_return_sequences=1)
story_till_now = genreate_robot_sentence[0]['generated_text']
new_sentence_robot = story_till_now[last_length:]
emotion_robot = self.get_emotion(new_sentence_robot)
story_data.append({
'sentence': new_sentence_robot,
'turn': 'robot',
'emotion': emotion_robot['label'],
'confidence_score': emotion_robot['score'],
})
stats_dict['sentence_no'] = i
stats_dict['turn'] = 'robot'
stats_dict['sentence'] = new_sentence_robot
stats_dict['show_emotion'] = None
stats_dict['emotion_label'] = emotion_robot['label']
stats_dict['emotion_score'] = emotion_robot['score']
stats_dict['num_reactions'] = None
stats_dict['reaction_frequency'] = None
stats_dict['reaction_weight'] = None
stats_df = pd.concat(
[stats_df, pd.DataFrame(stats_dict, index=[f'idx_{i}'])])
return stats_df, story_till_now, story_data
def get_stats(self,
story_till_now="Hello, I'm a language model,",
num_generation=4,
length=20, reaction_weight=-1, num_tests=2):
use_random_w = reaction_weight == -1
# self.stories = []
try:
num_rows = max(self.stats_df.story_id)+1
except Exception:
num_rows = 0
for story_id in range(num_tests):
if use_random_w:
# reaction_weight = np.random.random_sample()
reaction_weight = np.round(np.random.random_sample(), 1)
stats_df0, _story_till_now, story_data = self.auto_ist(
story_till_now=story_till_now,
num_generation=4,
length=20, reaction_weight=reaction_weight)
stats_df0.insert(loc=0, column='story_id', value=story_id+num_rows)
# stats_df0['story_id'] = story_id
self.stats_df = pd.concat([self.stats_df, stats_df0])
printj.yellow(f'story_id: {story_id}')
printj.green(stats_df0)
self.stories.append(_story_till_now)
self.data.append(story_data)
self.stats_df = self.stats_df.reset_index(drop=True)
print(self.stats_df)
def save_stats(self, path='pandas_simple.xlsx'):
writer = pd.ExcelWriter(path, engine='xlsxwriter')
# Convert the dataframe to an XlsxWriter Excel object.
self.stats_df.to_excel(writer, sheet_name='IST')
# Close the Pandas Excel writer and output the Excel file.
writer.save()