small-shlepa / src /gen /utils.py
hi-melnikov's picture
ruff format everything
d0e8be9
raw
history blame contribute delete
No virus
11 kB
import json
import os
import random
import time
from glob import glob
import yaml
# API setting constants
API_MAX_RETRY = 16
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
OPENAI_MODEL_LIST = (
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0613-verbose",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-turbo",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
)
temperature_config = {
"writing": 0.7,
"roleplay": 0.7,
"extraction": 0.0,
"math": 0.0,
"coding": 0.0,
"reasoning": 0.0,
"stem": 0.1,
"humanities": 0.1,
}
def load_questions(question_file: str):
"""Load questions from a file."""
questions = []
with open(question_file, "r") as ques_file:
for line in ques_file:
if line:
questions.append(json.loads(line))
return questions
def load_model_answers(answer_dir: str):
"""Load model answers.
The return value is a python dict of type:
Dict[model_name: str -> Dict[question_id: int -> answer: dict]]
"""
filenames = glob(os.path.join(answer_dir, "*.jsonl"))
filenames.sort()
model_answers = {}
for filename in filenames:
model_name = os.path.basename(filename)[:-6]
answer = {}
with open(filename) as fin:
for line in fin:
line = json.loads(line)
answer[line["question_id"]] = line
model_answers[model_name] = answer
return model_answers
def get_endpoint(endpoint_list):
if endpoint_list is None:
return None
assert endpoint_list is not None
# randomly pick one
api_dict = random.choices(endpoint_list)[0]
return api_dict
# load config args from config yaml files
def make_config(config_file: str) -> dict:
config_kwargs = {}
with open(config_file, "r") as f:
config_kwargs = yaml.load(f, Loader=yaml.SafeLoader)
return config_kwargs
def chat_completion_gigachat(model, messages, temperature, max_tokens, api_dict=None):
from gigachat import GigaChat
from gigachat.models import Chat, Messages
assert api_dict is not None, "no api settings provided!"
auth_token = api_dict.get("auth_token", os.environ.get(api_dict["auth_token"], ""))
client = GigaChat(credentials=auth_token, model=model, verify_ssl_certs=False)
temperature = max(temperature, 0.001)
messages = [Messages.parse_obj(m) for m in messages]
chat = Chat(messages=messages, max_tokens=max_tokens, temperature=temperature)
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
output = client.chat(chat)
output = output.choices[0].message.content
break
# Don't know other errors
except Exception as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
return output
def chat_completion_yandex(model, messages, temperature, max_tokens, api_dict=None):
from yandex_gpt import YandexGPT, YandexGPTConfigManagerForIAMToken
assert api_dict is not None, "no api settings provided!"
iam_token = api_dict.get("iam_token", os.environ.get(api_dict["iam_token_ENV"], ""))
config = YandexGPTConfigManagerForIAMToken(model_type=model, catalog_id=api_dict["catalog_id"], iam_token=iam_token)
client = YandexGPT(config_manager=config)
messages = [{"role": m["role"], "text": m["content"]} for m in messages]
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
output = client.get_sync_completion(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
break
# Don't know other errors
except Exception as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
return output
def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None):
import openai
api_key = api_dict.get("api_key", os.environ.get(api_dict["api_key_ENV"], ""))
if api_dict:
client = openai.OpenAI(
base_url=api_dict["api_base"],
api_key=api_key,
)
else:
client = openai.OpenAI()
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
# print(messages)
completion = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stop=["</s>", "<eos>", "<|eot_id|>"],
)
output = completion.choices[0].message.content
break
except openai.RateLimitError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
except openai.BadRequestError as e:
print(messages)
print(type(e), e)
except KeyError as e:
print(type(e), e)
break
return output
def chat_completion_openai_azure(model, messages, temperature, max_tokens, api_dict=None):
import openai
from openai import AzureOpenAI
api_base = api_dict["api_base"]
api_key = api_dict.get("api_key", os.environ.get(api_dict["api_key_ENV"], ""))
client = AzureOpenAI(
azure_endpoint=api_base, api_key=api_key, api_version=api_dict["api_version"], timeout=240, max_retries=2
)
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
response = client.chat.completions.create(
model=model,
messages=messages,
n=1,
temperature=temperature,
max_tokens=max_tokens,
seed=42,
)
output = response.choices[0].message.content
break
except openai.RateLimitError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
except openai.BadRequestError as e:
print(type(e), e)
break
except KeyError as e:
print(type(e), e)
break
return output
def chat_completion_anthropic(model, messages, temperature, max_tokens, api_dict=None):
import anthropic
if api_dict:
api_key = api_dict.get("api_key", os.environ.get(api_dict["api_key_ENV"], ""))
else:
api_key = os.environ["ANTHROPIC_API_KEY"]
sys_msg = ""
if messages[0]["role"] == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
# print(sys_msg)
c = anthropic.Anthropic(api_key=api_key)
response = c.messages.create(
model=model,
messages=messages,
stop_sequences=[anthropic.HUMAN_PROMPT],
max_tokens=max_tokens,
temperature=temperature,
system=sys_msg,
)
output = response.content[0].text
break
except anthropic.APIError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
return output
def chat_completion_mistral(model, messages, temperature, max_tokens):
from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.chat_completion import ChatMessage
api_key = os.environ["MISTRAL_API_KEY"]
client = MistralClient(api_key=api_key)
prompts = [ChatMessage(role=message["role"], content=message["content"]) for message in messages]
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
chat_response = client.chat(
model=model,
messages=prompts,
temperature=temperature,
max_tokens=max_tokens,
)
output = chat_response.choices[0].message.content
break
except MistralException as e:
print(type(e), e)
break
return output
def chat_completion_gemini(model, messages, temperature, max_tokens):
import google.generativeai as genai
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]
# Set up the model
generation_config = {
"temperature": temperature,
"top_p": 1,
"top_k": 1,
"max_output_tokens": max_tokens,
}
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
gemini = genai.GenerativeModel(
model_name=model, generation_config=generation_config, safety_settings=safety_settings
)
convo = gemini.start_chat(history=[])
convo.send_message(messages)
output = convo.last.text
break
except genai.types.generation_types.StopCandidateException as e:
print(type(e), e)
break
except Exception as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
return output
def chat_completion_cohere(model, messages, temperature, max_tokens):
import cohere
co = cohere.Client(os.environ["COHERE_API_KEY"])
assert len(messages) > 0
template_map = {"system": "SYSTEM", "assistant": "CHATBOT", "user": "USER"}
assert messages[-1]["role"] == "user"
prompt = messages[-1]["content"]
if len(messages) > 1:
history = []
for message in messages[:-1]:
history.append({"role": template_map[message["role"]], "message": message["content"]})
else:
history = None
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
response = co.chat(
message=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
chat_history=history,
)
output = response.text
break
except cohere.core.api_error.ApiError as e:
print(type(e), e)
raise
except Exception as e:
print(type(e), e)
break
return output
def reorg_answer_file(answer_file):
"""Sort by question id and de-duplication"""
answers = {}
with open(answer_file, "r") as fin:
for line in fin:
qid = json.loads(line)["question_id"]
answers[qid] = line
qids = sorted(list(answers.keys()))
with open(answer_file, "w") as fout:
for qid in qids:
fout.write(answers[qid])