""" Helper scripts for generating synthetic images using diffusion model. Functions: - get_top_misclassified - get_class_list - generateClassPairs - outputDirectory - pipe_img - createPrompts - interpolatePrompts - slerp - get_middle_elements - remove_middle - genClassImg - getMetadata - groupbyInterpolation - ungroupInterpolation - groupAllbyInterpolation - getPairIndices - generateImagesFromDataset - generateTrace """ import json import os import numpy as np import pandas as pd import torch from DeepCache import DeepCacheSDHelper from diffusers import ( LMSDiscreteScheduler, StableDiffusionImg2ImgPipeline, ) from torch import nn from torchmetrics.functional.image import structural_similarity_index_measure as ssim from torchvision import transforms def get_top_misclassified(val_classifier_json): """ Retrieves the top misclassified classes from a validation classifier JSON file. Args: val_classifier_json (str): The path to the validation classifier JSON file. Returns: dict: A dictionary containing the top misclassified classes, where the keys are the class names and the values are the number of misclassifications. """ with open(val_classifier_json) as f: val_output = json.load(f) val_metrics_df = pd.DataFrame.from_dict( val_output["val_metrics_details"], orient="index" ) class_dict = dict() for k, v in val_metrics_df["top_n_classes"].items(): class_dict[k] = v return class_dict def get_class_list(val_classifier_json): """ Retrieves the list of classes from the given validation classifier JSON file. Args: val_classifier_json (str): The path to the validation classifier JSON file. Returns: list: A sorted list of class names extracted from the JSON file. """ with open(val_classifier_json, "r") as f: data = json.load(f) return sorted(list(data["val_metrics_details"].keys())) def generateClassPairs(val_classifier_json): """ Generate pairs of misclassified classes from the given validation classifier JSON. Args: val_classifier_json (str): The path to the validation classifier JSON file. Returns: list: A sorted list of pairs of misclassified classes. """ pairs = set() misclassified_classes = get_top_misclassified(val_classifier_json) for key, value in misclassified_classes.items(): for v in value: pairs.add(tuple(sorted([key, v]))) return sorted(list(pairs)) def outputDirectory(class_pairs, synth_path, metadata_path): """ Creates the output directory structure for the synthesized data. Args: class_pairs (list): A list of class pairs. synth_path (str): The path to the directory where the synthesized data will be stored. metadata_path (str): The path to the directory where the metadata will be stored. Returns: None """ for id in class_pairs: class_folder = f"{synth_path}/{id}" if not (os.path.exists(class_folder)): os.makedirs(class_folder) if not (os.path.exists(metadata_path)): os.makedirs(metadata_path) print("Info: Output directory ready.") def pipe_img( model_path, device="cuda", apply_optimization=True, use_torchcompile=False, ci_cb=(5, 1), use_safetensors=None, cpu_offload=False, scheduler=None, ): """ Creates and returns an image-to-image pipeline for stable diffusion. Args: model_path (str): The path to the pretrained model. device (str, optional): The device to use for computation. Defaults to "cuda". apply_optimization (bool, optional): Whether to apply optimization techniques. Defaults to True. use_torchcompile (bool, optional): Whether to use torchcompile for model compilation. Defaults to False. ci_cb (tuple, optional): A tuple containing the cache interval and cache branch ID. Defaults to (5, 1). use_safetensors (bool, optional): Whether to use safetensors. Defaults to None. cpu_offload (bool, optional): Whether to enable CPU offloading. Defaults to False. scheduler (LMSDiscreteScheduler, optional): The scheduler for the pipeline. Defaults to None. Returns: StableDiffusionImg2ImgPipeline: The image-to-image pipeline for stable diffusion. """ ############################### # Reference: # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024). ############################### if scheduler is None: scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, steps_offset=1, ) pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_path, scheduler=scheduler, torch_dtype=torch.float32, use_safetensors=use_safetensors, ).to(device) if cpu_offload: pipe.enable_model_cpu_offload() if apply_optimization: # tomesd.apply_patch(pipe, ratio=0.5) helper = DeepCacheSDHelper(pipe=pipe) cache_interval, cache_branch_id = ci_cb helper.set_params( cache_interval=cache_interval, cache_branch_id=cache_branch_id ) # lower is faster but lower quality helper.enable() # if torch.cuda.is_available(): # pipe.to("cuda") # pipe.enable_xformers_memory_efficient_attention() if use_torchcompile: pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) return pipe def createPrompts( class_name_pairs, prompt_structure=None, use_default_negative_prompt=False, negative_prompt=None, ): """ Create prompts for image generation. Args: class_name_pairs (list): A list of two class names. prompt_structure (str, optional): The structure of the prompt. Defaults to "a photo of a ". use_default_negative_prompt (bool, optional): Whether to use the default negative prompt. Defaults to False. negative_prompt (str, optional): The negative prompt to steer the generation away from certain features. Returns: tuple: A tuple containing two lists - prompts and negative_prompts. prompts (list): Text prompts that describe the desired output image. negative_prompts (list): Negative prompts that can be used to steer the generation away from certain features. """ if prompt_structure is None: prompt_structure = "a photo of a " elif "" not in prompt_structure: raise ValueError( "The prompt structure must contain the placeholder." ) if use_default_negative_prompt: default_negative_prompt = ( "blurry image, disfigured, deformed, distorted, cartoon, drawings" ) negative_prompt = default_negative_prompt class1 = class_name_pairs[0] class2 = class_name_pairs[1] prompt1 = prompt_structure.replace("", class1) prompt2 = prompt_structure.replace("", class2) prompts = [prompt1, prompt2] if negative_prompt is None: print("Info: Negative prompt not provided, returning as None.") return prompts, None else: # Negative prompts that can be used to steer the generation away from certain features. negative_prompts = [negative_prompt] * len(prompts) return prompts, negative_prompts def interpolatePrompts( prompts, pipeline, num_interpolation_steps, sample_mid_interpolation, remove_n_middle=0, device="cuda", ): """ Interpolates prompts by generating intermediate embeddings between pairs of prompts. Args: prompts (List[str]): A list of prompts to be interpolated. pipeline: The pipeline object containing the tokenizer and text encoder. num_interpolation_steps (int): The number of interpolation steps between each pair of prompts. sample_mid_interpolation (int): The number of intermediate embeddings to sample from the middle of the interpolated prompts. remove_n_middle (int, optional): The number of middle embeddings to remove from the interpolated prompts. Defaults to 0. device (str, optional): The device to run the interpolation on. Defaults to "cuda". Returns: interpolated_prompt_embeds (torch.Tensor): The interpolated prompt embeddings. prompt_metadata (dict): Metadata about the interpolation process, including similarity scores and nearest class information. e.g. if num_interpolation_steps = 10, sample_mid_interpolation = 6, remove_n_middle = 2 Interpolated: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] Sampled: [2, 3, 4, 5, 6, 7] Removed: x x Returns: [2, 3, 6, 7] """ ############################### # Reference: # Akimov, R. (2024) Images Interpolation with Stable Diffusion - Hugging Face Open-Source AI Cookbook. Available at: https://huggingface.co/learn/cookbook/en/stable_diffusion_interpolation (Accessed: 4 June 2024). ############################### def slerp(v0, v1, num, t0=0, t1=1): """ Performs spherical linear interpolation between two vectors. Args: v0 (torch.Tensor): The starting vector. v1 (torch.Tensor): The ending vector. num (int): The number of interpolation points. t0 (float, optional): The starting time. Defaults to 0. t1 (float, optional): The ending time. Defaults to 1. Returns: torch.Tensor: The interpolated vectors. """ ############################### # Reference: # Karpathy, A. (2022) hacky stablediffusion code for generating videos, Gist. Available at: https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355 (Accessed: 4 June 2024). ############################### v0 = v0.detach().cpu().numpy() v1 = v1.detach().cpu().numpy() def interpolation(t, v0, v1, DOT_THRESHOLD=0.9995): """helper function to spherically interpolate two arrays v1 v2""" dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) if np.abs(dot) > DOT_THRESHOLD: v2 = (1 - t) * v0 + t * v1 else: theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) theta_t = theta_0 * t sin_theta_t = np.sin(theta_t) s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 v2 = s0 * v0 + s1 * v1 return v2 t = np.linspace(t0, t1, num) v3 = torch.tensor(np.array([interpolation(t[i], v0, v1) for i in range(num)])) return v3 def get_middle_elements(lst, n): """ Returns a tuple containing a sublist of the middle elements of the given list `lst` and a range of indices of those elements. Args: lst (list): The list from which to extract the middle elements. n (int): The number of middle elements to extract. Returns: tuple: A tuple containing the sublist of middle elements and a range of indices. Raises: None Examples: lst = [1, 2, 3, 4, 5] get_middle_elements(lst, 3) ([2, 3, 4], range(2, 5)) """ if n % 2 == 0: # Even number of elements middle_index = len(lst) // 2 - 1 start = middle_index - n // 2 + 1 end = middle_index + n // 2 + 1 return lst[start:end], range(start, end) else: # Odd number of elements middle_index = len(lst) // 2 start = middle_index - n // 2 end = middle_index + n // 2 + 1 return lst[start:end], range(start, end) def remove_middle(data, n): """ Remove the middle n elements from a list. Args: data (list): The input list. n (int): The number of elements to remove from the middle of the list. Returns: list: The modified list with the middle n elements removed. Raises: ValueError: If n is negative or greater than the length of the list. """ if n < 0 or n > len(data): raise ValueError( "Invalid value for n. It should be non-negative and less than half the list length" ) # Find the middle index middle = len(data) // 2 # Create slices to exclude the middle n elements if n == 1: return data[:middle] + data[middle + 1 :] elif n % 2 == 0: return data[: middle - n // 2] + data[middle + n // 2 :] else: return data[: middle - n // 2] + data[middle + n // 2 + 1 :] batch_size = len(prompts) # Tokenizing and encoding prompts into embeddings. prompts_tokens = pipeline.tokenizer( prompts, padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) prompts_embeds = pipeline.text_encoder(prompts_tokens.input_ids.to(device))[0] # Interpolating between embeddings pairs for the given number of interpolation steps. interpolated_prompt_embeds = [] for i in range(batch_size - 1): interpolated_prompt_embeds.append( slerp(prompts_embeds[i], prompts_embeds[i + 1], num_interpolation_steps) ) full_interpolated_prompt_embeds = interpolated_prompt_embeds[:] interpolated_prompt_embeds[0], sample_range = get_middle_elements( interpolated_prompt_embeds[0], sample_mid_interpolation ) if remove_n_middle > 0: interpolated_prompt_embeds[0] = remove_middle( interpolated_prompt_embeds[0], remove_n_middle ) prompt_metadata = dict() similarity = nn.CosineSimilarity(dim=-1, eps=1e-6) for i in range(num_interpolation_steps): class1_sim = ( similarity( full_interpolated_prompt_embeds[0][0], full_interpolated_prompt_embeds[0][i], ) .mean() .item() ) class2_sim = ( similarity( full_interpolated_prompt_embeds[0][num_interpolation_steps - 1], full_interpolated_prompt_embeds[0][i], ) .mean() .item() ) relative_distance = class1_sim / (class1_sim + class2_sim) prompt_metadata[i] = { "selected": i in sample_range, "similarity": { "class1": class1_sim, "class2": class2_sim, "class1_relative_distance": relative_distance, "class2_relative_distance": 1 - relative_distance, }, "nearest_class": int(relative_distance < 0.5), } interpolated_prompt_embeds = torch.cat(interpolated_prompt_embeds, dim=0).to(device) return interpolated_prompt_embeds, prompt_metadata def genClassImg( pipeline, pos_embed, neg_embed, input_image, generator, latents, num_imgs=1, height=512, width=512, num_inference_steps=25, guidance_scale=7.5, ): """ Generate class image using the given inputs. Args: pipeline: The pipeline object used for image generation. pos_embed: The positive embedding for the class. neg_embed: The negative embedding for the class (optional). input_image: The input image for guidance (optional). generator: The generator model used for image generation. latents: The latent vectors used for image generation. num_imgs: The number of images to generate (default is 1). height: The height of the generated images (default is 512). width: The width of the generated images (default is 512). num_inference_steps: The number of inference steps for image generation (default is 25). guidance_scale: The scale factor for guidance (default is 7.5). Returns: The generated class image. """ if neg_embed is not None: npe = neg_embed[None, ...] else: npe = None return pipeline( height=height, width=width, num_images_per_prompt=num_imgs, prompt_embeds=pos_embed[None, ...], negative_prompt_embeds=npe, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, latents=latents, image=input_image, ).images[0] def getMetadata( class_pairs, path, seed, guidance_scale, num_inference_steps, num_interpolation_steps, sample_mid_interpolation, height, width, prompts, negative_prompts, pipeline, prompt_metadata, negative_prompt_metadata, ssim_metadata=None, save_json=True, save_path=".", ): """ Generate metadata for the given parameters. Args: class_pairs (list): List of class pairs. path (str): Path to the data. seed (int): Seed value for randomization. guidance_scale (float): Scale factor for guidance. num_inference_steps (int): Number of inference steps. num_interpolation_steps (int): Number of interpolation steps. sample_mid_interpolation (bool): Flag to sample mid-interpolation. height (int): Height of the image. width (int): Width of the image. prompts (list): List of prompts. negative_prompts (list): List of negative prompts. pipeline (object): Pipeline object. prompt_metadata (dict): Metadata for prompts. negative_prompt_metadata (dict): Metadata for negative prompts. ssim_metadata (dict, optional): SSIM scores metadata. Defaults to None. save_json (bool, optional): Flag to save metadata as JSON. Defaults to True. save_path (str, optional): Path to save the JSON file. Defaults to ".". Returns: dict: Generated metadata. """ metadata = dict() metadata["class_pairs"] = class_pairs metadata["path"] = path metadata["seed"] = seed metadata["params"] = { "CFG": guidance_scale, "inferenceSteps": num_inference_steps, "interpolationSteps": num_interpolation_steps, "sampleMidInterpolation": sample_mid_interpolation, "height": height, "width": width, } for i in range(len(prompts)): metadata[f"prompt_text_{i}"] = prompts[i] if negative_prompts is not None: metadata[f"negative_prompt_text_{i}"] = negative_prompts[i] metadata["pipe_config"] = dict(pipeline.config) metadata["prompt_embed_similarity"] = prompt_metadata metadata["negative_prompt_embed_similarity"] = negative_prompt_metadata if ssim_metadata is not None: print("Info: SSIM scores are available.") metadata["ssim_scores"] = ssim_metadata if save_json: with open( os.path.join(save_path, f"{'_'.join(i for i in class_pairs)}_{seed}.json"), "w", ) as f: json.dump(metadata, f, indent=4) return metadata def groupbyInterpolation(dir_to_classfolder): """ Group files in a directory by interpolation step. Args: dir_to_classfolder (str): The path to the directory containing the files. Returns: None """ files = [ (f.split(sep="_")[1].split(sep=".")[0], os.path.join(dir_to_classfolder, f)) for f in os.listdir(dir_to_classfolder) ] # create a subfolder for each step of the interpolation for interpolation_step, file_path in files: new_dir = os.path.join(dir_to_classfolder, interpolation_step) if not os.path.exists(new_dir): os.makedirs(new_dir) os.rename(file_path, os.path.join(new_dir, os.path.basename(file_path))) def ungroupInterpolation(dir_to_classfolder): """ Moves all files from subdirectories within `dir_to_classfolder` to `dir_to_classfolder` itself, and then removes the subdirectories. Args: dir_to_classfolder (str): The path to the directory containing the subdirectories. Returns: None """ for interpolation_step in os.listdir(dir_to_classfolder): if os.path.isdir(os.path.join(dir_to_classfolder, interpolation_step)): for f in os.listdir(os.path.join(dir_to_classfolder, interpolation_step)): os.rename( os.path.join(dir_to_classfolder, interpolation_step, f), os.path.join(dir_to_classfolder, f), ) os.rmdir(os.path.join(dir_to_classfolder, interpolation_step)) def groupAllbyInterpolation( data_path, group=True, fn_group=groupbyInterpolation, fn_ungroup=ungroupInterpolation, ): """ Group or ungroup all data classes by interpolation. Args: data_path (str): The path to the data. group (bool, optional): Whether to group the data. Defaults to True. fn_group (function, optional): The function to use for grouping. Defaults to groupbyInterpolation. fn_ungroup (function, optional): The function to use for ungrouping. Defaults to ungroupInterpolation. """ data_classes = sorted(os.listdir(data_path)) if group: fn = fn_group else: fn = fn_ungroup for c in data_classes: c_path = os.path.join(data_path, c) if os.path.isdir(c_path): fn(c_path) print(f"Processed {c}") def getPairIndices(subset_len, total_pair_count=1, seed=None): """ Generate pairs of indices for a given subset length. Args: subset_len (int): The length of the subset. total_pair_count (int, optional): The total number of pairs to generate. Defaults to 1. seed (int, optional): The seed value for the random number generator. Defaults to None. Returns: list: A list of pairs of indices. """ rng = np.random.default_rng(seed) group_size = (subset_len + total_pair_count - 1) // total_pair_count numbers = list(range(subset_len)) numbers_selection = list(range(subset_len)) rng.shuffle(numbers) for i in range(group_size - subset_len % group_size): numbers.append(numbers_selection[i]) numbers = np.array(numbers) groups = numbers[: group_size * total_pair_count].reshape(-1, group_size) return groups.tolist() def generateImagesFromDataset( img_subsets, class_iterables, pipeline, interpolated_prompt_embeds, interpolated_negative_prompts_embeds, num_inference_steps, guidance_scale, height=512, width=512, seed=None, save_path=".", class_pairs=("0", "1"), save_image=True, image_type="jpg", interpolate_range="full", device="cuda", return_images=False, ): """ Generates images from a dataset using the given parameters. Args: img_subsets (dict): A dictionary containing image subsets for each class. class_iterables (dict): A dictionary containing iterable objects for each class. pipeline (object): The pipeline object used for image generation. interpolated_prompt_embeds (list): A list of interpolated prompt embeddings. interpolated_negative_prompts_embeds (list): A list of interpolated negative prompt embeddings. num_inference_steps (int): The number of inference steps for image generation. guidance_scale (float): The scale factor for guidance loss during image generation. height (int, optional): The height of the generated images. Defaults to 512. width (int, optional): The width of the generated images. Defaults to 512. seed (int, optional): The seed value for random number generation. Defaults to None. save_path (str, optional): The path to save the generated images. Defaults to ".". class_pairs (tuple, optional): A tuple containing pairs of class identifiers. Defaults to ("0", "1"). save_image (bool, optional): Whether to save the generated images. Defaults to True. image_type (str, optional): The file format of the saved images. Defaults to "jpg". interpolate_range (str, optional): The range of interpolation for prompt embeddings. Possible values are "full", "nearest", or "furthest". Defaults to "full". device (str, optional): The device to use for image generation. Defaults to "cuda". return_images (bool, optional): Whether to return the generated images. Defaults to False. Returns: dict or tuple: If return_images is True, returns a dictionary containing the generated images for each class and a dictionary containing the SSIM scores for each class and interpolation step. If return_images is False, returns a dictionary containing the SSIM scores for each class and interpolation step. """ if interpolate_range == "nearest": nearest_half = True furthest_half = False elif interpolate_range == "furthest": nearest_half = False furthest_half = True else: nearest_half = False furthest_half = False if seed is None: seed = torch.Generator().seed() generator = torch.manual_seed(seed) rng = np.random.default_rng(seed) # Generating initial U-Net latent vectors from a random normal distribution. latents = torch.randn( (1, pipeline.unet.config.in_channels, height // 8, width // 8), generator=generator, ).to(device) embed_len = len(interpolated_prompt_embeds) embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds) embed_pairs_list = list(embed_pairs) if return_images: class_images = dict() class_ssim = dict() if nearest_half or furthest_half: if nearest_half: steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len)) mutiplier = 2 elif furthest_half: # uses opposite class of images of the text interpolation steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2)) mutiplier = 2 else: steps_range = (range(embed_len), range(embed_len)) mutiplier = 1 for class_iter, class_id in enumerate(class_pairs): if return_images: class_images[class_id] = list() class_ssim[class_id] = { i: {"ssim_sum": 0, "ssim_count": 0, "ssim_avg": 0} for i in range(embed_len) } subset_len = len(img_subsets[class_id]) # to efficiently randomize the steps to interpolate for each image in the class, group_map is used # group_map: index is the image id, element is the group id # steps_range[class_iter] determines the range of steps to interpolate for the class, # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps # then the rest is to multiply the steps to cover the whole subset + remainder group_map = ( list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1) ) rng.shuffle( group_map ) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id iter_indices = class_iterables[class_id].pop() # generate images for each image in the class, randomly selecting an interpolated step for image_id in iter_indices: img, trg = img_subsets[class_id][image_id] input_image = img.unsqueeze(0) interpolate_step = group_map[image_id] prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolate_step] generated_image = genClassImg( pipeline, prompt_embeds, negative_prompt_embeds, input_image, generator, latents, num_imgs=1, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) pred_image = transforms.ToTensor()(generated_image).unsqueeze(0) ssim_score = ssim(pred_image, input_image).item() class_ssim[class_id][interpolate_step]["ssim_sum"] += ssim_score class_ssim[class_id][interpolate_step]["ssim_count"] += 1 if return_images: class_images[class_id].append(generated_image) if save_image: if image_type == "jpg": generated_image.save( f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}", format="JPEG", quality=95, ) elif image_type == "png": generated_image.save( f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}", format="PNG", ) else: generated_image.save( f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}" ) # calculate ssim avg for the class for i_step in range(embed_len): if class_ssim[class_id][i_step]["ssim_count"] > 0: class_ssim[class_id][i_step]["ssim_avg"] = ( class_ssim[class_id][i_step]["ssim_sum"] / class_ssim[class_id][i_step]["ssim_count"] ) if return_images: return class_images, class_ssim else: return class_ssim def generateTrace( prompts, img_subsets, class_iterables, interpolated_prompt_embeds, interpolated_negative_prompts_embeds, subset_indices, seed=None, save_path=".", class_pairs=("0", "1"), image_type="jpg", interpolate_range="full", save_prompt_embeds=False, ): """ Generate a trace dictionary containing information about the generated images. Args: prompts (list): List of prompt texts. img_subsets (dict): Dictionary containing image subsets for each class. class_iterables (dict): Dictionary containing iterable objects for each class. interpolated_prompt_embeds (torch.Tensor): Tensor containing interpolated prompt embeddings. interpolated_negative_prompts_embeds (torch.Tensor): Tensor containing interpolated negative prompt embeddings. subset_indices (dict): Dictionary containing indices of subsets for each class. seed (int, optional): Seed value for random number generation. Defaults to None. save_path (str, optional): Path to save the generated images. Defaults to ".". class_pairs (tuple, optional): Tuple containing class pairs. Defaults to ("0", "1"). image_type (str, optional): Type of the generated images. Defaults to "jpg". interpolate_range (str, optional): Range of interpolation. Defaults to "full". save_prompt_embeds (bool, optional): Flag to save prompt embeddings. Defaults to False. Returns: dict: Trace dictionary containing information about the generated images. """ trace_dict = { "class_pairs": list(), "class_id": list(), "image_id": list(), "interpolation_step": list(), "embed_len": list(), "pos_prompt_text": list(), "neg_prompt_text": list(), "input_file_path": list(), "output_file_path": list(), "input_prompts_embed": list(), } if interpolate_range == "nearest": nearest_half = True furthest_half = False elif interpolate_range == "furthest": nearest_half = False furthest_half = True else: nearest_half = False furthest_half = False if seed is None: seed = torch.Generator().seed() rng = np.random.default_rng(seed) embed_len = len(interpolated_prompt_embeds) embed_pairs = zip( interpolated_prompt_embeds.cpu().numpy(), interpolated_negative_prompts_embeds.cpu().numpy(), ) embed_pairs_list = list(embed_pairs) if nearest_half or furthest_half: if nearest_half: steps_range = (range(0, embed_len // 2), range(embed_len // 2, embed_len)) mutiplier = 2 elif furthest_half: # uses opposite class of images of the text interpolation steps_range = (range(embed_len // 2, embed_len), range(0, embed_len // 2)) mutiplier = 2 else: steps_range = (range(embed_len), range(embed_len)) mutiplier = 1 for class_iter, class_id in enumerate(class_pairs): subset_len = len(img_subsets[class_id]) # to efficiently randomize the steps to interpolate for each image in the class, group_map is used # group_map: index is the image id, element is the group id # steps_range[class_iter] determines the range of steps to interpolate for the class, # so the first half of the steps are for the first class and so on. range(0,7) and range(8,15) for 16 steps # then the rest is to multiply the steps to cover the whole subset + remainder group_map = ( list(steps_range[class_iter]) * mutiplier * (subset_len // embed_len + 1) ) rng.shuffle( group_map ) # shuffle the steps to interpolate for each image, position in the group_map is mapped to the image id iter_indices = class_iterables[class_id].pop() # generate images for each image in the class, randomly selecting an interpolated step for image_id in iter_indices: class_ds = img_subsets[class_id] interpolate_step = group_map[image_id] sample_count = subset_indices[class_id][0] + image_id input_file = os.path.normpath(class_ds.dataset.samples[sample_count][0]) pos_prompt = prompts[0] neg_prompt = prompts[1] output_file = f"{save_path}/{class_id}/{seed}-{image_id}_{interpolate_step}.{image_type}" if save_prompt_embeds: input_prompts_embed = embed_pairs_list[interpolate_step] else: input_prompts_embed = None trace_dict["class_pairs"].append(class_pairs) trace_dict["class_id"].append(class_id) trace_dict["image_id"].append(image_id) trace_dict["interpolation_step"].append(interpolate_step) trace_dict["embed_len"].append(embed_len) trace_dict["pos_prompt_text"].append(pos_prompt) trace_dict["neg_prompt_text"].append(neg_prompt) trace_dict["input_file_path"].append(input_file) trace_dict["output_file_path"].append(output_file) trace_dict["input_prompts_embed"].append(input_prompts_embed) return trace_dict