NarrativeFactScore / src /kg /save_triples.py
JihyukKim's picture
Initial commit
eaa3d8a
# src.kg.save_triples.py
from pathlib import Path
import json
import argparse
import os
from pysbd import Segmenter
from tiktoken import Encoding
from .knowledge_graph import PROMPT_FILE_PATH
from .openai_api import (RESPONSES_DIRECTORY_PATH,
get_max_chapter_segment_token_count,
get_openai_model_encoding, save_openai_api_response)
from .utils import (execute_function_in_parallel, set_up_logging,
strip_and_remove_empty_strings)
logger = set_up_logging('openai-api-scripts.log')
def get_paragraphs(text):
"""Split a text into paragraphs."""
paragraphs = strip_and_remove_empty_strings(text.split('\n\n'))
# Convert all whitespace into single spaces.
paragraphs = [' '.join(paragraph.split()) for paragraph in paragraphs]
return paragraphs
def combine_text_subunits_into_segments(subunits, join_string,
encoding: Encoding,
max_token_count):
"""
Combine subunits of text into segments that do not exceed a maximum number
of tokens.
"""
# `encode_ordinary_batch()` ignores special tokens and is slightly faster
# than `encode_batch()`.
subunit_token_counts = [len(tokens) for tokens
in encoding.encode_ordinary_batch(subunits)]
join_string_token_count = len(encoding.encode_ordinary(join_string))
total_token_count = (sum(subunit_token_counts) + join_string_token_count
* (len(subunits) - 1))
if total_token_count <= max_token_count:
return [join_string.join(subunits)]
# Calculate the approximate number of segments and the approximate number
# of tokens per segment, in order to keep the segment lengths roughly
# equal.
approximate_segment_count = total_token_count // max_token_count + 1
approximate_segment_token_count = round(total_token_count
/ approximate_segment_count)
segments = []
current_segment_subunits = []
current_segment_token_count = 0
for i, (subunit, subunit_token_count) in enumerate(
zip(subunits, subunit_token_counts)):
# The token count if the current subunit is added to the current
# segment.
extended_segment_token_count = (current_segment_token_count
+ join_string_token_count
+ subunit_token_count)
# Add the current subunit to the current segment if it results in a
# token count that is closer to the approximate segment token count
# than the current segment token count.
if (extended_segment_token_count <= max_token_count
and abs(extended_segment_token_count
- approximate_segment_token_count)
<= abs(current_segment_token_count
- approximate_segment_token_count)):
current_segment_subunits.append(subunit)
current_segment_token_count = extended_segment_token_count
else:
segment = join_string.join(current_segment_subunits)
segments.append(segment)
# If it is possible to join the remaining subunits into a single
# segment, do so. Additionally, add the current subunit as a
# segment if it is the last subunit.
if (sum(subunit_token_counts[i:]) + join_string_token_count
* (len(subunits) - i - 1) <= max_token_count
or i == len(subunits) - 1):
segment = join_string.join(subunits[i:])
segments.append(segment)
break
current_segment_subunits = [subunit]
current_segment_token_count = subunit_token_count
return segments
def split_long_sentences(sentences, encoding: Encoding,
max_token_count):
"""
Given a list of sentences, split sentences that exceed a maximum number of
tokens into multiple segments.
"""
token_counts = [len(tokens) for tokens
in encoding.encode_ordinary_batch(sentences)]
split_sentences = []
for sentence, token_count in zip(sentences, token_counts):
if token_count > max_token_count:
words = sentence.split()
segments = combine_text_subunits_into_segments(
words, ' ', encoding, max_token_count)
split_sentences.extend(segments)
else:
split_sentences.append(sentence)
return split_sentences
def split_long_paragraphs(paragraphs, encoding: Encoding,
max_token_count):
"""
Given a list of paragraphs, split paragraphs that exceed a maximum number
of tokens into multiple segments.
"""
token_counts = [len(tokens) for tokens
in encoding.encode_ordinary_batch(paragraphs)]
split_paragraphs = []
for paragraph, token_count in zip(paragraphs, token_counts):
if token_count > max_token_count:
sentences = Segmenter().segment(paragraph)
sentences = split_long_sentences(sentences, encoding,
max_token_count)
segments = combine_text_subunits_into_segments(
sentences, ' ', encoding, max_token_count)
split_paragraphs.extend(segments)
else:
split_paragraphs.append(paragraph)
return split_paragraphs
def get_chapter_segments(chapter_text, encoding: Encoding,
max_token_count):
"""
Split a chapter text into segments that do not exceed a maximum number of
tokens.
"""
paragraphs = get_paragraphs(chapter_text)
paragraphs = split_long_paragraphs(paragraphs, encoding, max_token_count)
chapter_segments = combine_text_subunits_into_segments(
paragraphs, '\n', encoding, max_token_count)
return chapter_segments
def get_response_save_path(idx, save_path, project_gutenberg_id,
chapter_index = None,
chapter_segment_index = None,
chapter_segment_count = None):
"""
Get the path to the JSON file(s) containing response data from the OpenAI
API.
"""
save_path = Path(save_path)
os.makedirs(save_path, exist_ok=True)
if chapter_index is not None:
save_path /= str(chapter_index)
if chapter_segment_index is not None:
save_path /= (f'{chapter_segment_index + 1}-of-'
f'{chapter_segment_count}.json')
return save_path
def save_openai_api_responses_for_script(script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id):
"""
Call the OpenAI API for each chapter segment in a script and save the
responses to a list.
"""
project_gutenberg_id = script['id']
chapter_count = len(script['chapters'])
logger.info(f'Starting to call OpenAI API and process responses for script '
f'{project_gutenberg_id} ({chapter_count} chapters).')
prompt_message_lists = []
response_list = []
for chapter in script['chapters']:
chapter_index = chapter['index']
chapter_segments = chapter['text']
chapter_segment_count = len(chapter_segments)
for chapter_segment_index, chapter_segment in enumerate(chapter_segments):
prompt_with_story = prompt.replace('{STORY}', chapter_segment)
prompt_message_lists.append([{
'role': 'user',
'content': prompt_with_story,
'api_key': api_key,
'model_id': model_id
}])
responses = execute_function_in_parallel(save_openai_api_response, prompt_message_lists)
for response in responses:
response_list.append(response)
logger.info(f'Finished processing responses for script {project_gutenberg_id}.')
return response_list
def save_triples_for_scripts(input_data, idx, api_key, model_id):
"""
Call the OpenAI API to generate knowledge graph nodes and edges, and store
the responses in a list.
"""
# 1) load data
script = input_data
# 2) call OpenAI API
prompt = PROMPT_FILE_PATH.read_text() # load prompt
max_chapter_segment_token_count = get_max_chapter_segment_token_count(prompt, model_id)
encoding = get_openai_model_encoding(model_id)
responses = save_openai_api_responses_for_script(
script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id
)
return responses