|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import torchvision.transforms as T |
|
|
import os |
|
|
import random |
|
|
import numpy as np |
|
|
|
|
|
from PIL import Image, ImageDraw |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
from .trainer import OminiModel, get_config, train |
|
|
from ..pipeline.flux_omini import Condition, convert_to_condition, generate |
|
|
|
|
|
|
|
|
class ImageConditionDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
base_dataset, |
|
|
condition_size=(512, 512), |
|
|
target_size=(512, 512), |
|
|
condition_type: str = "canny", |
|
|
drop_text_prob: float = 0.1, |
|
|
drop_image_prob: float = 0.1, |
|
|
return_pil_image: bool = False, |
|
|
position_scale=1.0, |
|
|
): |
|
|
self.base_dataset = base_dataset |
|
|
self.condition_size = condition_size |
|
|
self.target_size = target_size |
|
|
self.condition_type = condition_type |
|
|
self.drop_text_prob = drop_text_prob |
|
|
self.drop_image_prob = drop_image_prob |
|
|
self.return_pil_image = return_pil_image |
|
|
self.position_scale = position_scale |
|
|
|
|
|
self.to_tensor = T.ToTensor() |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.base_dataset) |
|
|
|
|
|
def __get_condition__(self, image, condition_type): |
|
|
condition_size = self.condition_size |
|
|
position_delta = np.array([0, 0]) |
|
|
if condition_type in ["canny", "coloring", "deblurring", "depth"]: |
|
|
image, kwargs = image.resize(condition_size), {} |
|
|
if condition_type == "deblurring": |
|
|
blur_radius = random.randint(1, 10) |
|
|
kwargs["blur_radius"] = blur_radius |
|
|
condition_img = convert_to_condition(condition_type, image, **kwargs) |
|
|
elif condition_type == "depth_pred": |
|
|
depth_img = convert_to_condition("depth", image) |
|
|
condition_img = image.resize(condition_size) |
|
|
image = depth_img.resize(condition_size) |
|
|
elif condition_type == "fill": |
|
|
condition_img = image.resize(condition_size).convert("RGB") |
|
|
w, h = image.size |
|
|
x1, x2 = sorted([random.randint(0, w), random.randint(0, w)]) |
|
|
y1, y2 = sorted([random.randint(0, h), random.randint(0, h)]) |
|
|
mask = Image.new("L", image.size, 0) |
|
|
draw = ImageDraw.Draw(mask) |
|
|
draw.rectangle([x1, y1, x2, y2], fill=255) |
|
|
if random.random() > 0.5: |
|
|
mask = Image.eval(mask, lambda a: 255 - a) |
|
|
condition_img = Image.composite( |
|
|
image, Image.new("RGB", image.size, (0, 0, 0)), mask |
|
|
) |
|
|
elif condition_type == "sr": |
|
|
condition_img = image.resize(condition_size) |
|
|
position_delta = np.array([0, -condition_size[0] // 16]) |
|
|
else: |
|
|
raise ValueError(f"Condition type {condition_type} is not implemented.") |
|
|
return condition_img, position_delta |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = self.base_dataset[idx]["jpg"] |
|
|
image = image.resize(self.target_size).convert("RGB") |
|
|
description = self.base_dataset[idx]["json"]["prompt"] |
|
|
|
|
|
condition_size = self.condition_size |
|
|
position_scale = self.position_scale |
|
|
|
|
|
condition_img, position_delta = self.__get_condition__( |
|
|
image, self.condition_type |
|
|
) |
|
|
|
|
|
|
|
|
drop_text = random.random() < self.drop_text_prob |
|
|
drop_image = random.random() < self.drop_image_prob |
|
|
|
|
|
if drop_text: |
|
|
description = "" |
|
|
if drop_image: |
|
|
condition_img = Image.new("RGB", condition_size, (0, 0, 0)) |
|
|
|
|
|
return { |
|
|
"image": self.to_tensor(image), |
|
|
"condition_0": self.to_tensor(condition_img), |
|
|
"condition_type_0": self.condition_type, |
|
|
"position_delta_0": position_delta, |
|
|
"description": description, |
|
|
**({"pil_image": [image, condition_img]} if self.return_pil_image else {}), |
|
|
**({"position_scale_0": position_scale} if position_scale != 1.0 else {}), |
|
|
} |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def test_function(model, save_path, file_name): |
|
|
condition_size = model.training_config["dataset"]["condition_size"] |
|
|
target_size = model.training_config["dataset"]["target_size"] |
|
|
|
|
|
position_delta = model.training_config["dataset"].get("position_delta", [0, 0]) |
|
|
position_scale = model.training_config["dataset"].get("position_scale", 1.0) |
|
|
|
|
|
adapter = model.adapter_names[2] |
|
|
condition_type = model.training_config["condition_type"] |
|
|
test_list = [] |
|
|
|
|
|
if condition_type in ["canny", "coloring", "deblurring", "depth"]: |
|
|
image = Image.open("assets/vase_hq.jpg") |
|
|
image = image.resize(condition_size) |
|
|
condition_img = convert_to_condition(condition_type, image, 5) |
|
|
condition = Condition(condition_img, adapter, position_delta, position_scale) |
|
|
test_list.append((condition, "A beautiful vase on a table.")) |
|
|
elif condition_type == "depth_pred": |
|
|
image = Image.open("assets/vase_hq.jpg") |
|
|
image = image.resize(condition_size) |
|
|
condition = Condition(image, adapter, position_delta, position_scale) |
|
|
test_list.append((condition, "A beautiful vase on a table.")) |
|
|
elif condition_type == "fill": |
|
|
condition_img = ( |
|
|
Image.open("./assets/vase_hq.jpg").resize(condition_size).convert("RGB") |
|
|
) |
|
|
mask = Image.new("L", condition_img.size, 0) |
|
|
draw = ImageDraw.Draw(mask) |
|
|
a = condition_img.size[0] // 4 |
|
|
b = a * 3 |
|
|
draw.rectangle([a, a, b, b], fill=255) |
|
|
condition_img = Image.composite( |
|
|
condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask |
|
|
) |
|
|
condition = Condition(condition, adapter, position_delta, position_scale) |
|
|
test_list.append((condition, "A beautiful vase on a table.")) |
|
|
elif condition_type == "super_resolution": |
|
|
image = Image.open("assets/vase_hq.jpg") |
|
|
image = image.resize(condition_size) |
|
|
condition = Condition(image, adapter, position_delta, position_scale) |
|
|
test_list.append((condition, "A beautiful vase on a table.")) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
for i, (condition, prompt) in enumerate(test_list): |
|
|
generator = torch.Generator(device=model.device) |
|
|
generator.manual_seed(42) |
|
|
|
|
|
res = generate( |
|
|
model.flux_pipe, |
|
|
prompt=prompt, |
|
|
conditions=[condition], |
|
|
height=target_size[1], |
|
|
width=target_size[0], |
|
|
generator=generator, |
|
|
model_config=model.model_config, |
|
|
kv_cache=model.model_config.get("independent_condition", False), |
|
|
) |
|
|
file_path = os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg") |
|
|
res.images[0].save(file_path) |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
config = get_config() |
|
|
training_config = config["train"] |
|
|
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) |
|
|
|
|
|
|
|
|
dataset = load_dataset( |
|
|
"webdataset", |
|
|
data_files={"train": training_config["dataset"]["urls"]}, |
|
|
split="train", |
|
|
cache_dir="cache/t2i2m", |
|
|
num_proc=32, |
|
|
) |
|
|
|
|
|
|
|
|
dataset = ImageConditionDataset( |
|
|
dataset, |
|
|
condition_size=training_config["dataset"]["condition_size"], |
|
|
target_size=training_config["dataset"]["target_size"], |
|
|
condition_type=training_config["condition_type"], |
|
|
drop_text_prob=training_config["dataset"]["drop_text_prob"], |
|
|
drop_image_prob=training_config["dataset"]["drop_image_prob"], |
|
|
position_scale=training_config["dataset"].get("position_scale", 1.0), |
|
|
) |
|
|
|
|
|
|
|
|
trainable_model = OminiModel( |
|
|
flux_pipe_id=config["flux_path"], |
|
|
lora_config=training_config["lora_config"], |
|
|
device=f"cuda", |
|
|
dtype=getattr(torch, config["dtype"]), |
|
|
optimizer_config=training_config["optimizer"], |
|
|
model_config=config.get("model", {}), |
|
|
gradient_checkpointing=training_config.get("gradient_checkpointing", False), |
|
|
) |
|
|
|
|
|
train(dataset, trainable_model, config, test_function) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|