visual-arena / fastchat /serve /monitor /tag_openai_moderation.py
tianleliphoebe's picture
Upload folder using huggingface_hub
ec0c335 verified
"""
Add OpenAI moderation API results to all conversations.
"""
import argparse
from concurrent.futures import ThreadPoolExecutor
import json
import os
import time
import openai
import requests
from tqdm import tqdm
API_MAX_RETRY = 16
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
def tag_moderation(text):
result = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
result = openai.Moderation.create(input=text)["results"][0]
break
except openai.error.OpenAIError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
return result
def tag_openai_moderation(x):
conv = x["conversation_a"]
user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"])
result = tag_moderation(user_prompts)
x["openai_moderation"] = result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
parser.add_argument(
"--parallel", type=int, default=1, help="The number of concurrent API calls."
)
parser.add_argument("--first-n", type=int)
args = parser.parse_args()
battles = json.load(open(args.input))
if args.first_n:
battles = battles[: args.first_n]
with ThreadPoolExecutor(args.parallel) as executor:
for line in tqdm(
executor.map(tag_openai_moderation, battles), total=len(battles)
):
pass
output = args.input.replace(".json", "_tagged.json")
with open(output, "w") as fout:
json.dump(battles, fout, indent=2, ensure_ascii=False)
print(f"Write cleaned data to {output}")