Spaces:
Runtime error
Runtime error
import numpy | |
import numpy as np | |
import torch | |
import os | |
import random | |
import pandas as pd | |
import os.path as osp | |
import PIL.Image as Image | |
from torch.utils.data import Dataset | |
from pathlib import Path | |
from imagedream.ldm.util import add_random_background | |
from imagedream.camera_utils import get_camera_for_index | |
from libs.base_utils import do_resize_content, add_stroke | |
import torchvision.transforms as transforms | |
def to_rgb_image(maybe_rgba: Image.Image): | |
if maybe_rgba.mode == "RGB": | |
return maybe_rgba | |
elif maybe_rgba.mode == "RGBA": | |
rgba = maybe_rgba | |
img = numpy.random.randint( | |
127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8 | |
) | |
img = Image.fromarray(img, "RGB") | |
img.paste(rgba, mask=rgba.getchannel("A")) | |
return img | |
else: | |
raise ValueError("Unsupported image type.", maybe_rgba.mode) | |
def axis_rotate_xyz(img: Image.Image, rotate_axis="z", angle=90.0): | |
img = img.convert("RGB") | |
img = np.array(img) - 127 | |
img = img.astype(np.float32) | |
# perform element-wise sin-cos rotation | |
if rotate_axis == "z": | |
img = np.stack( | |
[ | |
img[..., 0] * np.cos(angle) - img[..., 1] * np.sin(angle), | |
img[..., 0] * np.sin(angle) + img[..., 1] * np.cos(angle), | |
img[..., 2], | |
], | |
-1, | |
) | |
elif rotate_axis == "y": | |
img = np.stack( | |
[ | |
img[..., 0] * np.cos(angle) + img[..., 2] * np.sin(angle), | |
img[..., 1], | |
-img[..., 0] * np.sin(angle) + img[..., 2] * np.cos(angle), | |
], | |
-1, | |
) | |
elif rotate_axis == "x": | |
img = np.stack( | |
[ | |
img[..., 0], | |
img[..., 1] * np.cos(angle) - img[..., 2] * np.sin(angle), | |
img[..., 1] * np.sin(angle) + img[..., 2] * np.cos(angle), | |
], | |
-1, | |
) | |
return Image.fromarray(img.astype(np.uint8) + 127) | |
class DataHQCRelative(Dataset): | |
""" | |
- base_dir | |
- uid1 | |
- 000.png | |
- 001.png | |
- ... | |
- uid2 | |
- xyz_base | |
- uid1 | |
- xyz_new_000.png | |
- xyz_new_001.png | |
- ... | |
accepte caption data(in csv format) | |
""" | |
def __init__( | |
self, | |
base_dir, | |
caption_csv, | |
ref_indexs=[0], | |
ref_position=-1, | |
xyz_base=None, | |
camera_views=[3, 6, 9, 12, 15], # camera views are relative views, not abs | |
split="train", | |
image_size=256, | |
random_background=False, | |
resize_rate=1, | |
num_frames=5, | |
repeat=100, | |
outer_file=None, | |
debug=False, | |
eval_size=100, | |
): | |
print(__class__) | |
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) | |
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) | |
df = pd.read_csv(caption_csv, sep=",", names=["id", "caption"]) | |
id_to_caption = {} | |
for i in range(len(df.index)): | |
item = df.iloc[i] | |
id_to_caption[item["id"]] = item["caption"] | |
# outer file is txt file, containing each ident per line, specific idents not included in the train process | |
outer_set = ( | |
set(open(outer_file, "r").read().strip().split("\n")) | |
if outer_file is not None | |
else set() | |
) | |
xyz_set = set(os.listdir(xyz_base)) if xyz_base is not None else set() | |
common_keys = set(id_to_caption.keys()) & set(os.listdir(base_dir)) | |
common_keys = common_keys & xyz_set if xyz_base is not None else common_keys | |
common_keys = common_keys - outer_set | |
self.common_keys = common_keys | |
self.id_to_caption = id_to_caption | |
final_dict = {key: id_to_caption[key] for key in common_keys} | |
self.image_size = image_size | |
self.base_dir = Path(base_dir) | |
self.xyz_base = xyz_base | |
self.repeat = repeat | |
self.num_frames = num_frames | |
self.camera_views = camera_views[:num_frames] | |
self.split = split | |
self.ref_indexs = ref_indexs | |
self.ref_position = ref_position | |
self.resize_rate = resize_rate | |
self.random_background = random_background | |
self.debug = debug | |
assert split in ["train", "eval"] | |
clip_size = 224 | |
self.transfrom_clip = transforms.Compose( | |
[ | |
transforms.Resize( | |
(clip_size, clip_size), | |
interpolation=Image.BICUBIC, | |
antialias="warn", | |
), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), | |
] | |
) | |
self.transfrom_vae = transforms.Compose( | |
[ | |
transforms.Resize((image_size, image_size)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
# 对于第i个视角作为参考时,左边,下面,背面,右边,上面,的图片名称index | |
import torchvision.transforms.functional as TF | |
from functools import partial as PA | |
self.index_mapping = [ | |
# 正 左 下 背 右 上 | |
[0, 1, 2, 3, 4, 5], # 0 | |
[1, 3, 2, 4, 0, 5], # 1 | |
[2, 1, 3, 5, 4, 0], # 2 | |
[3, 4, 2, 0, 1, 5], # 3 | |
[4, 0, 2, 1, 3, 5], # 4 | |
[5, 1, 0, 2, 4, 3], # 5 | |
] | |
TT = { | |
"r90": PA(TF.rotate, angle=-90.0), # 顺时针90 | |
"r180": PA(TF.rotate, angle=-180.0), # 顺时针180 | |
"r270": PA(TF.rotate, angle=-270.0), # 顺时针270 | |
"s90": PA(TF.rotate, angle=90.0), # 逆时针90 | |
"s180": PA(TF.rotate, angle=180.0), # 逆时针180 | |
"s270": PA(TF.rotate, angle=270.0), # 逆时针270 | |
} | |
self.transfroms_mapping = [ | |
# 正 左 下 背 右 上 | |
[None, None, None, None, None, None], # 0 | |
[None, None, TT["r90"], None, None, TT["s90"]], # 1 | |
[None, TT["s90"], TT["s180"], TT["r180"], TT["r90"], None], # 2 | |
[None, None, TT["r180"], None, None, TT["s180"]], # 3 | |
[None, None, TT["s90"], None, None, TT["r90"]], # 4 | |
[None, TT["r90"], None, TT["r180"], TT["s90"], TT["s180"]], # 5 | |
] | |
XT = { # xyz transforms | |
"zRrota90": PA(axis_rotate_xyz, rotate_axis="z", angle=np.pi / 2), | |
"zSrota90": PA(axis_rotate_xyz, rotate_axis="z", angle=-np.pi / 2), | |
"zrota180": PA(axis_rotate_xyz, rotate_axis="z", angle=-np.pi), | |
"xRrota90": PA(axis_rotate_xyz, rotate_axis="x", angle=np.pi / 2), | |
"xSrota90": PA(axis_rotate_xyz, rotate_axis="x", angle=-np.pi / 2), | |
} | |
self.xyz_transforms_mapping = [ | |
# 正 左 下 背 右 上 | |
[None,] * 6, # 0 | |
[XT["zRrota90"],] * 6, # 1 | |
[XT["xRrota90"],] * 6, # 2 | |
[XT["zrota180"],] * 6, # 3 | |
[XT["zSrota90"],] * 6, # 4 | |
[XT["xSrota90"],] * 6, # 5 | |
] | |
total_items = [ | |
{ | |
"path": os.path.join(base_dir, k), | |
"xyz_path": os.path.join(xyz_base, k) if xyz_base is not None else None, | |
"caption": v, | |
} | |
for k, v in final_dict.items() | |
] | |
total_items.sort(key=lambda x: x["path"]) | |
if len(total_items) > eval_size: | |
if split == "train": | |
self.items = total_items[eval_size:] | |
else: | |
self.items = total_items[:eval_size] | |
else: | |
self.items = total_items | |
print("============= length of dataset %d =============" % len(self.items)) | |
def __len__(self): | |
return len(self.items) * self.repeat | |
def __getitem__(self, index): | |
""" | |
choose index for target 6 images | |
select one of them as input image | |
target_images_vae: batch of `num_frame` images of one object from different views, processed by vae_processor | |
ref_ip: ref image in piexl space | |
ref_ip_img: | |
camera views decide the logical camera pose of images: | |
000 is front , ev: 0, azimuth: 0 | |
001 is left , ev: 0, azimuth: -90 | |
002 is down , ev: -90, azimuth: 0 | |
003 is back , ev: 0, azimuth: 180 | |
004 is right , ev: 0, azimuth: 90 | |
005 is top , ev: 90, azimuth: 0 | |
ref_index decides which image choose to be input image | |
for example when camera views = [1, 2, 3, 4, 5, 0], ref_position=5 | |
then dataset return the instance images in order as [left, down, back, right, top, front] | |
in which view[ref_position] = view[5] = 0, so the refrence image is the front image | |
as all the faces can be rotated to the front face, so any image can be placed to ref_position as ref image(need some tramsforms) | |
to do a better control of which image can be placed to ref_position, we can set ref_indexs. | |
ref_indexs set [0] default, that means only 000 named images will be placed to ref_position. | |
on the situation of ref_indexs=[0, 1, 3, 4], only 000, 001, 003, 004 named images will be placed to ref_position. | |
""" | |
index_mapping = self.index_mapping | |
transfroms_mapping = self.transfroms_mapping | |
index = index % len(self.items) | |
target_dir = self.items[index]["path"] | |
target_xyz_dir = self.items[index]["xyz_path"] | |
caption = self.items[index]["caption"] | |
bg_color = np.random.rand() * 255 | |
target_images = [] | |
target_xyz_images = [] | |
raw_images = [] | |
raw_xyz_images = [] | |
alpha_masks = [] | |
ref_index = random.choice(self.ref_indexs) | |
cur_index_mapping = index_mapping[ref_index] | |
cur_transfroms_mapping = transfroms_mapping[ref_index] | |
cur_xyz_transfroms_mapping = self.xyz_transforms_mapping[ref_index] | |
for relative_view in self.camera_views: | |
image_index = cur_index_mapping[relative_view] | |
trans = cur_transfroms_mapping[relative_view] | |
trans_xyz = cur_xyz_transfroms_mapping[relative_view] | |
# open | |
img = Image.open( | |
os.path.join(target_dir, f"{image_index:03d}.png") | |
).convert("RGBA") | |
if trans is not None: | |
img = trans(img) | |
img = do_resize_content(img, self.resize_rate) | |
alpha_mask = img.getchannel("A") | |
alpha_masks.append(alpha_mask) | |
if self.random_background: | |
img = add_random_background(img, bg_color) | |
img = img.convert("RGB") | |
target_images.append(self.transfrom_vae(img)) | |
raw_images.append(img) | |
if self.xyz_base is not None: | |
img_xyz = Image.open( | |
os.path.join(target_xyz_dir, f"xyz_new_{image_index:03d}.png") | |
).convert("RGBA") | |
img_xyz = trans_xyz(img_xyz) if trans_xyz is not None else img_xyz | |
img_xyz = trans(img_xyz) if trans is not None else img_xyz | |
img_xyz = do_resize_content(img_xyz, self.resize_rate) | |
img_xyz.putalpha(alpha_mask) | |
if self.random_background: | |
img_xyz = add_random_background(img_xyz, bg_color) | |
img_xyz = img_xyz.convert("RGB") | |
target_xyz_images.append(self.transfrom_vae(img_xyz)) | |
if self.debug: | |
raw_xyz_images.append(img_xyz) | |
cameras = [get_camera_for_index(i).squeeze() for i in self.camera_views] | |
if self.ref_position is not None: | |
cameras[self.ref_position] = torch.zeros_like( | |
cameras[self.ref_position] | |
) # set ref camera to zero | |
cameras = torch.stack(cameras) | |
input_img = Image.open( | |
os.path.join(target_dir, f"{ref_index:03d}.png") | |
).convert("RGBA") | |
input_img = do_resize_content(input_img, self.resize_rate) | |
if self.random_background: | |
input_img = add_random_background(input_img, bg_color) | |
input_img = input_img.convert("RGB") | |
clip_cond = self.transfrom_clip(input_img) | |
vae_cond = self.transfrom_vae(input_img) | |
vae_target = torch.stack(target_images, dim=0) | |
if self.xyz_base is not None: | |
xyz_vae_target = torch.stack(target_xyz_images, dim=0) | |
else: | |
xyz_vae_target = [] | |
if self.debug: | |
print(f"debug!!,{bg_color}") | |
return { | |
"target_images": raw_images, | |
"target_images_xyz": raw_xyz_images, | |
"input_img": input_img, | |
"cameras": cameras, | |
"caption": caption, | |
"item": self.items[index], | |
"alpha_masks": alpha_masks, | |
} | |
if self.split == "train": | |
return { | |
"target_images_vae": vae_target, | |
"target_images_xyz_vae": xyz_vae_target, | |
"clip_cond": clip_cond, | |
"vae_cond": vae_cond, | |
"cameras": cameras, | |
"caption": caption, | |
} | |
else: # eval | |
path = os.path.join(target_dir, f"{ref_index:03d}.png") | |
return dict( | |
path=path, | |
target_dir=target_dir, | |
cond_raw_images=raw_images, | |
cond=input_img, | |
ref_index=ref_index, | |
ident=f"{index}-{Path(target_dir).stem}", | |
) | |
class DataRelativeStroke(DataHQCRelative): | |
"""a temp dataset for add sync base using fov data as ref image""" | |
def __init__( | |
self, | |
base_dir, | |
caption_csv, | |
ref_indexs=[0], | |
ref_position=-1, | |
xyz_base=None, | |
camera_views=[3, 6, 9, 12, 15], # camera views are relative views, not abs | |
split="train", | |
image_size=256, | |
random_background=False, | |
resize_rate=1, | |
num_frames=5, | |
repeat=100, | |
outer_file=None, | |
debug=False, | |
eval_size=100, | |
stroke_p=0.3, | |
resize_range=None, | |
): | |
print(__class__) | |
super().__init__( | |
base_dir, | |
caption_csv, | |
ref_indexs=ref_indexs, | |
ref_position=ref_position, | |
xyz_base=xyz_base, | |
camera_views=camera_views, | |
split=split, | |
image_size=image_size, | |
random_background=random_background, | |
resize_rate=resize_rate, | |
num_frames=num_frames, | |
repeat=repeat, | |
outer_file=outer_file, | |
debug=debug, | |
eval_size=eval_size, | |
) | |
self.stroke_p = stroke_p | |
assert ( | |
resize_range is None or len(resize_range) == 2 | |
), "resize_range should be a tuple of 2 elements" | |
self.resize_range = resize_range | |
def __len__(self): | |
return len(self.items) * self.repeat | |
def __getitem__(self, index): | |
index_mapping = self.index_mapping | |
transfroms_mapping = self.transfroms_mapping | |
index = index % len(self.items) | |
target_dir = self.items[index]["path"] | |
target_xyz_dir = self.items[index]["xyz_path"] | |
caption = self.items[index]["caption"] | |
bg_color = np.random.rand() * 255 | |
target_images = [] | |
target_xyz_images = [] | |
raw_images = [] | |
raw_xyz_images = [] | |
alpha_masks = [] | |
ref_index = random.choice(self.ref_indexs) | |
cur_index_mapping = index_mapping[ref_index] | |
cur_transfroms_mapping = transfroms_mapping[ref_index] | |
cur_xyz_transfroms_mapping = self.xyz_transforms_mapping[ref_index] | |
cur_resize_rate = ( | |
random.uniform(*self.resize_range) * self.resize_rate | |
if self.resize_range is not None | |
else self.resize_rate | |
) | |
for relative_view in self.camera_views: | |
image_index = cur_index_mapping[relative_view] | |
trans = cur_transfroms_mapping[relative_view] | |
trans_xyz = cur_xyz_transfroms_mapping[relative_view] | |
# open | |
img = Image.open( | |
os.path.join(target_dir, f"{image_index:03d}.png") | |
).convert("RGBA") | |
if trans is not None: | |
img = trans(img) | |
img = do_resize_content(img, cur_resize_rate) | |
alpha_mask = img.getchannel("A") | |
alpha_masks.append(alpha_mask) | |
if self.random_background: | |
img = add_random_background(img, bg_color) | |
img = img.convert("RGB") | |
target_images.append(self.transfrom_vae(img)) | |
raw_images.append(img) | |
if self.xyz_base is not None: | |
img_xyz = Image.open( | |
os.path.join(target_xyz_dir, f"xyz_new_{image_index:03d}.png") | |
).convert("RGBA") | |
img_xyz = trans_xyz(img_xyz) if trans_xyz is not None else img_xyz | |
img_xyz = trans(img_xyz) if trans is not None else img_xyz | |
img_xyz = do_resize_content(img_xyz, cur_resize_rate) | |
img_xyz.putalpha(alpha_mask) | |
if self.random_background: | |
img_xyz = add_random_background(img_xyz, bg_color) | |
img_xyz = img_xyz.convert("RGB") | |
target_xyz_images.append(self.transfrom_vae(img_xyz)) | |
if self.debug: | |
raw_xyz_images.append(img_xyz) | |
cameras = [get_camera_for_index(i).squeeze() for i in self.camera_views] | |
if self.ref_position is not None: | |
cameras[self.ref_position] = torch.zeros_like( | |
cameras[self.ref_position] | |
) # set ref camera to zero | |
cameras = torch.stack(cameras) | |
input_img = Image.open( | |
os.path.join(target_dir, f"{ref_index:03d}.png") | |
).convert("RGBA") | |
input_img = do_resize_content(input_img, cur_resize_rate) | |
if random.random() < self.stroke_p: | |
## random rgb color | |
color = ( | |
random.randint(0, 255), | |
random.randint(0, 255), | |
random.randint(0, 255), | |
) | |
radius = random.randint(1, 3) | |
input_img = add_stroke(input_img, color=color, stroke_radius=radius) | |
if self.random_background: | |
input_img = add_random_background(input_img, bg_color) | |
input_img = input_img.convert("RGB") | |
clip_cond = self.transfrom_clip(input_img) | |
vae_cond = self.transfrom_vae(input_img) | |
vae_target = torch.stack(target_images, dim=0) | |
if self.xyz_base is not None: | |
xyz_vae_target = torch.stack(target_xyz_images, dim=0) | |
else: | |
xyz_vae_target = [] | |
if self.debug: | |
print(f"debug!!,{bg_color}") | |
return { | |
"target_images": raw_images, | |
"target_images_xyz": raw_xyz_images, | |
"input_img": input_img, | |
"cameras": cameras, | |
"caption": caption, | |
"item": self.items[index], | |
"alpha_masks": alpha_masks, | |
"cur_resize_rate": cur_resize_rate, | |
} | |
if self.split == "train": | |
return { | |
"target_images_vae": vae_target, | |
"target_images_xyz_vae": xyz_vae_target, | |
"clip_cond": clip_cond, | |
"vae_cond": vae_cond, | |
"cameras": cameras, | |
"caption": caption, | |
} | |
else: # eval | |
path = os.path.join(target_dir, f"{ref_index:03d}.png") | |
return dict( | |
path=path, | |
target_dir=target_dir, | |
cond_raw_images=raw_images, | |
cond=input_img, | |
ref_index=ref_index, | |
ident=f"{index}-{Path(target_dir).stem}", | |
) | |
class InTheWildImages(Dataset): | |
""" | |
a data set for in the wild images, | |
receive base floders, image path ls, path files as input | |
""" | |
def __init__(self, base_dirs=[], image_paths=[], path_files=[]): | |
print(__class__) | |
self.base_dirs = base_dirs | |
self.image_paths = image_paths | |
self.path_files = path_files | |
self.init_item() | |
def init_item(self): | |
items = [] | |
for d in self.base_dirs: | |
items += [osp.join(d, f) for f in os.listdir(d)] | |
items = items + self.image_paths | |
for file in self.path_files: | |
with open(file, "r") as f: | |
items += [line.strip() for line in f.readlines()] | |
items.sort() | |
self.items = items | |
def __len__(self): | |
return len(self.items) | |
def __getitem__(self, index): | |
item = self.items[index] | |
img = Image.open(item) | |
background = Image.new("RGBA", img.size, (0, 0, 0, 0)) | |
cond = Image.alpha_composite(background, img) | |
return dict( | |
path=item, ident=f"{index}-{Path(item).stem}", cond=cond.convert("RGB") | |
) | |