diff --git a/__pycache__/app.cpython-38.pyc b/__pycache__/app.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0eae3ea980e15d6009b1173dec5089a568e4531 Binary files /dev/null and b/__pycache__/app.cpython-38.pyc differ diff --git a/__pycache__/example_component.cpython-38.pyc b/__pycache__/example_component.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b30b851966058d9bdb533147b833163a4bcc3e6 Binary files /dev/null and b/__pycache__/example_component.cpython-38.pyc differ diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dataset/__pycache__/__init__.cpython-38.pyc b/dataset/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05b0ef16d074d28a35ecbe025b6821120d8605da Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-38.pyc differ diff --git a/dataset/__pycache__/catalog.cpython-38.pyc b/dataset/__pycache__/catalog.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc5b0f96eb8b80206ac3e28b4e5237e9d1cf29a3 Binary files /dev/null and b/dataset/__pycache__/catalog.cpython-38.pyc differ diff --git a/dataset/__pycache__/concat_dataset.cpython-38.pyc b/dataset/__pycache__/concat_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0d24742b9b477ac163ee9f1bd4a17b85f370051 Binary files /dev/null and b/dataset/__pycache__/concat_dataset.cpython-38.pyc differ diff --git a/dataset/base_dataset.py b/dataset/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3005bfc7cbef54b20006ca88ee01783cec9425c3 --- /dev/null +++ b/dataset/base_dataset.py @@ -0,0 +1,220 @@ +import torch +from PIL import Image, ImageDraw +import torchvision.transforms as transforms +import torchvision +from zipfile import ZipFile +import os +import multiprocessing +import math +import numpy as np +import random +from io import BytesIO + +VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png'] + + +def check_filenames_in_zipdata(filenames, ziproot): + samples = [] + for fst in ZipFile(ziproot).infolist(): + fname = fst.filename + if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0: + continue + if os.path.splitext(fname)[1].lower() in VALID_IMAGE_TYPES: + samples.append((fname)) + filenames = set(filenames) + samples = set(samples) + assert filenames.issubset(samples), 'Something wrong with your zip data' + + + +def draw_box(img, boxes): + colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] + draw = ImageDraw.Draw(img) + for bid, box in enumerate(boxes): + draw.rectangle([box[0], box[1], box[2], box[3]], outline =colors[bid % len(colors)], width=4) + # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1 + return img + + + +def to_valid(x0, y0, x1, y1, image_size, min_box_size): + valid = True + + if x0>image_size or y0>image_size or x1<0 or y1<0: + valid = False # no way to make this box vide, it is completely cropped out + return valid, (None, None, None, None) + + x0 = max(x0, 0) + y0 = max(y0, 0) + x1 = min(x1, image_size) + y1 = min(y1, image_size) + + if (x1-x0)*(y1-y0) / (image_size*image_size) < min_box_size: + valid = False + return valid, (None, None, None, None) + + return valid, (x0, y0, x1, y1) + + + + + +def recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, image_size, min_box_size): + """ + x,y,w,h: the original annotation corresponding to the raw image size. + trans_info: what resizing and cropping have been applied to the raw image + image_size: what is the final image size + """ + + x0 = x * trans_info["performed_scale"] - trans_info['crop_x'] + y0 = y * trans_info["performed_scale"] - trans_info['crop_y'] + x1 = (x + w) * trans_info["performed_scale"] - trans_info['crop_x'] + y1 = (y + h) * trans_info["performed_scale"] - trans_info['crop_y'] + + + # at this point, box annotation has been recalculated based on scaling and cropping + # but some point may fall off the image_size region (e.g., negative value), thus we + # need to clamp them into 0-image_size. But if all points falling outsize of image + # region, then we will consider this is an invalid box. + valid, (x0, y0, x1, y1) = to_valid(x0, y0, x1, y1, image_size, min_box_size) + + if valid: + # we also perform random flip. + # Here boxes are valid, and are based on image_size + if trans_info["performed_flip"]: + x0, x1 = image_size-x1, image_size-x0 + + return valid, (x0, y0, x1, y1) + + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, image_root, random_crop, random_flip, image_size): + super().__init__() + self.image_root = image_root + self.random_crop = random_crop + self.random_flip = random_flip + self.image_size = image_size + self.use_zip = False + + if image_root[-4::] == 'zip': + self.use_zip = True + self.zip_dict = {} + + if self.random_crop: + assert False, 'NOT IMPLEMENTED' + + + def fetch_zipfile(self, ziproot): + pid = multiprocessing.current_process().pid # get pid of this process. + if pid not in self.zip_dict: + self.zip_dict[pid] = ZipFile(ziproot) + zip_file = self.zip_dict[pid] + return zip_file + + def fetch_image(self, filename): + if self.use_zip: + zip_file = self.fetch_zipfile(self.image_root) + image = Image.open( BytesIO(zip_file.read(filename)) ).convert('RGB') + return image + else: + image = Image.open( os.path.join(self.image_root,filename) ).convert('RGB') + return image + + + def vis_getitem_data(self, index=None, out=None, return_tensor=False, name="res.jpg", print_caption=True): + + if out is None: + out = self[index] + + img = torchvision.transforms.functional.to_pil_image( out["image"]*0.5+0.5 ) + canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(out["image"]) ) + W, H = img.size + + if print_caption: + caption = out["caption"] + print(caption) + print(" ") + + boxes = [] + for box in out["boxes"]: + x0,y0,x1,y1 = box + boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] ) + img = draw_box(img, boxes) + + if return_tensor: + return torchvision.transforms.functional.to_tensor(img) + else: + img.save(name) + + + def transform_image(self, pil_image): + if self.random_crop: + assert False + arr = random_crop_arr(pil_image, self.image_size) + else: + arr, info = center_crop_arr(pil_image, self.image_size) + + info["performed_flip"] = False + if self.random_flip and random.random()<0.5: + arr = arr[:, ::-1] + info["performed_flip"] = True + + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2,0,1]) + + return torch.tensor(arr), info + + + +def center_crop_arr(pil_image, image_size): + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + WW, HH = pil_image.size + + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + # at this point, the min of pil_image side is desired image_size + performed_scale = image_size / min(WW, HH) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + + info = {"performed_scale":performed_scale, 'crop_y':crop_y, 'crop_x':crop_x, "WW":WW, 'HH':HH} + + return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size], info + + +def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): + min_smaller_dim_size = math.ceil(image_size / max_crop_frac) + max_smaller_dim_size = math.ceil(image_size / min_crop_frac) + smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) + + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + while min(*pil_image.size) >= 2 * smaller_dim_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = smaller_dim_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = random.randrange(arr.shape[0] - image_size + 1) + crop_x = random.randrange(arr.shape[1] - image_size + 1) + return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] diff --git a/dataset/catalog.py b/dataset/catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..b622e477dae7cb4ba5c599fa7d2f7220b4311885 --- /dev/null +++ b/dataset/catalog.py @@ -0,0 +1,72 @@ +import os + +class DatasetCatalog: + def __init__(self, ROOT, which_embedder): + assert which_embedder in ['clip', 'bert'] + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + + self.VGGrounding = { + "target": "dataset.tsv_dataset.TSVDataset", + "train_params": dict( + tsv_path=os.path.join(ROOT,'GROUNDING/gqa/tsv/train-00.tsv'), + ) + } + + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + + self.FlickrGrounding = { + "target": "dataset.tsv_dataset.TSVDataset", + "train_params":dict( + tsv_path=os.path.join(ROOT,'GROUNDING/flickr30k/tsv/train-00.tsv'), + ) + } + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + self.SBUGrounding = { + "target": "dataset.tsv_dataset.TSVDataset", + "train_params":dict( + tsv_path=os.path.join(ROOT,'GROUNDING/SBU/tsv/train-00.tsv'), + ) + } + + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + + self.CC3MGrounding = { + "target": "dataset.tsv_dataset.TSVDataset", + "train_params":dict( + tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'), + ) + } + + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + + self.CC12MGrounding = { + "target": "dataset.tsv_dataset.TSVDataset", + "train_params":dict( + tsv_path=os.path.join(ROOT,'GROUNDING/CC12M/tsv/train-00.tsv'), + ) + } + + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + # temp = 'category_embedding_clip.pth' if which_embedder == 'clip' else 'category_embedding_bert.pth' + # obj365_category_embedding_path = os.path.join(ROOT, 'OBJECTS365', temp) + + self.Obj365Detection = { + "target": "dataset.tsv_dataset.TSVDataset", + "train_params":dict( + tsv_path=os.path.join(ROOT,'OBJECTS365/tsv/train-00.tsv'), + ), + } + + diff --git a/dataset/cd_dataset.py b/dataset/cd_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0627329bda44a15c6821fc477bbde45acfe86a2f --- /dev/null +++ b/dataset/cd_dataset.py @@ -0,0 +1,250 @@ +import json, os, random, math +from collections import defaultdict +from copy import deepcopy + +import torch +from torch.utils.data import Dataset +import torchvision.transforms as transforms + +import numpy as np +from PIL import Image +from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid +from io import BytesIO + + + +def not_in_at_all(list1, list2): + for a in list1: + if a in list2: + return False + return True + + +def clean_annotations(annotations): + for anno in annotations: + anno.pop("segmentation", None) + anno.pop("area", None) + anno.pop("iscrowd", None) + # anno.pop("id", None) + + +def make_a_sentence(obj_names, clean=False): + + if clean: + obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names] + + caption = "" + tokens_positive = [] + for obj_name in obj_names: + start_len = len(caption) + caption += obj_name + end_len = len(caption) + caption += ", " + tokens_positive.append( + [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list + ) + caption = caption[:-2] # remove last ", " + + return caption #, tokens_positive + + +def check_all_have_same_images(instances_data, stuff_data, caption_data): + if stuff_data is not None: + assert instances_data["images"] == stuff_data["images"] + if caption_data is not None: + assert instances_data["images"] == caption_data["images"] + + +class CDDataset(BaseDataset): + "CD: Caption Detection" + def __init__(self, + image_root, + category_embedding_path, + instances_json_path = None, + stuff_json_path = None, + caption_json_path = None, + prob_real_caption = 0, + fake_caption_type = 'empty', + image_size=256, + max_images=None, + min_box_size=0.01, + max_boxes_per_image=8, + include_other=False, + random_crop = False, + random_flip = True, + ): + super().__init__(random_crop, random_flip, image_size) + + self.image_root = image_root + self.category_embedding_path = category_embedding_path + self.instances_json_path = instances_json_path + self.stuff_json_path = stuff_json_path + self.caption_json_path = caption_json_path + self.prob_real_caption = prob_real_caption + self.fake_caption_type = fake_caption_type + self.max_images = max_images + self.min_box_size = min_box_size + self.max_boxes_per_image = max_boxes_per_image + self.include_other = include_other + + + assert fake_caption_type in ["empty", "made"] + if prob_real_caption > 0: + assert caption_json_path is not None, "caption json must be given" + + + # Load all jsons + with open(instances_json_path, 'r') as f: + instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' + clean_annotations(instances_data["annotations"]) + self.instances_data = instances_data + + self.stuff_data = None + if stuff_json_path is not None: + with open(stuff_json_path, 'r') as f: + stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' + clean_annotations(stuff_data["annotations"]) + self.stuff_data = stuff_data + + self.captions_data = None + if caption_json_path is not None: + with open(caption_json_path, 'r') as f: + captions_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' + clean_annotations(captions_data["annotations"]) + self.captions_data = captions_data + + + # Load preprocessed name embedding + self.category_embeddings = torch.load(category_embedding_path) + self.embedding_len = list( self.category_embeddings.values() )[0].shape[0] + + + # Misc + self.image_ids = [] # main list for selecting images + self.image_id_to_filename = {} # file names used to read image + check_all_have_same_images(self.instances_data, self.stuff_data, self.captions_data) + for image_data in self.instances_data['images']: + image_id = image_data['id'] + filename = image_data['file_name'] + self.image_ids.append(image_id) + self.image_id_to_filename[image_id] = filename + + + # All category names (including things and stuff) + self.object_idx_to_name = {} + for category_data in self.instances_data['categories']: + self.object_idx_to_name[category_data['id']] = category_data['name'] + if self.stuff_data is not None: + for category_data in self.stuff_data['categories']: + self.object_idx_to_name[category_data['id']] = category_data['name'] + + + # Add object data from instances and stuff + self.image_id_to_objects = defaultdict(list) + self.select_objects( self.instances_data['annotations'] ) + if self.stuff_data is not None: + self.select_objects( self.stuff_data['annotations'] ) + + # Add caption data + if self.captions_data is not None: + self.image_id_to_captions = defaultdict(list) + self.select_captions( self.captions_data['annotations'] ) + + # Check if all filenames can be found in the zip file + # all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids] + # check_filenames_in_zipdata(all_filenames, image_root) + + + def select_objects(self, annotations): + for object_anno in annotations: + image_id = object_anno['image_id'] + object_name = self.object_idx_to_name[object_anno['category_id']] + other_ok = object_name != 'other' or self.include_other + if other_ok: + self.image_id_to_objects[image_id].append(object_anno) + + + def select_captions(self, annotations): + for caption_data in annotations: + image_id = caption_data['image_id'] + self.image_id_to_captions[image_id].append(caption_data) + + + def total_images(self): + return len(self) + + + def __getitem__(self, index): + if self.max_boxes_per_image > 99: + assert False, "Are you sure setting such large number of boxes?" + + out = {} + + image_id = self.image_ids[index] + out['id'] = image_id + + # Image + filename = self.image_id_to_filename[image_id] + image = self.fetch_image(filename) + #WW, HH = image.size + image_tensor, trans_info = self.transform_image(image) + out["image"] = image_tensor + + + # Select valid boxes after cropping (center or random) + this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id]) + areas = [] + all_obj_names = [] + all_boxes = [] + all_masks = [] + all_positive_embeddings = [] + for object_anno in this_image_obj_annos: + + x, y, w, h = object_anno['bbox'] + valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size) + + if valid: + areas.append( (x1-x0)*(y1-y0) ) + obj_name = self.object_idx_to_name[ object_anno['category_id'] ] + all_obj_names.append(obj_name) + all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1 + all_masks.append(1) + all_positive_embeddings.append( self.category_embeddings[obj_name] ) + + wanted_idxs = torch.tensor(areas).sort(descending=True)[1] + wanted_idxs = wanted_idxs[0:self.max_boxes_per_image] + obj_names = [] # used for making a sentence + boxes = torch.zeros(self.max_boxes_per_image, 4) + masks = torch.zeros(self.max_boxes_per_image) + positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len) + for i, idx in enumerate(wanted_idxs): + obj_names.append( all_obj_names[idx] ) + boxes[i] = all_boxes[idx] + masks[i] = all_masks[idx] + positive_embeddings[i] = all_positive_embeddings[idx] + + # Caption + if random.uniform(0, 1) < self.prob_real_caption: + caption_data = self.image_id_to_captions[image_id] + idx = random.randint(0, len(caption_data)-1 ) + caption = caption_data[idx]["caption"] + else: + if self.fake_caption_type == "empty": + caption = "" + else: + caption = make_a_sentence(obj_names, clean=True) + + + out["caption"] = caption + out["boxes"] = boxes + out["masks"] = masks + out["positive_embeddings"] = positive_embeddings + + return out + + + def __len__(self): + if self.max_images is None: + return len(self.image_ids) + return min(len(self.image_ids), self.max_images) + diff --git a/dataset/concat_dataset.py b/dataset/concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..df637663567a8c74673de9361950a6d663357fa0 --- /dev/null +++ b/dataset/concat_dataset.py @@ -0,0 +1,65 @@ +from .catalog import DatasetCatalog +from ldm.util import instantiate_from_config +import torch + + + + +class ConCatDataset(): + def __init__(self, dataset_name_list, ROOT, which_embedder, train=True, repeats=None): + self.datasets = [] + cul_previous_dataset_length = 0 + offset_map = [] + which_dataset = [] + + if repeats is None: + repeats = [1] * len(dataset_name_list) + else: + assert len(repeats) == len(dataset_name_list) + + + Catalog = DatasetCatalog(ROOT, which_embedder) + for dataset_idx, (dataset_name, yaml_params) in enumerate(dataset_name_list.items()): + repeat = repeats[dataset_idx] + + dataset_dict = getattr(Catalog, dataset_name) + + target = dataset_dict['target'] + params = dataset_dict['train_params'] if train else dataset_dict['val_params'] + if yaml_params is not None: + params.update(yaml_params) + dataset = instantiate_from_config( dict(target=target, params=params) ) + + self.datasets.append(dataset) + for _ in range(repeat): + offset_map.append( torch.ones(len(dataset))*cul_previous_dataset_length ) + which_dataset.append( torch.ones(len(dataset))*dataset_idx ) + cul_previous_dataset_length += len(dataset) + offset_map = torch.cat(offset_map, dim=0).long() + self.total_length = cul_previous_dataset_length + + self.mapping = torch.arange(self.total_length) - offset_map + self.which_dataset = torch.cat(which_dataset, dim=0).long() + + + def total_images(self): + count = 0 + for dataset in self.datasets: + print(dataset.total_images()) + count += dataset.total_images() + return count + + + + def __getitem__(self, idx): + dataset = self.datasets[ self.which_dataset[idx] ] + return dataset[ self.mapping[idx] ] + + + def __len__(self): + return self.total_length + + + + + diff --git a/dataset/grounding_dataset.py b/dataset/grounding_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1b1fa74fc948466bd3d1a522025413ee5224577a --- /dev/null +++ b/dataset/grounding_dataset.py @@ -0,0 +1,205 @@ +from tkinter.messagebox import NO +import torch +import json +from collections import defaultdict +from PIL import Image, ImageDraw +from copy import deepcopy +import os +import torchvision.transforms as transforms +import torchvision +from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid +from io import BytesIO +import random + +def check_unique(images, fields): + for field in fields: + temp_list = [] + for img_info in images: + temp_list.append(img_info[field]) + assert len(set(temp_list)) == len(temp_list), field + +def clean_data(data): + for data_info in data: + data_info.pop("original_img_id", None) + data_info.pop("original_id", None) + data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image) + data_info.pop("dataset_name", None) + data_info.pop("data_source", None) + data_info["data_id"] = data_info.pop("id") + + +def clean_annotations(annotations): + for anno_info in annotations: + anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco + anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation + anno_info.pop("area", None) + # anno_info.pop("id", None) + anno_info["data_id"] = anno_info.pop("image_id") + + +def draw_box(img, boxes): + draw = ImageDraw.Draw(img) + for box in boxes: + draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1 + return img + + +def xyhw2xyxy(box): + x0, y0, w, h = box + return [ x0, y0, x0+w, y0+h ] + + + +class GroundingDataset(BaseDataset): + def __init__(self, + image_root, + json_path, + annotation_embedding_path, + prob_real_caption=1, + image_size=256, + min_box_size=0.01, + max_boxes_per_data=8, + max_images=None, # set as 30K used to eval + random_crop = False, + random_flip = True, + ): + super().__init__(image_root, random_crop, random_flip, image_size) + self.image_root = image_root + self.json_path = json_path + self.annotation_embedding_path = annotation_embedding_path + self.prob_real_caption = prob_real_caption + self.min_box_size = min_box_size + self.max_boxes_per_data = max_boxes_per_data + self.max_images = max_images + + + # Load raw data + with open(json_path, 'r') as f: + json_raw = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' + self.data = json_raw["images"] # donot name it images, which is misleading + self.annotations = json_raw["annotations"] + + + # Load preprocessed name embedding + if 'bert' in annotation_embedding_path: + self.embedding_len = 1280 + elif 'clip' in annotation_embedding_path: + self.embedding_len = 768 + else: + assert False + + + # clean data and annotation + check_unique( self.data, ['id'] ) + check_unique( self.annotations, ['id'] ) + clean_data(self.data) + clean_annotations(self.annotations) + self.data_id_list = [ datum['data_id'] for datum in self.data ] + self.data = { datum['data_id']:datum for datum in self.data } # map self.data from a list into a dict + + + # data point to its annotation mapping + self.data_id_to_annos = defaultdict(list) + for anno in self.annotations: + self.data_id_to_annos[ anno["data_id"] ].append(anno) + + + + # These are not used that offen, but are useful in some cases + self.file_names = [] # all training images + self.file_name_to_data_ids = defaultdict(list) # for each image, there are multiple data points (captions) + for data_id in self.data_id_list: + fine_name = self.data[data_id]["file_name"] + self.file_names.append(fine_name) + self.file_name_to_data_ids[fine_name].append(data_id) + self.file_names = list(set(self.file_names)) + + + if self.max_images is not None: + "This is only used as COCO2017P evulation, when we set max_images as 30k" + assert False, 'I have commented out the following code to save cpu memory' + # new_data_id_list = [] + # new_file_name_to_data_ids = defaultdict(list) + # self.file_names = self.file_names[0:self.max_images] + # for file_name in self.file_names: + # data_id = self.file_name_to_data_ids[file_name][0] + # new_data_id_list.append(data_id) + # new_file_name_to_data_ids[file_name].append(data_id) + # self.data_id_list = new_data_id_list + # self.file_name_to_data_ids = new_file_name_to_data_ids + + + # Check if all filenames can be found in the zip file + # all_filenames = [self.data[idx]['file_name'] for idx in self.data_id_list ] + # check_filenames_in_zipdata(all_filenames, image_root) + + + def total_images(self): + return len(self.file_names) + + + def __getitem__(self, index): + if self.max_boxes_per_data > 99: + assert False, "Are you sure setting such large number of boxes?" + + out = {} + + data_id = self.data_id_list[index] + out['id'] = data_id + + + # Image and caption + file_name = self.data[data_id]['file_name'] + image = self.fetch_image(file_name) + image_tensor, trans_info = self.transform_image(image) + out["image"] = image_tensor + + if random.uniform(0, 1) < self.prob_real_caption: + out["caption"] = self.data[data_id]["caption"] + else: + out["caption"] = "" + + + + annos = deepcopy(self.data_id_to_annos[data_id]) + areas = [] + all_boxes = [] + all_masks = [] + all_positive_embeddings = [] + + + for anno in annos: + + x, y, w, h = anno['bbox'] + valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size) + + if valid: + areas.append( (x1-x0)*(y1-y0) ) + all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1 + all_masks.append(1) + all_positive_embeddings.append( torch.load(os.path.join(self.annotation_embedding_path,str(anno["id"])), map_location='cpu' ) ) + + wanted_idxs = torch.tensor(areas).sort(descending=True)[1] + wanted_idxs = wanted_idxs[0:self.max_boxes_per_data] + + boxes = torch.zeros(self.max_boxes_per_data, 4) + masks = torch.zeros(self.max_boxes_per_data) + positive_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len) + for i, idx in enumerate(wanted_idxs): + boxes[i] = all_boxes[idx] + masks[i] = all_masks[idx] + positive_embeddings[i] = all_positive_embeddings[idx] + + + out["boxes"] = boxes + out["masks"] = masks + out["positive_embeddings"] = positive_embeddings + + return out + + + + def __len__(self): + return len(self.data_id_list) + + diff --git a/dataset/layout_dataset.py b/dataset/layout_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8d2b4dc73e8c194e92725faeab368f0951f6f7e8 --- /dev/null +++ b/dataset/layout_dataset.py @@ -0,0 +1,237 @@ +import json, os, random, math +from collections import defaultdict +from copy import deepcopy + +import torch +from torch.utils.data import Dataset +import torchvision.transforms as transforms + +import numpy as np +from PIL import Image, ImageOps +from .base_dataset import BaseDataset, check_filenames_in_zipdata +from io import BytesIO + + + + +def clean_annotations(annotations): + for anno in annotations: + anno.pop("segmentation", None) + anno.pop("area", None) + anno.pop("iscrowd", None) + anno.pop("id", None) + + +def make_a_sentence(obj_names, clean=False): + + if clean: + obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names] + + caption = "" + tokens_positive = [] + for obj_name in obj_names: + start_len = len(caption) + caption += obj_name + end_len = len(caption) + caption += ", " + tokens_positive.append( + [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list + ) + caption = caption[:-2] # remove last ", " + + return caption #, tokens_positive + + +class LayoutDataset(BaseDataset): + """ + Note: this dataset can somehow be achieved in cd_dataset.CDDataset + Since if you donot set prob_real_caption=0 in CDDataset, then that + dataset will only use detection annotations. However, in that dataset, + we do not remove images but remove boxes. + + However, in layout2img works, people will just resize raw image data into 256*256, + thus they pre-calculate box size and apply min_box_size before min/max_boxes_per_image. + And then they will remove images if does not follow the rule. + + These two different methods will lead to different number of training/val images. + Thus this dataset here is only for layout2img. + + """ + def __init__(self, + image_root, + instances_json_path, + stuff_json_path, + category_embedding_path, + fake_caption_type = 'empty', + image_size=256, + max_samples=None, + min_box_size=0.02, + min_boxes_per_image=3, + max_boxes_per_image=8, + include_other=False, + random_flip=True + ): + super().__init__(random_crop=None, random_flip=None, image_size=None) # we only use vis_getitem func in BaseDataset, donot use the others. + + assert fake_caption_type in ['empty', 'made'] + self.image_root = image_root + self.instances_json_path = instances_json_path + self.stuff_json_path = stuff_json_path + self.category_embedding_path = category_embedding_path + self.fake_caption_type = fake_caption_type + self.image_size = image_size + self.max_samples = max_samples + self.min_box_size = min_box_size + self.min_boxes_per_image = min_boxes_per_image + self.max_boxes_per_image = max_boxes_per_image + self.include_other = include_other + self.random_flip = random_flip + + + self.transform = transforms.Compose([transforms.Resize( (image_size, image_size) ), + transforms.ToTensor(), + transforms.Lambda(lambda t: (t * 2) - 1) ]) + + # Load all jsons + with open(instances_json_path, 'r') as f: + instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' + clean_annotations(instances_data["annotations"]) + self.instances_data = instances_data + + with open(stuff_json_path, 'r') as f: + stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' + clean_annotations(stuff_data["annotations"]) + self.stuff_data = stuff_data + + + # Load preprocessed name embedding + self.category_embeddings = torch.load(category_embedding_path) + self.embedding_len = list( self.category_embeddings.values() )[0].shape[0] + + + # Misc + self.image_ids = [] # main list for selecting images + self.image_id_to_filename = {} # file names used to read image + self.image_id_to_size = {} # original size of this image + assert instances_data['images'] == stuff_data["images"] + for image_data in instances_data['images']: + image_id = image_data['id'] + filename = image_data['file_name'] + width = image_data['width'] + height = image_data['height'] + self.image_ids.append(image_id) + self.image_id_to_filename[image_id] = filename + self.image_id_to_size[image_id] = (width, height) + + # All category names (including things and stuff) + self.things_id_list = [] + self.stuff_id_list = [] + self.object_idx_to_name = {} + for category_data in instances_data['categories']: + self.things_id_list.append( category_data['id'] ) + self.object_idx_to_name[category_data['id']] = category_data['name'] + for category_data in stuff_data['categories']: + self.stuff_id_list.append( category_data['id'] ) + self.object_idx_to_name[category_data['id']] = category_data['name'] + self.all_categories = [ self.object_idx_to_name.get(k, None) for k in range(183+1) ] + + + # Add object data from instances and stuff + self.image_id_to_objects = defaultdict(list) + self.select_objects( instances_data['annotations'] ) + self.select_objects( stuff_data['annotations'] ) + + + # Prune images that have too few or too many objects + new_image_ids = [] + for image_id in self.image_ids: + num_objs = len(self.image_id_to_objects[image_id]) + if self.min_boxes_per_image <= num_objs <= self.max_boxes_per_image: + new_image_ids.append(image_id) + self.image_ids = new_image_ids + + + # Check if all filenames can be found in the zip file + all_filenames = [self.image_id_to_filename[idx] for idx in self.image_ids] + check_filenames_in_zipdata(all_filenames, image_root) + + + + def select_objects(self, annotations): + for object_anno in annotations: + image_id = object_anno['image_id'] + _, _, w, h = object_anno['bbox'] + W, H = self.image_id_to_size[image_id] + box_area = (w * h) / (W * H) + box_ok = box_area > self.min_box_size + object_name = self.object_idx_to_name[object_anno['category_id']] + other_ok = object_name != 'other' or self.include_other + if box_ok and other_ok: + self.image_id_to_objects[image_id].append(object_anno) + + + def total_images(self): + return len(self) + + + def __getitem__(self, index): + if self.max_boxes_per_image > 99: + assert False, "Are you sure setting such large number of boxes?" + + out = {} + + image_id = self.image_ids[index] + out['id'] = image_id + + flip = self.random_flip and random.random()<0.5 + + # Image + filename = self.image_id_to_filename[image_id] + zip_file = self.fetch_zipfile(self.image_root) + image = Image.open(BytesIO(zip_file.read(filename))).convert('RGB') + WW, HH = image.size + if flip: + image = ImageOps.mirror(image) + out["image"] = self.transform(image) + + this_image_obj_annos = deepcopy(self.image_id_to_objects[image_id]) + + # Make a sentence + obj_names = [] # used for make a sentence + boxes = torch.zeros(self.max_boxes_per_image, 4) + masks = torch.zeros(self.max_boxes_per_image) + positive_embeddings = torch.zeros(self.max_boxes_per_image, self.embedding_len) + for idx, object_anno in enumerate(this_image_obj_annos): + obj_name = self.object_idx_to_name[ object_anno['category_id'] ] + obj_names.append(obj_name) + x, y, w, h = object_anno['bbox'] + x0 = x / WW + y0 = y / HH + x1 = (x + w) / WW + y1 = (y + h) / HH + if flip: + x0, x1 = 1-x1, 1-x0 + boxes[idx] = torch.tensor([x0,y0,x1,y1]) + masks[idx] = 1 + positive_embeddings[idx] = self.category_embeddings[obj_name] + + if self.fake_caption_type == 'empty': + caption = "" + else: + caption = make_a_sentence(obj_names, clean=True) + + out["caption"] = caption + out["boxes"] = boxes + out["masks"] = masks + out["positive_embeddings"] = positive_embeddings + + + return out + + + def __len__(self): + if self.max_samples is None: + return len(self.image_ids) + return min(len(self.image_ids), self.max_samples) + + diff --git a/dataset/tsv.py b/dataset/tsv.py new file mode 100644 index 0000000000000000000000000000000000000000..dfbc4c4d0c1ac4b833b8229a952c1e2fe03bd6f1 --- /dev/null +++ b/dataset/tsv.py @@ -0,0 +1,212 @@ +import os +import os.path as op +import gc +import json +from typing import List +import logging + +try: + from .blob_storage import BlobStorage, disk_usage +except: + class BlobStorage: + pass + + +def generate_lineidx(filein: str, idxout: str) -> None: + idxout_tmp = idxout + '.tmp' + with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout: + fsize = os.fstat(tsvin.fileno()).st_size + fpos = 0 + while fpos != fsize: + tsvout.write(str(fpos) + "\n") + tsvin.readline() + fpos = tsvin.tell() + os.rename(idxout_tmp, idxout) + + +def read_to_character(fp, c): + result = [] + while True: + s = fp.read(32) + assert s != '' + if c in s: + result.append(s[: s.index(c)]) + break + else: + result.append(s) + return ''.join(result) + + +class TSVFile(object): + def __init__(self, + tsv_file: str, + if_generate_lineidx: bool = False, + lineidx: str = None, + class_selector: List[str] = None, + blob_storage: BlobStorage = None): + self.tsv_file = tsv_file + self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \ + if not lineidx else lineidx + self.linelist = op.splitext(tsv_file)[0] + '.linelist' + self.chunks = op.splitext(tsv_file)[0] + '.chunks' + self._fp = None + self._lineidx = None + self._sample_indices = None + self._class_boundaries = None + self._class_selector = class_selector + self._blob_storage = blob_storage + self._len = None + # the process always keeps the process which opens the file. + # If the pid is not equal to the currrent pid, we will re-open the file. + self.pid = None + # generate lineidx if not exist + if not op.isfile(self.lineidx) and if_generate_lineidx: + generate_lineidx(self.tsv_file, self.lineidx) + + def __del__(self): + self.gcidx() + if self._fp: + self._fp.close() + # physically remove the tsv file if it is retrieved by BlobStorage + if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file): + try: + original_usage = disk_usage('/') + os.remove(self.tsv_file) + logging.info("Purged %s (disk usage: %.2f%% => %.2f%%)" % + (self.tsv_file, original_usage, disk_usage('/') * 100)) + except: + # Known issue: multiple threads attempting to delete the file will raise a FileNotFound error. + # TODO: try Threadling.Lock to better handle the race condition + pass + + def __str__(self): + return "TSVFile(tsv_file='{}')".format(self.tsv_file) + + def __repr__(self): + return str(self) + + def gcidx(self): + logging.debug('Run gc collect') + self._lineidx = None + self._sample_indices = None + #self._class_boundaries = None + return gc.collect() + + def get_class_boundaries(self): + return self._class_boundaries + + def num_rows(self, gcf=False): + if (self._len is None): + self._ensure_lineidx_loaded() + retval = len(self._sample_indices) + + if (gcf): + self.gcidx() + + self._len = retval + + return self._len + + def seek(self, idx: int): + self._ensure_tsv_opened() + self._ensure_lineidx_loaded() + try: + pos = self._lineidx[self._sample_indices[idx]] + except: + logging.info('=> {}-{}'.format(self.tsv_file, idx)) + raise + self._fp.seek(pos) + return [s.strip() for s in self._fp.readline().split('\t')] + + def seek_first_column(self, idx: int): + self._ensure_tsv_opened() + self._ensure_lineidx_loaded() + pos = self._lineidx[idx] + self._fp.seek(pos) + return read_to_character(self._fp, '\t') + + def get_key(self, idx: int): + return self.seek_first_column(idx) + + def __getitem__(self, index: int): + return self.seek(index) + + def __len__(self): + return self.num_rows() + + def _ensure_lineidx_loaded(self): + if self._lineidx is None: + logging.debug('=> loading lineidx: {}'.format(self.lineidx)) + with open(self.lineidx, 'r') as fp: + lines = fp.readlines() + lines = [line.strip() for line in lines] + self._lineidx = [int(line) for line in lines] + + # read the line list if exists + linelist = None + if op.isfile(self.linelist): + with open(self.linelist, 'r') as fp: + linelist = sorted( + [ + int(line.strip()) + for line in fp.readlines() + ] + ) + + if op.isfile(self.chunks): + self._sample_indices = [] + self._class_boundaries = [] + class_boundaries = json.load(open(self.chunks, 'r')) + for class_name, boundary in class_boundaries.items(): + start = len(self._sample_indices) + if class_name in self._class_selector: + for idx in range(boundary[0], boundary[1] + 1): + # NOTE: potentially slow when linelist is long, try to speed it up + if linelist and idx not in linelist: + continue + self._sample_indices.append(idx) + end = len(self._sample_indices) + self._class_boundaries.append((start, end)) + else: + if linelist: + self._sample_indices = linelist + else: + self._sample_indices = list(range(len(self._lineidx))) + + def _ensure_tsv_opened(self): + if self._fp is None: + if self._blob_storage: + self._fp = self._blob_storage.open(self.tsv_file) + else: + self._fp = open(self.tsv_file, 'r') + self.pid = os.getpid() + + if self.pid != os.getpid(): + logging.debug('=> re-open {} because the process id changed'.format(self.tsv_file)) + self._fp = open(self.tsv_file, 'r') + self.pid = os.getpid() + + +class TSVWriter(object): + def __init__(self, tsv_file): + self.tsv_file = tsv_file + self.lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' + self.tsv_file_tmp = self.tsv_file + '.tmp' + self.lineidx_file_tmp = self.lineidx_file + '.tmp' + + self.tsv_fp = open(self.tsv_file_tmp, 'w') + self.lineidx_fp = open(self.lineidx_file_tmp, 'w') + + self.idx = 0 + + def write(self, values, sep='\t'): + v = '{0}\n'.format(sep.join(map(str, values))) + self.tsv_fp.write(v) + self.lineidx_fp.write(str(self.idx) + '\n') + self.idx = self.idx + len(v) + + def close(self): + self.tsv_fp.close() + self.lineidx_fp.close() + os.rename(self.tsv_file_tmp, self.tsv_file) + os.rename(self.lineidx_file_tmp, self.lineidx_file) diff --git a/dataset/tsv_dataset.py b/dataset/tsv_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2db59faf1254970b35d2fc8dec78afde4f6918 --- /dev/null +++ b/dataset/tsv_dataset.py @@ -0,0 +1,326 @@ +from tkinter.messagebox import NO +import torch +import json +from collections import defaultdict +from PIL import Image, ImageDraw +from copy import deepcopy +import os +import torchvision.transforms as transforms +import torchvision +from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid +from io import BytesIO +import random + +from .tsv import TSVFile + +from io import BytesIO +import base64 +from PIL import Image +import numpy as np + + +def decode_base64_to_pillow(image_b64): + return Image.open(BytesIO(base64.b64decode(image_b64))).convert('RGB') + +def decode_tensor_from_string(arr_str, use_tensor=True): + arr = np.frombuffer(base64.b64decode(arr_str), dtype='float32') + if use_tensor: + arr = torch.from_numpy(arr) + return arr + +def decode_item(item): + item = json.loads(item) + item['image'] = decode_base64_to_pillow(item['image']) + + for anno in item['annos']: + anno['image_embedding_before'] = decode_tensor_from_string(anno['image_embedding_before']) + anno['text_embedding_before'] = decode_tensor_from_string(anno['text_embedding_before']) + anno['image_embedding_after'] = decode_tensor_from_string(anno['image_embedding_after']) + anno['text_embedding_after'] = decode_tensor_from_string(anno['text_embedding_after']) + return item + +def check_unique(images, fields): + for field in fields: + temp_list = [] + for img_info in images: + temp_list.append(img_info[field]) + assert len(set(temp_list)) == len(temp_list), field + +def clean_data(data): + for data_info in data: + data_info.pop("original_img_id", None) + data_info.pop("original_id", None) + data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image) + data_info.pop("dataset_name", None) + data_info.pop("data_source", None) + data_info["data_id"] = data_info.pop("id") + + +def clean_annotations(annotations): + for anno_info in annotations: + anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco + anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation + anno_info.pop("area", None) + # anno_info.pop("id", None) + anno_info["data_id"] = anno_info.pop("image_id") + + +def draw_box(img, boxes): + draw = ImageDraw.Draw(img) + for box in boxes: + draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1 + return img + + +def xyhw2xyxy(box): + x0, y0, w, h = box + return [ x0, y0, x0+w, y0+h ] + + +def make_a_sentence(obj_names, clean=False): + + if clean: + obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names] + + caption = "" + tokens_positive = [] + for obj_name in obj_names: + start_len = len(caption) + caption += obj_name + end_len = len(caption) + caption += ", " + tokens_positive.append( + [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list + ) + caption = caption[:-2] # remove last ", " + + return caption #, tokens_positive + + +def mask_for_random_drop_text_or_image_feature(masks, random_drop_embedding): + """ + input masks tell how many valid grounding tokens for this image + e.g., 1,1,1,1,0,0,0,0,0,0... + + If random_drop_embedding=both. we will random drop either image or + text feature for each token, + but we always make sure there is at least one feature used. + In other words, the following masks are not valid + (because for the second obj, no feature at all): + image: 1,0,1,1,0,0,0,0,0 + text: 1,0,0,0,0,0,0,0,0 + + if random_drop_embedding=image. we will random drop image feature + and always keep the text one. + + """ + N = masks.shape[0] + + if random_drop_embedding=='both': + temp_mask = torch.ones(2,N) + for i in range(N): + if random.uniform(0, 1) < 0.5: # else keep both features + idx = random.sample([0,1], 1)[0] # randomly choose to drop image or text feature + temp_mask[idx,i] = 0 + image_masks = temp_mask[0]*masks + text_masks = temp_mask[1]*masks + + if random_drop_embedding=='image': + image_masks = masks*(torch.rand(N)>0.5)*1 + text_masks = masks + + return image_masks, text_masks + + + + + +def project(x, projection_matrix): + """ + x (Batch*768) should be the penultimate feature of CLIP (before projection) + projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer + defined in CLIP (out_dim, in_dim), thus we need to apply transpose below. + this function will return the CLIP feature (without normalziation) + """ + return x@torch.transpose(projection_matrix, 0, 1) + + +def inv_project(y, projection_matrix): + """ + y (Batch*768) should be the CLIP feature (after projection) + projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer + defined in CLIP (out_dim, in_dim). + this function will return the CLIP penultimate feature. + + Note: to make sure getting the correct penultimate feature, the input y should not be normalized. + If it is normalized, then the result will be scaled by CLIP feature norm, which is unknown. + """ + return y@torch.transpose(torch.linalg.inv(projection_matrix), 0, 1) + + + + +class TSVDataset(BaseDataset): + def __init__(self, + tsv_path, + which_embedder='clip', + which_layer=['after','after'], # text and image + prob_use_caption=1, + random_drop_embedding='none', + image_size=256, + min_box_size=0.01, + max_boxes_per_data=8, + max_images=None, # set as 30K used to eval + random_crop = False, + random_flip = True, + ): + image_root = "a placeholder path as we are using tsv here" + super().__init__(image_root, random_crop, random_flip, image_size) + self.tsv_path = tsv_path + self.which_embedder = which_embedder + self.prob_use_caption = prob_use_caption + self.random_drop_embedding = random_drop_embedding + self.min_box_size = min_box_size + self.max_boxes_per_data = max_boxes_per_data + self.max_images = max_images + + assert which_layer in [ ['after','after'], ['before','after_renorm'], ['before','after_reproject'] ] + assert random_drop_embedding in ['none', 'both', 'image'] + self.which_layer_text = which_layer[0] + self.which_layer_image = which_layer[1] + + #self.projection_matrix = torch.load(os.path.join(os.path.dirname(__file__), 'projection_matrix') ) + self.projection_matrix = torch.load('projection_matrix.pth') + + # Load tsv data + self.tsv_file = TSVFile(self.tsv_path) + + + # Load preprocessed name embedding + if which_embedder == 'bert': + self.embedding_len = 1280 + elif which_embedder == 'clip': + self.embedding_len = 768 + else: + assert False + + def total_images(self): + return len(self) + + def get_item_from_tsv(self, index): + _, item = self.tsv_file[index] + item = decode_item(item) + return item + + + def mapping(self, image_embedding): + if self.which_layer_image == 'after': + # both use CLIP aligned feature + return image_embedding + elif self.which_layer_image == 'after_renorm': + # text use before, but image use after projection but normalize to 28.7 + return image_embedding*28.7 + elif self.which_layer_image == 'after_reproject': + image_embedding = project( image_embedding.unsqueeze(0), self.projection_matrix.T ) + image_embedding = image_embedding.squeeze(0) + image_embedding = image_embedding / image_embedding.norm() + image_embedding = image_embedding * 28.7 + return image_embedding + + + + def __getitem__(self, index): + if self.max_boxes_per_data > 99: + assert False, "Are you sure setting such large number of boxes?" + + raw_item = self.get_item_from_tsv(index) + is_det = raw_item.get('is_det', False) # if it is from detection (such as o365), then we will make a caption + + out = {} + + # -------------------- id and image ------------------- # + out['id'] = raw_item['data_id'] + image = raw_item['image'] + image_tensor, trans_info = self.transform_image(image) + out["image"] = image_tensor + + + + # -------------------- grounding token ------------------- # + annos = raw_item['annos'] + + areas = [] + all_boxes = [] + all_masks = [] + all_text_embeddings = [] + all_image_embeddings = [] + if is_det: + all_category_names = [] + + text_embedding_name = 'text_embedding_before' if self.which_layer_text == 'before' else 'text_embedding_after' + image_embedding_name = 'image_embedding_after' + + for anno in annos: + x, y, w, h = anno['bbox'] + valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size) + + if valid: + areas.append( (x1-x0)*(y1-y0) ) + all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1 + all_masks.append(1) + all_text_embeddings.append(anno[text_embedding_name]) + all_image_embeddings.append( self.mapping(anno[image_embedding_name]) ) + if is_det: + all_category_names.append(anno["category_name"]) + + + wanted_idxs = torch.tensor(areas).sort(descending=True)[1] + wanted_idxs = wanted_idxs[0:self.max_boxes_per_data] + + boxes = torch.zeros(self.max_boxes_per_data, 4) + masks = torch.zeros(self.max_boxes_per_data) + text_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len) + image_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len) + if is_det: + category_names = [] + for i, idx in enumerate(wanted_idxs): + boxes[i] = all_boxes[idx] + masks[i] = all_masks[idx] + text_embeddings[i] = all_text_embeddings[idx] + image_embeddings[i] = all_image_embeddings[idx] + if is_det: + category_names.append(all_category_names[idx]) + + if self.random_drop_embedding != 'none': + image_masks, text_masks = mask_for_random_drop_text_or_image_feature(masks, self.random_drop_embedding) + else: + image_masks = masks + text_masks = masks + + + out["boxes"] = boxes + out["masks"] = masks + out["image_masks"] = image_masks + out["text_masks"] = text_masks + out["text_embeddings"] = text_embeddings + out["image_embeddings"] = image_embeddings + + + + # -------------------- caption ------------------- # + if random.uniform(0, 1) < self.prob_use_caption: + if is_det: + out["caption"] = make_a_sentence(category_names) + else: + out["caption"] = raw_item["caption"] + else: + out["caption"] = "" + + return out + + + + def __len__(self): + return len(self.tsv_file) + + diff --git a/dataset/utils.py b/dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0ceafd04bc6860eaccfe5a480fb452f00792dac4 --- /dev/null +++ b/dataset/utils.py @@ -0,0 +1,116 @@ +#!/usr/bin/python +# +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import PIL +import torch +import torchvision.transforms as T + + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + +INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN] +INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD] + + +def imagenet_preprocess(): + return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) + + +def rescale(x): + lo, hi = x.min(), x.max() + return x.sub(lo).div(hi - lo) + + +def imagenet_deprocess(rescale_image=True): + transforms = [ + T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD), + T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]), + ] + if rescale_image: + transforms.append(rescale) + return T.Compose(transforms) + + +def imagenet_deprocess_batch(imgs, rescale=True): + """ + Input: + - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images + + Output: + - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images + in the range [0, 255] + """ + if isinstance(imgs, torch.autograd.Variable): + imgs = imgs.data + imgs = imgs.cpu().clone() + deprocess_fn = imagenet_deprocess(rescale_image=rescale) + imgs_de = [] + for i in range(imgs.size(0)): + img_de = deprocess_fn(imgs[i])[None] + img_de = img_de.mul(255).clamp(0, 255).byte() + imgs_de.append(img_de) + imgs_de = torch.cat(imgs_de, dim=0) + return imgs_de + + +class Resize(object): + def __init__(self, size, interp=PIL.Image.BILINEAR): + if isinstance(size, tuple): + H, W = size + self.size = (W, H) + else: + self.size = (size, size) + self.interp = interp + + def __call__(self, img): + return img.resize(self.size, self.interp) + + +def unpack_var(v): + if isinstance(v, torch.autograd.Variable): + return v.data + return v + + +def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img): + triples = unpack_var(triples) + obj_data = [unpack_var(o) for o in obj_data] + obj_to_img = unpack_var(obj_to_img) + triple_to_img = unpack_var(triple_to_img) + + triples_out = [] + obj_data_out = [[] for _ in obj_data] + obj_offset = 0 + N = obj_to_img.max() + 1 + for i in range(N): + o_idxs = (obj_to_img == i).nonzero().view(-1) + t_idxs = (triple_to_img == i).nonzero().view(-1) + + cur_triples = triples[t_idxs].clone() + cur_triples[:, 0] -= obj_offset + cur_triples[:, 2] -= obj_offset + triples_out.append(cur_triples) + + for j, o_data in enumerate(obj_data): + cur_o_data = None + if o_data is not None: + cur_o_data = o_data[o_idxs] + obj_data_out[j].append(cur_o_data) + + obj_offset += o_idxs.size(0) + + return triples_out, obj_data_out diff --git a/gligen/__pycache__/__init__.cpython-38.pyc b/gligen/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dddf0891f9b86a5a4c19aad9273cdac089f3782 Binary files /dev/null and b/gligen/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/__pycache__/distributed.cpython-38.pyc b/gligen/__pycache__/distributed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7e01e94be05a8a6f8afe0291682765163cf72ac Binary files /dev/null and b/gligen/__pycache__/distributed.cpython-38.pyc differ diff --git a/gligen/__pycache__/evaluator.cpython-38.pyc b/gligen/__pycache__/evaluator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b775a0d708154f3ead33f5fd96f7d51fcf266103 Binary files /dev/null and b/gligen/__pycache__/evaluator.cpython-38.pyc differ diff --git a/gligen/__pycache__/task_grounded_generation.cpython-38.pyc b/gligen/__pycache__/task_grounded_generation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5b36ff062326f901f0a31f3f547ce95993e4fd Binary files /dev/null and b/gligen/__pycache__/task_grounded_generation.cpython-38.pyc differ diff --git a/gligen/__pycache__/trainer.cpython-38.pyc b/gligen/__pycache__/trainer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16ac8ae5d0a4e8c3ade40746a342320c7aab222b Binary files /dev/null and b/gligen/__pycache__/trainer.cpython-38.pyc differ diff --git a/gligen/ldm/__pycache__/util.cpython-38.pyc b/gligen/ldm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c49bacc41c35909fac61e2ecc1c916fc1ffb7605 Binary files /dev/null and b/gligen/ldm/__pycache__/util.cpython-38.pyc differ diff --git a/gligen/ldm/models/.DS_Store b/gligen/ldm/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8e2da8e41b76fb8a3c71433582fcacba45e51b72 Binary files /dev/null and b/gligen/ldm/models/.DS_Store differ diff --git a/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31267258e90660e79d7d00084135454ec92e8285 Binary files /dev/null and b/gligen/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ diff --git a/gligen/ldm/models/autoencoder.py b/gligen/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1163e72dd063ee6773fe3e3c586c43b0663da4c9 --- /dev/null +++ b/gligen/ldm/models/autoencoder.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +#import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +# from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + + + +class AutoencoderKL(nn.Module): + def __init__(self, + ddconfig, + embed_dim, + scale_factor=1 + ): + super().__init__() + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + self.scale_factor = scale_factor + + + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior.sample() * self.scale_factor + + def decode(self, z): + z = 1. / self.scale_factor * z + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + + + + + + + diff --git a/gligen/ldm/models/diffusion/__init__.py b/gligen/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aab67efed99c0c65c5da44eab775d687302dcbd4 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e496b8b850a5e43196b3aa6381f453e21f1d1766 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..967b19d2615c7aca3ade02c313bdf641867ed6a0 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07c6ce260940772e875980d5df10dfba907d352e Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/gaussian_smoothing.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0130f2932816a221829355e8f8fbf412e035960 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/ldm.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d450a9c0e02daf65da42ac35d0f564f98c894e8 Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/loss.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc b/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3067ac490bb3ebd63857301562bf694b559888fc Binary files /dev/null and b/gligen/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc differ diff --git a/gligen/ldm/models/diffusion/classifier.py b/gligen/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef --- /dev/null +++ b/gligen/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/gligen/ldm/models/diffusion/ddim.py b/gligen/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..7db86661e94319b54bec15bf521097bb7b7faf87 --- /dev/null +++ b/gligen/ldm/models/diffusion/ddim.py @@ -0,0 +1,134 @@ +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None): + super().__init__() + self.diffusion = diffusion + self.model = model + self.device = diffusion.betas.device + self.ddpm_num_timesteps = diffusion.num_timesteps + self.schedule = schedule + self.alpha_generator_func = alpha_generator_func + self.set_alpha_scale = set_alpha_scale + + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + attr = attr.to(self.device) + setattr(self, name, attr) + + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=False) + alphas_cumprod = self.diffusion.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) + + self.register_buffer('betas', to_torch(self.diffusion.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=False) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + + @torch.no_grad() + def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None): + self.make_schedule(ddim_num_steps=S) + return self.ddim_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0) + + + @torch.no_grad() + def ddim_sampling(self, shape, input, uc, guidance_scale=1, mask=None, x0=None): + b = shape[0] + + img = input["x"] + if img == None: + img = torch.randn(shape, device=self.device) + input["x"] = img + + + time_range = np.flip(self.ddim_timesteps) + total_steps = self.ddim_timesteps.shape[0] + + #iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + iterator = time_range + + if self.alpha_generator_func != None: + alphas = self.alpha_generator_func(len(iterator)) + + + for i, step in enumerate(iterator): + + # set alpha + if self.alpha_generator_func != None: + self.set_alpha_scale(self.model, alphas[i]) + if alphas[i] == 0: + self.model.restore_first_conv_from_SD() + + # run + index = total_steps - i - 1 + input["timesteps"] = torch.full((b,), step, device=self.device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.diffusion.q_sample( x0, input["timesteps"] ) + img = img_orig * mask + (1. - mask) * img + input["x"] = img + + img, pred_x0 = self.p_sample_ddim(input, index=index, uc=uc, guidance_scale=guidance_scale) + input["x"] = img + + return img + + + @torch.no_grad() + def p_sample_ddim(self, input, index, uc=None, guidance_scale=1): + + + e_t = self.model(input) + if uc is not None and guidance_scale != 1: + unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=input["inpainting_extra_input"], grounding_extra_input=input['grounding_extra_input']) + e_t_uncond = self.model( unconditional_input ) + e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond) + + # select parameters corresponding to the currently considered timestep + b = input["x"].shape[0] + a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device) + a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device) + sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device) + + # current prediction for x_0 + pred_x0 = (input["x"] - sqrt_one_minus_at * e_t) / a_t.sqrt() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * torch.randn_like( input["x"] ) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + + return x_prev, pred_x0 diff --git a/gligen/ldm/models/diffusion/ddpm.py b/gligen/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3feeabf55dbc0cf6fd112195bcebd7fddbec41 --- /dev/null +++ b/gligen/ldm/models/diffusion/ddpm.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial +from ldm.modules.diffusionmodules.util import make_beta_schedule + + + + + +class DDPM(nn.Module): + def __init__(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().__init__() + + self.v_posterior = 0 + self.register_schedule(beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + + def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + + + + + + + + + + + + + + + + + diff --git a/gligen/ldm/models/diffusion/gaussian_smoothing.py b/gligen/ldm/models/diffusion/gaussian_smoothing.py new file mode 100644 index 0000000000000000000000000000000000000000..eec81e48b935ae1d3111f2c71d8d9c430bf8c19c --- /dev/null +++ b/gligen/ldm/models/diffusion/gaussian_smoothing.py @@ -0,0 +1,119 @@ +import math +import numbers +import torch +from torch import nn +from torch.nn import functional as F + + +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + def __init__(self, channels, kernel_size, sigma, dim=2): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + torch.exp(-((mgrid - mean) / (2 * std)) ** 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) + + +class AverageSmoothing(nn.Module): + """ + Apply average smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the average kernel. + sigma (float, sequence): Standard deviation of the rage kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + def __init__(self, channels, kernel_size, dim=2): + super(AverageSmoothing, self).__init__() + + # Make sure sum of values in gaussian kernel equals 1. + kernel = torch.ones(size=(kernel_size, kernel_size)) / (kernel_size * kernel_size) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) + ) + + def forward(self, input): + """ + Apply average filter to input. + Arguments: + input (torch.Tensor): Input to apply average filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight, groups=self.groups) diff --git a/gligen/ldm/models/diffusion/ldm.py b/gligen/ldm/models/diffusion/ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..78fa65862d848a3fa49ff8c2b7bc475067175891 --- /dev/null +++ b/gligen/ldm/models/diffusion/ldm.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import numpy as np +from tqdm import tqdm +from ldm.util import default +from ldm.modules.diffusionmodules.util import extract_into_tensor +from .ddpm import DDPM + + + +class LatentDiffusion(DDPM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # hardcoded + self.clip_denoised = False + + + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + + "Does not support DDPM sampling anymore. Only do DDIM or PLMS" + + # = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = # + + # def predict_start_from_noise(self, x_t, t, noise): + # return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + # extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) + + # def q_posterior(self, x_start, x_t, t): + # posterior_mean = ( + # extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + # extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + # ) + # posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + # posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + # return posterior_mean, posterior_variance, posterior_log_variance_clipped + + + # def p_mean_variance(self, model, x, c, t): + + # model_out = model(x, t, c) + # x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + + # if self.clip_denoised: + # x_recon.clamp_(-1., 1.) + + # model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + # return model_mean, posterior_variance, posterior_log_variance, x_recon + + + # @torch.no_grad() + # def p_sample(self, model, x, c, t): + # b, *_, device = *x.shape, x.device + # model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, ) + # noise = torch.randn_like(x) + + # # no noise when t == 0 + # nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + + + # @torch.no_grad() + # def p_sample_loop(self, model, shape, c): + # device = self.betas.device + # b = shape[0] + # img = torch.randn(shape, device=device) + + # iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps) + # for i in iterator: + # ts = torch.full((b,), i, device=device, dtype=torch.long) + # img, x0 = self.p_sample(model, img, c, ts) + + # return img + + + # @torch.no_grad() + # def sample(self, model, shape, c, uc=None, guidance_scale=None): + # return self.p_sample_loop(model, shape, c) + + + + + diff --git a/gligen/ldm/models/diffusion/loss.py b/gligen/ldm/models/diffusion/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..25e199219de4b6c3792d56a03e8d71450416151c --- /dev/null +++ b/gligen/ldm/models/diffusion/loss.py @@ -0,0 +1,170 @@ +import math +import torch +from ldm.models.diffusion.gaussian_smoothing import GaussianSmoothing +from torch.nn import functional as F +from torchvision.utils import save_image + + + + + + +def loss_one_att_outside(attn_map,bboxes, object_positions,t): + # loss = torch.tensor(0).to('cuda') + loss = 0 + object_number = len(bboxes) + b, i, j = attn_map.shape + H = W = int(math.sqrt(i)) + + + # if t== 20: import pdb; pdb.set_trace() + + for obj_idx in range(object_number): + + for obj_box in bboxes[obj_idx]: + mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W)) + x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ + int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) + mask[y_min: y_max, x_min: x_max] = 1. + mask_out = 1. - mask + index = (mask == 1.).nonzero(as_tuple=False) + index_in_key = index[:,0]* H + index[:, 1] + att_box = torch.zeros_like(attn_map) + att_box[:,index_in_key,:] = attn_map[:,index_in_key,:] + + att_box = att_box.sum(axis=1) / index_in_key.shape[0] + att_box = att_box.reshape(-1, H, H) + activation_value = (att_box* mask_out).reshape(b, -1).sum(dim=-1) #/ att_box.reshape(b, -1).sum(dim=-1) + loss += torch.mean(activation_value) + + return loss / object_number + +def caculate_loss_self_att(self_first, self_second, self_third, bboxes, object_positions, t, list_res=[256], smooth_att = True,sigma=0.5,kernel_size=3 ): + all_attn = get_all_self_att(self_first, self_second, self_third) + cnt = 0 + total_loss = 0 + for res in list_res: + attn_maps = all_attn[res] + for attn in attn_maps: + total_loss += loss_one_att_outside(attn, bboxes, object_positions,t) + cnt += 1 + + return total_loss /cnt + + +def get_all_self_att(self_first, self_second, self_third): + result = {256:[], 1024:[], 4096:[], 64:[], 94:[],1054:[] ,286:[],4126:[] } + # import pdb; pdb.set_trace() + all_att = [self_first, self_second, self_third] + for self_att in all_att: + for att in self_att: + if att != []: + temp = att[0] + for attn_map in temp: + current_res = attn_map.shape[1] + # print(current_res) + result[current_res].append(attn_map) + return result + +def get_all_attention(attn_maps_mid, attn_maps_up , attn_maps_down, res): + result = [] + + for attn_map_integrated in attn_maps_up: + if attn_map_integrated == []: continue + attn_map = attn_map_integrated[0][0] + b, i, j = attn_map.shape + H = W = int(math.sqrt(i)) + # print(H) + if H == res: + result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] )) + for attn_map_integrated in attn_maps_mid: + + # for attn_map_integrated in attn_maps_mid: + attn_map = attn_map_integrated[0] + b, i, j = attn_map.shape + H = W = int(math.sqrt(i)) + # print(H) + if (H==res): + result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] )) + # import pdb; pdb.set_trace() + for attn_map_integrated in attn_maps_down: + if attn_map_integrated == []: continue + attn_map = attn_map_integrated[0][0] + if attn_map == []: continue + b, i, j = attn_map.shape + H = W = int(math.sqrt(i)) + # print(H) + if (H==res): + result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] )) + + result = torch.cat(result, dim=0) + result = result.sum(0) / result.shape[0] + return result + + +def caculate_loss_att_fixed_cnt(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ): + attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res) + # attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32) + # attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64) + # attn8 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 8) + all_attn = [attn16] + obj_number = len(bboxes) + total_loss = 0 + # import pdb; pdb.set_trace() + for attn in all_attn[0:1]: + attn_text = attn[:, :, 1:-1] + attn_text *= 100 + attn_text = torch.nn.functional.softmax(attn_text, dim=-1) + current_res = attn.shape[0] + H = W = current_res + + # if t == 49: import pdb; pdb.set_trace() + for obj_idx in range(obj_number): + num_boxes= 0 + + for obj_position in object_positions[obj_idx]: + true_obj_position = obj_position - 1 + att_map_obj = attn_text[:,:, true_obj_position] + if smooth_att: + smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).cuda() + input = F.pad(att_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect') + att_map_obj = smoothing(input).squeeze(0).squeeze(0) + other_att_map_obj = att_map_obj.clone() + att_copy = att_map_obj.clone() + + for obj_box in bboxes[obj_idx]: + x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ + int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) + + + if att_map_obj[y_min: y_max, x_min: x_max].numel() == 0: + max_inside=1. + + else: + max_inside = att_map_obj[y_min: y_max, x_min: x_max].max() + total_loss += 1. - max_inside + + # find max outside the box, find in the other boxes + + att_copy[y_min: y_max, x_min: x_max] = 0. + other_att_map_obj[y_min: y_max, x_min: x_max] = 0. + + for obj_outside in range(obj_number): + if obj_outside != obj_idx: + for obj_out_box in bboxes[obj_outside]: + x_min_out, y_min_out, x_max_out, y_max_out = int(obj_out_box[0] * W), \ + int(obj_out_box[1] * H), int(obj_out_box[2] * W), int(obj_out_box[3] * H) + + # att_copy[y_min: y_max, x_min: x_max] = 0. + if other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].numel() == 0: + max_outside_one= 0 + else: + max_outside_one = other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].max() + # max_outside = max(max_outside,max_outside_one ) + att_copy[y_min_out: y_max_out, x_min_out: x_max_out] = 0. + total_loss += max_outside_one + max_background = att_copy.max() + total_loss += len(bboxes[obj_idx]) *max_background /2. + + return total_loss/obj_number + diff --git a/gligen/ldm/models/diffusion/plms.py b/gligen/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..feb54d9928c495ead66b591b271539e7b6de7d05 --- /dev/null +++ b/gligen/ldm/models/diffusion/plms.py @@ -0,0 +1,295 @@ +import torch +import numpy as np +from tqdm import tqdm +from functools import partial +from copy import deepcopy +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +import math +from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att +class PLMSSampler(object): + def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None): + super().__init__() + self.diffusion = diffusion + self.model = model + self.device = diffusion.betas.device + self.ddpm_num_timesteps = diffusion.num_timesteps + self.schedule = schedule + self.alpha_generator_func = alpha_generator_func + self.set_alpha_scale = set_alpha_scale + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + attr = attr.to(self.device) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=False): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.diffusion.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) + + self.register_buffer('betas', to_torch(self.diffusion.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.diffusion.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + + # @torch.no_grad() + def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'): + self.make_schedule(ddim_num_steps=S) + # import pdb; pdb.set_trace() + return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type) + + + # @torch.no_grad() + def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'): + + b = shape[0] + + img = input["x"] + if img == None: + img = torch.randn(shape, device=self.device) + input["x"] = img + + time_range = np.flip(self.ddim_timesteps) + total_steps = self.ddim_timesteps.shape[0] + + old_eps = [] + + if self.alpha_generator_func != None: + alphas = self.alpha_generator_func(len(time_range)) + + for i, step in enumerate(time_range): + + # set alpha and restore first conv layer + if self.alpha_generator_func != None: + self.set_alpha_scale(self.model, alphas[i]) + if alphas[i] == 0: + self.model.restore_first_conv_from_SD() + + # run + index = total_steps - i - 1 + ts = torch.full((b,), step, device=self.device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=self.device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.diffusion.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + input["x"] = img + # three loss types + if loss_type !=None and loss_type!='standard': + if input['object_position'] != []: + if loss_type=='SAR_CAR': + x = self.update_loss_self_cross( input,i, index, ts ) + elif loss_type=='SAR': + x = self.update_only_self( input,i, index, ts ) + elif loss_type=='CAR': + x = self.update_loss_only_cross( input,i, index, ts ) + input["x"] = x + img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next) + input["x"] = img + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + + return img + + def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ): + if index1 < 10: + loss_scale = 3 + max_iter = 5 + elif index1 < 20: + loss_scale = 2 + max_iter = 5 + else: + loss_scale = 0.8 + max_iter = 1 + + loss_threshold = 0.1 + max_index = 20 + x = deepcopy(input["x"]) + iteration = 0 + loss = torch.tensor(10000) + input["timesteps"] = ts + + print("optimize", index1) + self.model.train() + while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) : + print('iter', iteration) + # import pdb; pdb.set_trace() + x = x.requires_grad_(True) + input['x'] = x + e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input) + bboxes = input['boxes_att'] + object_positions = input['object_position'] + loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + loss = loss1 + loss2 + print('loss', loss, loss1, loss2) + hh = torch.autograd.backward(loss, retain_graph=True) + grad_cond = x.grad + x = x - grad_cond + x = x.detach() + iteration += 1 + torch.cuda.empty_cache() + return x + + def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'): + + if index1 < 10: + loss_scale = 3 + max_iter = 5 + elif index1 < 20: + loss_scale = 2 + max_iter = 5 + else: + loss_scale = 1 + max_iter = 1 + loss_threshold = 0.1 + + max_index = 30 + x = deepcopy(input["x"]) + iteration = 0 + loss = torch.tensor(10000) + input["timesteps"] = ts + + print("optimize", index1) + while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) : + print('iter', iteration) + x = x.requires_grad_(True) + input['x'] = x + e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input) + + bboxes = input['boxes'] + object_positions = input['object_position'] + loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + loss = loss2 + print('loss', loss) + hh = torch.autograd.backward(loss) + grad_cond = x.grad + x = x - grad_cond + x = x.detach() + iteration += 1 + torch.cuda.empty_cache() + return x + + def update_only_self(self, input,index1, index, ts,type_loss='self_accross' ): + if index1 < 10: + loss_scale = 4 + max_iter = 5 + elif index1 < 20: + loss_scale = 3 + max_iter = 5 + else: + loss_scale = 1 + max_iter = 1 + loss_threshold = 0.1 + + max_index = 30 + x = deepcopy(input["x"]) + iteration = 0 + loss = torch.tensor(10000) + input["timesteps"] = ts + + print("optimize", index1) + while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) : + print('iter', iteration) + x = x.requires_grad_(True) + input['x'] = x + e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input) + + bboxes = input['boxes'] + object_positions = input['object_position'] + loss = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes, + object_positions=object_positions, t = index1)*loss_scale + print('loss', loss) + hh = torch.autograd.backward(loss) + grad_cond = x.grad + + x = x - grad_cond + x = x.detach() + iteration += 1 + torch.cuda.empty_cache() + return x + + @torch.no_grad() + def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None): + x = deepcopy(input["x"]) + b = x.shape[0] + self.model.eval() + def get_model_output(input): + e_t, first, second, third,_,_,_ = self.model(input) + if uc is not None and guidance_scale != 1: + unconditional_input = dict(x=input["x"], timesteps=input["timesteps"], context=uc, inpainting_extra_input=None, grounding_extra_input=None) + # unconditional_input=input + e_t_uncond, _, _, _, _, _, _ = self.model( unconditional_input) + e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond) + return e_t + + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=self.device) + a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=self.device) + sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=self.device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index],device=self.device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * torch.randn_like(x) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + input["timesteps"] = t + e_t = get_model_output(input) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + input["x"] = x_prev + input["timesteps"] = t_next + e_t_next = get_model_output(input) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t + + diff --git a/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc b/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab45dbe365eb7a3fdccee25685d18676c93762b2 Binary files /dev/null and b/gligen/ldm/modules/__pycache__/attention.cpython-38.pyc differ diff --git a/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc b/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..506b77ab3cb448eba85cb89937a96c6649eebf20 Binary files /dev/null and b/gligen/ldm/modules/__pycache__/x_transformer.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93f5570e474602a9a3aea0ff5aa7e5c559b65b1d Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..801791333bcb0415e1f9fe39caf1c14d32b41019 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbb364cafadd2945347ab854337c8a7796e3f3af Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/convnext.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d2479f7e20192b033f24aef127c6a28b70944e9 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a23f0e3ce617bb29d18bf1b11247785f889758e Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/normal_grounding_net.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e5e84dfe28e5bf086472a66bf1f7eaad5516221 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5df70a6186c395c6a2baaa716781fe7b0455e624 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/text_grounding_net.cpython-38.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b41b3506c39ddcac8985749922726679da227b57 Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc differ diff --git a/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a62ce04191ef90b485b1d5d6f5cc22ffd279695c Binary files /dev/null and b/gligen/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ diff --git a/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9e4a6a210e228373cf9c9c6f3f9455029c4d145 Binary files /dev/null and b/gligen/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..663f17e84df6d157242a63297c17dc0f4aa7b926 Binary files /dev/null and b/gligen/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ diff --git a/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..732165b4f79c8221b53aaf08739ccd2134a0adff Binary files /dev/null and b/gligen/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ diff --git a/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d872efd7c5506a0bb150b9132246e35d4a5c3369 Binary files /dev/null and b/gligen/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ