File size: 3,897 Bytes
eaa3d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# src.kg.openai_api.py
import json
import logging
import os
from pathlib import Path
import openai
from dotenv import load_dotenv
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
Timeout, APIConnectionError, InvalidRequestError)
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
stop_after_delay, wait_random_exponential, stop_after_attempt)
from tiktoken import Encoding, encoding_for_model
logger = logging.getLogger(__name__)
load_dotenv()
# This value is set by OpenAI for the selected model and cannot be changed.
MAX_MODEL_TOKEN_COUNT = 4096
# This value can be changed.
MAX_RESPONSE_TOKEN_COUNT = 512
RESPONSES_DIRECTORY_PATH = Path('../openai-api-responses-new')
def get_openai_model_encoding(model_id):
"""Get the encoding (tokenizer) for the OpenAI model."""
return encoding_for_model(model_id)
def get_max_chapter_segment_token_count(prompt: str, model_id: str) -> int:
"""
Calculate the maximum number of tokens that a chapter segment may contain
given the prompt.
"""
encoding = get_openai_model_encoding(model_id)
# `encode_ordinary()` ignores special tokens and is slightly faster than
# `encode()`.
prompt_token_count = len(encoding.encode_ordinary(prompt))
# Subtract 8 for tokens added by OpenAI in the prompt and response (refer
# to https://platform.openai.com/docs/guides/chat/managing-tokens for
# details).
# Subtract 1 for the newline added below to the end of the prompt.
# This calculation does not have to be exact.
max_chapter_segment_token_count = (MAX_MODEL_TOKEN_COUNT
- MAX_RESPONSE_TOKEN_COUNT
- prompt_token_count - 8 - 1)
return max_chapter_segment_token_count
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
before_sleep=before_sleep_log(logger, logging.WARNING))
def save_openai_api_response(prompt_messages):
"""
Use a prompt to make a request to the OpenAI API and return the response data.
"""
openai.api_key = prompt_messages[0]['api_key'] # Set the API key for OpenAI
model_id = prompt_messages[0]['model_id'] # Get the model ID from the prompt messages
prompt_messages[0].pop('api_key') # Remove the API key from the prompt messages
prompt_messages[0].pop('model_id') # Remove the model ID from the prompt messages
try:
logger.info('Calling OpenAI API...')
response = openai.ChatCompletion.create(
model=model_id, messages=prompt_messages, temperature=0
)
finish_reason = response.choices[0].finish_reason
if finish_reason != 'stop':
logger.error(f'`finish_reason` is `{finish_reason}`.')
save_data = {
'model': response.model,
'usage': response.usage,
'finish_reason': finish_reason,
'prompt_messages': prompt_messages,
'response': response.choices[0].message.content
}
except InvalidRequestError:
logger.error('InvalidRequestError encountered 10 times. Returning empty response.')
save_data = {
'model': None,
'usage': None,
'finish_reason': 'invalid_request',
'prompt_messages': prompt_messages,
'response': ' '
}
return save_data
def load_response_text(save_path):
"""
Load the response text from a JSON file containing response data from the
OpenAI API.
"""
with open(save_path, 'r') as save_file:
save_data = json.load(save_file)
return save_data['response']
|