import os import random import numpy as np from PIL import Image import torch if __name__ != '__main__': import open_clip os.environ['CUDA_VISIBLE_DEVICES'] = '' def seed_all(seed = 0): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.use_deterministic_algorithms(True, warn_only=False) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def inference_text(model, model_name, batches): y = [] tokenizer = open_clip.get_tokenizer(model_name) with torch.no_grad(): for x in batches: x = tokenizer(x) y.append(model.encode_text(x)) return torch.stack(y) def inference_image(model, preprocess_val, batches): y = [] with torch.no_grad(): for x in batches: x = torch.stack([preprocess_val(img) for img in x]) y.append(model.encode_image(x)) return torch.stack(y) def forward_model(model, model_name, preprocess_val, image_batch, text_batch): y = [] tokenizer = open_clip.get_tokenizer(model_name) with torch.no_grad(): for x_im, x_txt in zip(image_batch, text_batch): x_im = torch.stack([preprocess_val(im) for im in x_im]) x_txt = tokenizer(x_txt) y.append(model(x_im, x_txt)) if type(y[0]) == dict: out = {} for key in y[0].keys(): out[key] = torch.stack([batch_out[key] for batch_out in y]) else: out = [] for i in range(len(y[0])): out.append(torch.stack([batch_out[i] for batch_out in y])) return out def random_image_batch(batch_size, size): h, w = size data = np.random.randint(255, size = (batch_size, h, w, 3), dtype = np.uint8) return [ Image.fromarray(d) for d in data ] def random_text_batch(batch_size, min_length = 75, max_length = 75): t = open_clip.tokenizer.SimpleTokenizer() # every token decoded as string, exclude SOT and EOT, replace EOW with space token_words = [ x[1].replace('', ' ') for x in t.decoder.items() if x[0] not in t.all_special_ids ] # strings of randomly chosen tokens return [ ''.join(random.choices( token_words, k = random.randint(min_length, max_length) )) for _ in range(batch_size) ] def create_random_text_data( path, min_length = 75, max_length = 75, batches = 1, batch_size = 1 ): text_batches = [ random_text_batch(batch_size, min_length, max_length) for _ in range(batches) ] print(f"{path}") torch.save(text_batches, path) def create_random_image_data(path, size, batches = 1, batch_size = 1): image_batches = [ random_image_batch(batch_size, size) for _ in range(batches) ] print(f"{path}") torch.save(image_batches, path) def get_data_dirs(make_dir = True): data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data') input_dir = os.path.join(data_dir, 'input') output_dir = os.path.join(data_dir, 'output') if make_dir: os.makedirs(input_dir, exist_ok = True) os.makedirs(output_dir, exist_ok = True) assert os.path.isdir(data_dir), f"data directory missing, expected at {input_dir}" assert os.path.isdir(data_dir), f"data directory missing, expected at {output_dir}" return input_dir, output_dir def create_test_data_for_model( model_name, pretrained = None, precision = 'fp32', jit = False, pretrained_hf = False, force_quick_gelu = False, create_missing_input_data = True, batches = 1, batch_size = 1, overwrite = False ): model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}' input_dir, output_dir = get_data_dirs() output_file_text = os.path.join(output_dir, f'{model_id}_random_text.pt') output_file_image = os.path.join(output_dir, f'{model_id}_random_image.pt') text_exists = os.path.exists(output_file_text) image_exists = os.path.exists(output_file_image) if not overwrite and text_exists and image_exists: return seed_all() model, _, preprocess_val = open_clip.create_model_and_transforms( model_name, pretrained = pretrained, precision = precision, jit = jit, force_quick_gelu = force_quick_gelu, pretrained_hf = pretrained_hf ) # text if overwrite or not text_exists: input_file_text = os.path.join(input_dir, 'random_text.pt') if create_missing_input_data and not os.path.exists(input_file_text): create_random_text_data( input_file_text, batches = batches, batch_size = batch_size ) assert os.path.isfile(input_file_text), f"missing input data, expected at {input_file_text}" input_data_text = torch.load(input_file_text) output_data_text = inference_text(model, model_name, input_data_text) print(f"{output_file_text}") torch.save(output_data_text, output_file_text) # image if overwrite or not image_exists: size = model.visual.image_size if not isinstance(size, tuple): size = (size, size) input_file_image = os.path.join(input_dir, f'random_image_{size[0]}_{size[1]}.pt') if create_missing_input_data and not os.path.exists(input_file_image): create_random_image_data( input_file_image, size, batches = batches, batch_size = batch_size ) assert os.path.isfile(input_file_image), f"missing input data, expected at {input_file_image}" input_data_image = torch.load(input_file_image) output_data_image = inference_image(model, preprocess_val, input_data_image) print(f"{output_file_image}") torch.save(output_data_image, output_file_image) def create_test_data( models, batches = 1, batch_size = 1, overwrite = False ): models = list(set(models).difference({ # not available with timm # see https://github.com/mlfoundations/open_clip/issues/219 'timm-convnext_xlarge', 'timm-vit_medium_patch16_gap_256' }).intersection(open_clip.list_models())) models.sort() print(f"generating test data for:\n{models}") for model_name in models: print(model_name) create_test_data_for_model( model_name, batches = batches, batch_size = batch_size, overwrite = overwrite ) return models def _sytem_assert(string): assert os.system(string) == 0 class TestWrapper(torch.nn.Module): output_dict: torch.jit.Final[bool] def __init__(self, model, model_name, output_dict=True) -> None: super().__init__() self.model = model self.output_dict = output_dict if type(model) in [open_clip.CLIP, open_clip.CustomTextCLIP]: self.model.output_dict = self.output_dict config = open_clip.get_model_config(model_name) self.head = torch.nn.Linear(config["embed_dim"], 2) def forward(self, image, text): x = self.model(image, text) if self.output_dict: out = self.head(x["image_features"]) else: out = self.head(x[0]) return {"test_output": out} def main(args): global open_clip import importlib import shutil import subprocess import argparse parser = argparse.ArgumentParser(description = "Populate test data directory") parser.add_argument( '-a', '--all', action = 'store_true', help = "create test data for all models" ) parser.add_argument( '-m', '--model', type = str, default = [], nargs = '+', help = "model(s) to create test data for" ) parser.add_argument( '-f', '--model_list', type = str, help = "path to a text file containing a list of model names, one model per line" ) parser.add_argument( '-s', '--save_model_list', type = str, help = "path to save the list of models that data was generated for" ) parser.add_argument( '-g', '--git_revision', type = str, help = "git revision to generate test data for" ) parser.add_argument( '--overwrite', action = 'store_true', help = "overwrite existing output data" ) parser.add_argument( '-n', '--num_batches', default = 1, type = int, help = "amount of data batches to create (default: 1)" ) parser.add_argument( '-b', '--batch_size', default = 1, type = int, help = "test data batch size (default: 1)" ) args = parser.parse_args(args) model_list = [] if args.model_list is not None: with open(args.model_list, 'r') as f: model_list = f.read().splitlines() if not args.all and len(args.model) < 1 and len(model_list) < 1: print("error: at least one model name is required") parser.print_help() parser.exit(1) if args.git_revision is not None: stash_output = subprocess.check_output(['git', 'stash']).decode().splitlines() has_stash = len(stash_output) > 0 and stash_output[0] != 'No local changes to save' current_branch = subprocess.check_output(['git', 'branch', '--show-current']) if len(current_branch) < 1: # not on a branch -> detached head current_branch = subprocess.check_output(['git', 'rev-parse', 'HEAD']) current_branch = current_branch.splitlines()[0].decode() try: _sytem_assert(f'git checkout {args.git_revision}') except AssertionError as e: _sytem_assert(f'git checkout -f {current_branch}') if has_stash: os.system(f'git stash pop') raise e open_clip = importlib.import_module('open_clip') models = open_clip.list_models() if args.all else args.model + model_list try: models = create_test_data( models, batches = args.num_batches, batch_size = args.batch_size, overwrite = args.overwrite ) finally: if args.git_revision is not None: test_dir = os.path.join(os.path.dirname(__file__), 'data') test_dir_ref = os.path.join(os.path.dirname(__file__), 'data_ref') if os.path.exists(test_dir_ref): shutil.rmtree(test_dir_ref, ignore_errors = True) if os.path.exists(test_dir): os.rename(test_dir, test_dir_ref) _sytem_assert(f'git checkout {current_branch}') if has_stash: os.system(f'git stash pop') os.rename(test_dir_ref, test_dir) if args.save_model_list is not None: print(f"Saving model list as {args.save_model_list}") with open(args.save_model_list, 'w') as f: for m in models: print(m, file=f) if __name__ == '__main__': import sys main(sys.argv[1:])