Spaces:
Runtime error
Runtime error
| import os | |
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import json | |
| import random | |
| import glob | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| def image_resize(img, max_size=512): | |
| w, h = img.size | |
| if w >= h: | |
| new_w = max_size | |
| new_h = int((max_size / w) * h) | |
| else: | |
| new_h = max_size | |
| new_w = int((max_size / h) * w) | |
| return img.resize((new_w, new_h)) | |
| def c_crop(image): | |
| width, height = image.size | |
| new_size = min(width, height) | |
| left = (width - new_size) / 2 | |
| top = (height - new_size) / 2 | |
| right = (width + new_size) / 2 | |
| bottom = (height + new_size) / 2 | |
| return image.crop((left, top, right, bottom)) | |
| def crop_to_aspect_ratio(image, ratio="16:9"): | |
| width, height = image.size | |
| ratio_map = { | |
| "16:9": (16, 9), | |
| "4:3": (4, 3), | |
| "1:1": (1, 1) | |
| } | |
| target_w, target_h = ratio_map[ratio] | |
| target_ratio_value = target_w / target_h | |
| current_ratio = width / height | |
| if current_ratio > target_ratio_value: | |
| new_width = int(height * target_ratio_value) | |
| offset = (width - new_width) // 2 | |
| crop_box = (offset, 0, offset + new_width, height) | |
| else: | |
| new_height = int(width / target_ratio_value) | |
| offset = (height - new_height) // 2 | |
| crop_box = (0, offset, width, offset + new_height) | |
| cropped_img = image.crop(crop_box) | |
| return cropped_img | |
| class CustomImageDataset(Dataset): | |
| def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False): | |
| self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] | |
| # self.images = glob.glob(img_dir +'**/*.jpg', recursive=True) + glob.glob(img_dir +'**/*.png', recursive=True) + glob.glob(img_dir +'**/*.jpeg', recursive=True) | |
| self.images.sort() | |
| self.img_size = img_size | |
| self.caption_type = caption_type | |
| self.random_ratio = random_ratio | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| try: | |
| img = Image.open(self.images[idx]).convert('RGB') | |
| if self.random_ratio: | |
| ratio = random.choice(["16:9", "default", "1:1", "4:3"]) | |
| if ratio != "default": | |
| img = crop_to_aspect_ratio(img, ratio) | |
| img = image_resize(img, self.img_size) | |
| w, h = img.size | |
| new_w = (w // 32) * 32 | |
| new_h = (h // 32) * 32 | |
| img = img.resize((new_w, new_h)) | |
| img = torch.from_numpy((np.array(img) / 127.5) - 1) | |
| img = img.permute(2, 0, 1) | |
| json_path = self.images[idx].split('.')[0] + '.' + self.caption_type | |
| if self.caption_type == "json": | |
| prompt = json.load(open(json_path))['caption'] | |
| else: | |
| prompt = open(json_path).read() | |
| return img, prompt | |
| except Exception as e: | |
| print(e) | |
| return self.__getitem__(random.randint(0, len(self.images) - 1)) | |
| def loader(train_batch_size, num_workers, **args): | |
| dataset = CustomImageDataset(**args) | |
| return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) | |
| class ImageEditPairDataset(Dataset): | |
| def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False, grayscale_editing=False, zoom_camera=False): | |
| # self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] | |
| self.images = glob.glob(img_dir +'**/*.jpg', recursive=True) + glob.glob(img_dir +'**/*.png', recursive=True) + glob.glob(img_dir +'**/*.jpeg', recursive=True) | |
| self.images.sort() | |
| self.img_size = img_size | |
| self.caption_type = caption_type | |
| self.random_ratio = random_ratio | |
| self.grayscale_editing = grayscale_editing | |
| self.zoom_camera = zoom_camera | |
| if "ByteMorph-Bench" or "InstructMove" in img_dir: | |
| self.eval = True | |
| else: | |
| self.eval = False | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| try: | |
| img = Image.open(self.images[idx]).convert('RGB') | |
| ori_width, ori_height = img.size | |
| left_half = (0, 0, ori_width // 2, ori_height) | |
| right_half = (ori_width // 2, 0, ori_width, ori_height) | |
| src_image = img.crop(left_half) # Left half | |
| tgt_image = img.crop(right_half) # Right half | |
| # print("ori_width, ori_height: ",ori_width, ori_height) | |
| if self.random_ratio: | |
| ratio = random.choice(["16:9", "default", "1:1", "4:3"]) | |
| if ratio != "default": | |
| src_image = crop_to_aspect_ratio(src_image, ratio) | |
| tgt_image = crop_to_aspect_ratio(tgt_image, ratio) | |
| src_image = image_resize(src_image, self.img_size) | |
| tgt_image = image_resize(tgt_image, self.img_size) | |
| w, h = src_image.size | |
| new_w = (w // 32) * 32 | |
| new_h = (h // 32) * 32 | |
| # print("new_w, new_h: ",new_w, new_h) | |
| src_image = src_image.resize((new_w, new_h)) | |
| src_image = torch.from_numpy((np.array(src_image) / 127.5) - 1) | |
| src_image = src_image.permute(2, 0, 1) | |
| tgt_image = tgt_image.resize((new_w, new_h)) | |
| tgt_image = torch.from_numpy((np.array(tgt_image) / 127.5) - 1) | |
| tgt_image = tgt_image.permute(2, 0, 1) | |
| json_path = self.images[idx].split('.')[0] + '.' + self.caption_type | |
| if self.eval: | |
| image_name = self.images[idx].split('.')[0].split("/")[-1] | |
| edit_type = self.images[idx].split('.')[0].split("/")[-2] | |
| if self.caption_type == "json": | |
| if not self.eval: | |
| prompt = json.load(open(json_path))['caption'] | |
| edit_prompt = json.load(open(json_path))['edit'] | |
| else: | |
| prompt = [] #json.load(open(json_path))['caption'] | |
| edit_prompt = json.load(open(json_path))['edit'] | |
| else: | |
| raise NotImplementedError | |
| # prompt = open(json_path).read() | |
| if (not self.grayscale_editing) and (not self.zoom_camera): | |
| if not self.eval: | |
| return src_image, tgt_image, prompt, edit_prompt | |
| else: | |
| return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type | |
| if self.grayscale_editing and (not self.zoom_camera): | |
| # Grayscale = 0.2989 * R + 0.5870 * G + 0.1140 * B | |
| grayscale_image = 0.2989 * src_image[0, :, :] + 0.5870 * src_image[1, :, :] + 0.1140 * src_image[2, :, :] | |
| tgt_image = grayscale_image.unsqueeze(0).repeat(3, 1, 1) | |
| edit_prompt = "Convert the input image to a black and white grayscale image while maintaining the original composition and details." | |
| if not self.eval: | |
| return src_image, tgt_image, prompt, edit_prompt | |
| else: | |
| return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type | |
| if (not self.grayscale_editing) and self.zoom_camera: | |
| cropped = TF.center_crop(src_image, (256, 256)) | |
| tgt_image = TF.resize(cropped, (512, 512)) | |
| edit_prompt = "The central area of the input image is zoomed. The camera transitions from a wide shot to a closer position, narrowing its view." | |
| if not self.eval: | |
| return src_image, tgt_image, prompt, edit_prompt | |
| else: | |
| return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type | |
| if self.grayscale_editing and self.zoom_camera: | |
| grayscale_image = 0.2989 * src_image[0, :, :] + 0.5870 * src_image[1, :, :] + 0.1140 * src_image[2, :, :] | |
| tgt_image = grayscale_image.unsqueeze(0).repeat(3, 1, 1) | |
| tgt_image = TF.center_crop(tgt_image, (256, 256)) | |
| tgt_image = TF.resize(tgt_image, (512, 512)) | |
| edit_prompt = "Convert the input image to a black and white grayscale image while maintaining the original composition and details. And the central area of the input image is zoomed, the camera transitions from a wide shot to a closer position, narrowing its view." | |
| if not self.eval: | |
| return src_image, tgt_image, prompt, edit_prompt | |
| else: | |
| return src_image, tgt_image, prompt, edit_prompt, image_name, edit_type | |
| except Exception as e: | |
| print(e) | |
| return self.__getitem__(random.randint(0, len(self.images) - 1)) | |
| def image_pair_loader(train_batch_size, num_workers, **args): | |
| dataset = ImageEditPairDataset(**args) | |
| return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) | |
| def eval_image_pair_loader(eval_batch_size, num_workers, **args): | |
| dataset = ImageEditPairDataset(**args) | |
| return DataLoader(dataset, batch_size=eval_batch_size, num_workers=num_workers, shuffle=False) | |
| if __name__ == "__main__": | |
| from src.flux.util import save_image | |
| example_dataset = ImageEditPairDataset( | |
| img_dir="", | |
| img_size=512, | |
| caption_type='json', | |
| random_ratio=False, | |
| grayscale_editing=False, | |
| zoom_camera=False, | |
| ) | |
| train_dataloader = DataLoader( | |
| example_dataset, | |
| batch_size=1, | |
| num_workers=4, | |
| shuffle=False, | |
| ) | |
| for step, batch in enumerate(train_dataloader): | |
| src_image, tgt_image, prompt, edit_prompt = batch | |
| os.makedirs("./debug", exist_ok=True) | |
| save_image(src_image, f"./debug/{step}-src_img.jpg") | |
| save_image(tgt_image, f"./debug/{step}-tgt_img.jpg") | |
| if step == 3: | |
| breakpoint() | |