from PIL import Image from datasets import load_dataset from torchvision import transforms import random import torch Image.MAX_IMAGE_PIXELS = None def multiple_16(num: float): return int(round(num / 16) * 16) def get_random_resolution(min_size=512, max_size=1280, multiple=16): resolution = random.randint(min_size // multiple, max_size // multiple) * multiple return resolution def load_image_safely(image_path, size): try: image = Image.open(image_path).convert("RGB") return image except Exception as e: print("file error: "+image_path) with open("failed_images.txt", "a") as f: f.write(f"{image_path}\n") return Image.new("RGB", (size, size), (255, 255, 255)) def make_train_dataset(args, tokenizer, accelerator=None): if args.train_data_dir is not None: print("load_data") dataset = load_dataset('json', data_files=args.train_data_dir) column_names = dataset["train"].column_names # 6. Get the column names for input/target. caption_column = args.caption_column target_column = args.target_column if args.subject_column is not None: subject_columns = args.subject_column.split(",") if args.spatial_column is not None: spatial_columns= args.spatial_column.split(",") size = args.cond_size noise_size = get_random_resolution(max_size=args.noise_size) # maybe 768 or higher subject_cond_train_transforms = transforms.Compose( [ transforms.Lambda(lambda img: img.resize(( multiple_16(size * img.size[0] / max(img.size)), multiple_16(size * img.size[1] / max(img.size)) ), resample=Image.BILINEAR)), transforms.RandomHorizontalFlip(p=0.7), transforms.RandomRotation(degrees=20), transforms.Lambda(lambda img: transforms.Pad( padding=( int((size - img.size[0]) / 2), int((size - img.size[1]) / 2), int((size - img.size[0]) / 2), int((size - img.size[1]) / 2) ), fill=0 )(img)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) cond_train_transforms = transforms.Compose( [ transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop((size, size)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def train_transforms(image, noise_size): train_transforms_ = transforms.Compose( [ transforms.Lambda(lambda img: img.resize(( multiple_16(noise_size * img.size[0] / max(img.size)), multiple_16(noise_size * img.size[1] / max(img.size)) ), resample=Image.BILINEAR)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) transformed_image = train_transforms_(image) return transformed_image def load_and_transform_cond_images(images): transformed_images = [cond_train_transforms(image) for image in images] concatenated_image = torch.cat(transformed_images, dim=1) return concatenated_image def load_and_transform_subject_images(images): transformed_images = [subject_cond_train_transforms(image) for image in images] concatenated_image = torch.cat(transformed_images, dim=1) return concatenated_image tokenizer_clip = tokenizer[0] tokenizer_t5 = tokenizer[1] def tokenize_prompt_clip_t5(examples): captions = [] for caption in examples[caption_column]: if isinstance(caption, str): if random.random() < 0.1: captions.append(" ") # 将文本设为空 else: captions.append(caption) elif isinstance(caption, list): # take a random caption if there are multiple if random.random() < 0.1: captions.append(" ") else: captions.append(random.choice(caption)) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) text_inputs = tokenizer_clip( captions, padding="max_length", max_length=77, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids_1 = text_inputs.input_ids text_inputs = tokenizer_t5( captions, padding="max_length", max_length=512, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids_2 = text_inputs.input_ids return text_input_ids_1, text_input_ids_2 def preprocess_train(examples): _examples = {} if args.subject_column is not None: subject_images = [[load_image_safely(examples[column][i], args.cond_size) for column in subject_columns] for i in range(len(examples[target_column]))] _examples["subject_pixel_values"] = [load_and_transform_subject_images(subject) for subject in subject_images] if args.spatial_column is not None: spatial_images = [[load_image_safely(examples[column][i], args.cond_size) for column in spatial_columns] for i in range(len(examples[target_column]))] _examples["cond_pixel_values"] = [load_and_transform_cond_images(spatial) for spatial in spatial_images] target_images = [load_image_safely(image_path, args.cond_size) for image_path in examples[target_column]] _examples["pixel_values"] = [train_transforms(image, noise_size) for image in target_images] _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples) return _examples if accelerator is not None: with accelerator.main_process_first(): train_dataset = dataset["train"].with_transform(preprocess_train) else: train_dataset = dataset["train"].with_transform(preprocess_train) return train_dataset def collate_fn(examples): if examples[0].get("cond_pixel_values") is not None: cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples]) cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() else: cond_pixel_values = None if examples[0].get("subject_pixel_values") is not None: subject_pixel_values = torch.stack([example["subject_pixel_values"] for example in examples]) subject_pixel_values = subject_pixel_values.to(memory_format=torch.contiguous_format).float() else: subject_pixel_values = None target_pixel_values = torch.stack([example["pixel_values"] for example in examples]) target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float() token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples]) token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples]) return { "cond_pixel_values": cond_pixel_values, "subject_pixel_values": subject_pixel_values, "pixel_values": target_pixel_values, "text_ids_1": token_ids_clip, "text_ids_2": token_ids_t5, }