Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from PIL import Image | |
from datasets import load_dataset | |
from torchvision import transforms | |
import random | |
import os | |
import numpy as np | |
Image.MAX_IMAGE_PIXELS = None | |
def make_train_dataset(args, tokenizer, accelerator=None): | |
if args.train_data_dir is not None: | |
print("load_data") | |
dataset = load_dataset('json', data_files=args.train_data_dir) | |
column_names = dataset["train"].column_names | |
# 6. Get the column names for input/target. | |
if args.caption_column is None: | |
caption_column = column_names[0] | |
print(f"caption column defaulting to {caption_column}") | |
else: | |
caption_column = args.caption_column | |
if caption_column not in column_names: | |
raise ValueError( | |
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" | |
) | |
if args.source_column is None: | |
source_column = column_names[1] | |
print(f"source column defaulting to {source_column}") | |
else: | |
source_column = args.source_column | |
if source_column not in column_names: | |
raise ValueError( | |
f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" | |
) | |
if args.target_column is None: | |
target_column = column_names[1] | |
print(f"target column defaulting to {target_column}") | |
else: | |
target_column = args.target_column | |
if target_column not in column_names: | |
raise ValueError( | |
f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" | |
) | |
h = args.height | |
w = args.width | |
train_transforms = transforms.Compose( | |
[ | |
transforms.Resize((h, w), interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
tokenizer_clip = tokenizer[0] | |
tokenizer_t5 = tokenizer[1] | |
def tokenize_prompt_clip_t5(examples): | |
captions = [] | |
for caption in examples[caption_column]: | |
if isinstance(caption, str): | |
captions.append(caption) | |
elif isinstance(caption, list): | |
captions.append(random.choice(caption)) | |
else: | |
raise ValueError( | |
f"Caption column `{caption_column}` should contain either strings or lists of strings." | |
) | |
text_inputs = tokenizer_clip( | |
captions, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
text_input_ids_1 = text_inputs.input_ids | |
text_inputs = tokenizer_t5( | |
captions, | |
padding="max_length", | |
max_length=512, | |
truncation=True, | |
return_length=False, | |
return_overflowing_tokens=False, | |
return_tensors="pt", | |
) | |
text_input_ids_2 = text_inputs.input_ids | |
return text_input_ids_1, text_input_ids_2 | |
def preprocess_train(examples): | |
_examples = {} | |
source_images = [Image.open(image).convert("RGB") for image in examples[source_column]] | |
target_images = [Image.open(image).convert("RGB") for image in examples[target_column]] | |
_examples["cond_pixel_values"] = [train_transforms(source) for source in source_images] | |
_examples["pixel_values"] = [train_transforms(image) for image in target_images] | |
_examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples) | |
return _examples | |
if accelerator is not None: | |
with accelerator.main_process_first(): | |
train_dataset = dataset["train"].with_transform(preprocess_train) | |
else: | |
train_dataset = dataset["train"].with_transform(preprocess_train) | |
return train_dataset | |
def collate_fn(examples): | |
cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples]) | |
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() | |
target_pixel_values = torch.stack([example["pixel_values"] for example in examples]) | |
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float() | |
token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples]) | |
token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples]) | |
return { | |
"cond_pixel_values": cond_pixel_values, | |
"pixel_values": target_pixel_values, | |
"text_ids_1": token_ids_clip, | |
"text_ids_2": token_ids_t5, | |
} | |