from __future__ import annotations import json import time from argparse import ArgumentParser from pathlib import Path from typing import Optional import datasets import numpy as np import openai from tqdm.auto import tqdm DELIMITER_0 = "\n##\n" DELIMITER_1 = "\n%%\n" STOP = "\nEND" def generate( openai_model: str, caption: str, num_retries: int = 3, max_tokens: int = 256, temperature: float = 0.7, top_p: float = 1.0, frequency_penalty: float = 0.1, presence_penalty: float = 0.0, sleep_on_error: float = 1.0, ) -> Optional[tuple[str, str]]: for _ in range(1 + num_retries): try: response = openai.Completion.create( model=openai_model, prompt=caption + DELIMITER_0, temperature=temperature, max_tokens=max_tokens, top_p=top_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, stop=[STOP], ) except Exception as e: print(e) time.sleep(sleep_on_error) continue output = response["choices"][0]["text"].split(DELIMITER_1) if len(output) == 2: instruction, edited_caption = output results = openai.Moderation.create([instruction, edited_caption])["results"] if results[0]["flagged"] or results[1]["flagged"]: continue if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower(): return instruction, edited_caption def main(openai_model: str, num_samples: int, num_partitions: int, partition: int, seed: int): dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train") # Other datasets we considered that may be worth trying: # dataset = datasets.load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="train") # dataset = datasets.load_dataset("laion/laion-coco", split="train") np.random.seed(seed) permutation = np.array_split(np.random.permutation(len(dataset)), num_partitions)[partition] dataset = dataset[permutation] captions = dataset["TEXT"] urls = dataset["URL"] output_path = f"data/dataset=laion-aesthetics-6.5_model={openai_model}_samples={num_samples}_partition={partition}.jsonl" # fmt: skip print(f"Prompt file path: {output_path}") count = 0 caption_set = set() url_set = set() if Path(output_path).exists(): with open(output_path, "r") as f: for line in tqdm(f, desc="Resuming from existing prompts"): prompt = json.loads(line) if prompt["caption"] not in caption_set and prompt["url"] not in url_set: caption_set.add(prompt["caption"]) url_set.add(prompt["url"]) count += 1 with open(output_path, "a") as fp: with tqdm(total=num_samples - count, desc="Generating instructions and edited captions") as progress_bar: for caption, url in zip(captions, urls): if caption in caption_set or url in url_set: continue if openai.Moderation.create(caption)["results"][0]["flagged"]: continue edit_output = generate(openai_model, caption) if edit_output is not None: edit, output = edit_output fp.write(f"{json.dumps(dict(caption=caption, edit=edit, output=output, url=url))}\n") count += 1 progress_bar.update() caption_set.add(caption) url_set.add(url) if count == num_samples: break if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--openai-api-key", required=True, type=str) parser.add_argument("--openai-model", required=True, type=str) parser.add_argument("--num-samples", default=10000, type=int) parser.add_argument("--num-partitions", default=1, type=int) parser.add_argument("--partition", default=0, type=int) parser.add_argument("--seed", default=0, type=int) args = parser.parse_args() openai.api_key = args.openai_api_key main(args.openai_model, args.num_samples, args.num_partitions, args.partition, args.seed)