File size: 8,665 Bytes
eaa3d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# src.kg.save_triples.py
from pathlib import Path
import json
import argparse
import os

from pysbd import Segmenter
from tiktoken import Encoding

from .knowledge_graph import PROMPT_FILE_PATH
from .openai_api import (RESPONSES_DIRECTORY_PATH,
                        get_max_chapter_segment_token_count,
                        get_openai_model_encoding, save_openai_api_response)
from .utils import (execute_function_in_parallel, set_up_logging,
                   strip_and_remove_empty_strings)

logger = set_up_logging('openai-api-scripts.log')


def get_paragraphs(text):
    """Split a text into paragraphs."""
    paragraphs = strip_and_remove_empty_strings(text.split('\n\n'))
    # Convert all whitespace into single spaces.
    paragraphs = [' '.join(paragraph.split()) for paragraph in paragraphs]
    return paragraphs


def combine_text_subunits_into_segments(subunits, join_string,
                                        encoding: Encoding,
                                        max_token_count):
    """
    Combine subunits of text into segments that do not exceed a maximum number
    of tokens.
    """
    # `encode_ordinary_batch()` ignores special tokens and is slightly faster
    # than `encode_batch()`.
    subunit_token_counts = [len(tokens) for tokens
                            in encoding.encode_ordinary_batch(subunits)]
    join_string_token_count = len(encoding.encode_ordinary(join_string))
    total_token_count = (sum(subunit_token_counts) + join_string_token_count
                         * (len(subunits) - 1))
    if total_token_count <= max_token_count:
        return [join_string.join(subunits)]
    # Calculate the approximate number of segments and the approximate number
    # of tokens per segment, in order to keep the segment lengths roughly
    # equal.
    approximate_segment_count = total_token_count // max_token_count + 1
    approximate_segment_token_count = round(total_token_count
                                            / approximate_segment_count)
    segments = []
    current_segment_subunits = []
    current_segment_token_count = 0
    for i, (subunit, subunit_token_count) in enumerate(
            zip(subunits, subunit_token_counts)):
        # The token count if the current subunit is added to the current
        # segment.
        extended_segment_token_count = (current_segment_token_count
                                        + join_string_token_count
                                        + subunit_token_count)
        # Add the current subunit to the current segment if it results in a
        # token count that is closer to the approximate segment token count
        # than the current segment token count.
        if (extended_segment_token_count <= max_token_count
                and abs(extended_segment_token_count
                        - approximate_segment_token_count)
                <= abs(current_segment_token_count
                       - approximate_segment_token_count)):
            current_segment_subunits.append(subunit)
            current_segment_token_count = extended_segment_token_count
        else:
            segment = join_string.join(current_segment_subunits)
            segments.append(segment)
            # If it is possible to join the remaining subunits into a single
            # segment, do so. Additionally, add the current subunit as a
            # segment if it is the last subunit.
            if (sum(subunit_token_counts[i:]) + join_string_token_count
                    * (len(subunits) - i - 1) <= max_token_count
                    or i == len(subunits) - 1):
                segment = join_string.join(subunits[i:])
                segments.append(segment)
                break
            current_segment_subunits = [subunit]
            current_segment_token_count = subunit_token_count
    return segments


def split_long_sentences(sentences, encoding: Encoding,
                         max_token_count):
    """
    Given a list of sentences, split sentences that exceed a maximum number of
    tokens into multiple segments.
    """
    token_counts = [len(tokens) for tokens
                    in encoding.encode_ordinary_batch(sentences)]
    split_sentences = []
    for sentence, token_count in zip(sentences, token_counts):
        if token_count > max_token_count:
            words = sentence.split()
            segments = combine_text_subunits_into_segments(
                words, ' ', encoding, max_token_count)
            split_sentences.extend(segments)
        else:
            split_sentences.append(sentence)
    return split_sentences


def split_long_paragraphs(paragraphs, encoding: Encoding,
                          max_token_count):
    """
    Given a list of paragraphs, split paragraphs that exceed a maximum number
    of tokens into multiple segments.
    """
    token_counts = [len(tokens) for tokens
                    in encoding.encode_ordinary_batch(paragraphs)]
    split_paragraphs = []
    for paragraph, token_count in zip(paragraphs, token_counts):
        if token_count > max_token_count:
            sentences = Segmenter().segment(paragraph)
            sentences = split_long_sentences(sentences, encoding,
                                             max_token_count)
            segments = combine_text_subunits_into_segments(
                sentences, ' ', encoding, max_token_count)
            split_paragraphs.extend(segments)
        else:
            split_paragraphs.append(paragraph)
    return split_paragraphs


def get_chapter_segments(chapter_text, encoding: Encoding,
                         max_token_count):
    """
    Split a chapter text into segments that do not exceed a maximum number of
    tokens.
    """
    paragraphs = get_paragraphs(chapter_text)
    paragraphs = split_long_paragraphs(paragraphs, encoding, max_token_count)
    chapter_segments = combine_text_subunits_into_segments(
        paragraphs, '\n', encoding, max_token_count)
    return chapter_segments


def get_response_save_path(idx, save_path, project_gutenberg_id,
                           chapter_index = None,
                           chapter_segment_index = None,
                           chapter_segment_count = None):
    """
    Get the path to the JSON file(s) containing response data from the OpenAI
    API.
    """
    save_path = Path(save_path)
    os.makedirs(save_path, exist_ok=True)
    
    if chapter_index is not None:
        save_path /= str(chapter_index)
        if chapter_segment_index is not None:
            save_path /= (f'{chapter_segment_index + 1}-of-'
                          f'{chapter_segment_count}.json')
    return save_path




def save_openai_api_responses_for_script(script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id):
    """
    Call the OpenAI API for each chapter segment in a script and save the
    responses to a list.
    """
    project_gutenberg_id = script['id']
    chapter_count = len(script['chapters'])
    logger.info(f'Starting to call OpenAI API and process responses for script '
                f'{project_gutenberg_id} ({chapter_count} chapters).')

    prompt_message_lists = []
    response_list = []

    for chapter in script['chapters']:
        chapter_index = chapter['index']
        chapter_segments = chapter['text']
        chapter_segment_count = len(chapter_segments)

        for chapter_segment_index, chapter_segment in enumerate(chapter_segments):
            prompt_with_story = prompt.replace('{STORY}', chapter_segment)
            prompt_message_lists.append([{
                'role': 'user',
                'content': prompt_with_story,
                'api_key': api_key,
                'model_id': model_id
            }])

    responses = execute_function_in_parallel(save_openai_api_response, prompt_message_lists)

    for response in responses:
        response_list.append(response)

    logger.info(f'Finished processing responses for script {project_gutenberg_id}.')
    return response_list


def save_triples_for_scripts(input_data, idx, api_key, model_id):
    """
    Call the OpenAI API to generate knowledge graph nodes and edges, and store
    the responses in a list.
    """
    # 1) load data
    script = input_data

    # 2) call OpenAI API
    prompt = PROMPT_FILE_PATH.read_text()  # load prompt
    max_chapter_segment_token_count = get_max_chapter_segment_token_count(prompt, model_id)
    encoding = get_openai_model_encoding(model_id)
    responses = save_openai_api_responses_for_script(
        script, prompt, encoding, max_chapter_segment_token_count, idx, api_key, model_id
    )

    return responses