# Copyright (c) Facebook, Inc. and its affiliates. # Modified from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py from transformers import pipeline import torchvision from PIL import Image from models.t2i_pipeline import StableDiffusionPipelineSpatialAware import torchvision.io as vision_io import torch.nn.functional as F import torch import tqdm import numpy as np import cv2 import warnings import time import tempfile import argparse import glob import multiprocessing as mp import os import random # fmt: off import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) # fmt: on warnings.filterwarnings("ignore") # constants WINDOW_NAME = "demo" def generate_image(pipe, overall_prompt, latents, get_latents=False, num_inference_steps=50, fg_masks=None, fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None): ''' Main function that calls the image diffusion model latent: input_noise from where it starts the generation get_latents: if True, returns the latents for each frame ''' image = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks, frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, output_type='pil', fg_prompt=fg_prompt, make_attention_mask_2d=True, attention_mask_block_diagonal=True).images[0] torch.save(image, "img.pt") if get_latents: video_latents = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, output_type="latent").images torch.save(video_latents, "img_latents.pt") return image, video_latents return image def save_frames(path): video, audio, video_info = vision_io.read_video( f"demo3/{path}.mp4", pts_unit='sec') # Number of frames num_frames = video.size(0) # Save each frame os.makedirs(f"demo3/{path}", exist_ok=True) for i in range(num_frames): frame = video[i, :, :, :].numpy() # Convert from C x H x W to H x W x C and from torch tensor to PIL Image # frame = frame.permute(1, 2, 0).numpy() img = Image.fromarray(frame.astype('uint8')) img.save(f"demo3/{path}/frame_{i:04d}.png") def create_boxes(): img_width = 96 img_height = 96 # initialize bboxes list sbboxes = [] # object dimensions for object_size in [20, 30, 40, 50, 60]: obj_width, obj_height = object_size, object_size # starting position start_x = 3 start_y = 4 # calculate total size occupied by the objects in the grid total_obj_width = 3 * obj_width total_obj_height = 3 * obj_height # determine horizontal and vertical spacings spacing_horizontal = (img_width - total_obj_width - start_x) // 2 spacing_vertical = (img_height - total_obj_height - start_y) // 2 for i in range(3): for j in range(3): x_start = start_x + i * (obj_width + spacing_horizontal) y_start = start_y + j * (obj_height + spacing_vertical) # Corrected to img_width to include the last pixel x_end = min(x_start + obj_width, img_width) # Corrected to img_height to include the last pixel y_end = min(y_start + obj_height, img_height) sbboxes.append([x_start, y_start, x_end, y_end]) mask_id = 0 masks_list = [] for sbbox in sbboxes: smask = torch.zeros(1, 1, 96, 96) smask[0, 0, sbbox[1]:sbbox[3], sbbox[0]:sbbox[2]] = 1.0 masks_list.append(smask) # torchvision.utils.save_image(smask, f"{SAVE_DIR}/masks/mask_{mask_id}.png") # save masks as images mask_id += 1 return masks_list def objects_list(): objects_settings = [ ("apple", "on a table"), ("ball", "in a park"), ("cat", "on a couch"), ("dog", "in a backyard"), ("elephant", "in a jungle"), ("fountain pen", "on a desk"), ("guitar", "on a stage"), ("helicopter", "in the sky"), ("island", "in the sea"), ("jar", "on a shelf"), ("kite", "in the sky"), ("lamp", "in a room"), ("motorbike", "on a road"), ("notebook", "on a table"), ("owl", "on a tree"), ("piano", "in a hall"), ("queen", "in a castle"), ("robot", "in a lab"), ("snake", "in a forest"), ("tent", "in the mountains"), ("umbrella", "on a beach"), ("violin", "in an orchestra"), ("wheel", "in a garage"), ("xylophone", "in a music class"), ("yacht", "in a marina"), ("zebra", "in a savannah"), ("aeroplane", "in the clouds"), ("bridge", "over a river"), ("computer", "in an office"), ("dragon", "in a cave"), ("egg", "in a nest"), ("flower", "in a garden"), ("globe", "in a library"), ("hat", "on a rack"), ("ice cube", "in a glass"), ("jewelry", "in a box"), ("kangaroo", "in a desert"), ("lion", "in a den"), ("mug", "on a counter"), ("nest", "on a branch"), ("octopus", "in the ocean"), ("parrot", "in a rainforest"), ("quilt", "on a bed"), ("rose", "in a vase"), ("ship", "in a dock"), ("train", "on the tracks"), ("utensils", "in a kitchen"), ("vase", "on a window sill"), ("watch", "in a store"), ("x-ray", "in a hospital"), ("yarn", "in a basket"), ("zeppelin", "above a city"), ] objects_settings.extend([ ("muffin", "on a bakery shelf"), ("notebook", "on a student's desk"), ("owl", "in a tree"), ("piano", "in a concert hall"), ("quill", "on parchment"), ("robot", "in a factory"), ("snake", "in the grass"), ("telescope", "in an observatory"), ("umbrella", "at the beach"), ("violin", "in an orchestra"), ("whale", "in the ocean"), ("xylophone", "in a music store"), ("yacht", "in a marina"), ("zebra", "on a savanna"), # Kitchen items ("spoon", "in a drawer"), ("plate", "in a cupboard"), ("cup", "on a shelf"), ("frying pan", "on a stove"), ("jar", "in the refrigerator"), # Office items ("computer", "in an office"), ("printer", "by a desk"), ("chair", "around a conference table"), ("lamp", "on a workbench"), ("calendar", "on a wall"), # Outdoor items ("bicycle", "on a street"), ("tent", "in a campsite"), ("fire", "in a fireplace"), ("mountain", "in the distance"), ("river", "through the woods"), # and so on ... ]) # To expedite the generation, you can combine themes and objects: themes = [ ("wild animals", ["tiger", "lion", "cheetah", "giraffe", "hippopotamus"], "in the wild"), ("household items", ["sofa", "tv", "clock", "vase", "photo frame"], "in a living room"), ("clothes", ["shirt", "pants", "shoes", "hat", "jacket"], "in a wardrobe"), ("musical instruments", ["drum", "trumpet", "harp", "saxophone", "tuba"], "in a band"), ("cosmic entities", ["planet", "star", "comet", "nebula", "asteroid"], "in space"), # ... add more themes ] # Using the themes to extend our list for theme_name, theme_objects, theme_location in themes: for theme_object in theme_objects: objects_settings.append((theme_object, theme_location)) # Sports equipment objects_settings.extend([ ("basketball", "on a court"), ("golf ball", "on a golf course"), ("tennis racket", "on a tennis court"), ("baseball bat", "in a stadium"), ("hockey stick", "on an ice rink"), ("football", "on a field"), ("skateboard", "in a skatepark"), ("boxing gloves", "in a boxing ring"), ("ski", "on a snowy slope"), ("surfboard", "on a beach shore"), ]) # Toys and games objects_settings.extend([ ("teddy bear", "on a child's bed"), ("doll", "in a toy store"), ("toy car", "on a carpet"), ("board game", "on a table"), ("yo-yo", "in a child's hand"), ("kite", "in the sky on a windy day"), ("Lego bricks", "on a construction table"), ("jigsaw puzzle", "partially completed"), ("rubik's cube", "on a shelf"), ("action figure", "on display"), ]) # Transportation objects_settings.extend([ ("bus", "at a bus stop"), ("motorcycle", "on a road"), ("helicopter", "landing on a pad"), ("scooter", "on a sidewalk"), ("train", "at a station"), ("bicycle", "parked by a post"), ("boat", "in a harbor"), ("tractor", "on a farm"), ("airplane", "taking off from a runway"), ("submarine", "below sea level"), ]) # Medieval theme objects_settings.extend([ ("castle", "on a hilltop"), ("knight", "riding a horse"), ("bow and arrow", "in an archery range"), ("crown", "in a treasure chest"), ("dragon", "flying over mountains"), ("shield", "next to a warrior"), ("dagger", "on a wooden table"), ("torch", "lighting a dark corridor"), ("scroll", "sealed with wax"), ("cauldron", "with bubbling potion"), ]) # Modern technology objects_settings.extend([ ("smartphone", "on a charger"), ("laptop", "in a cafe"), ("headphones", "around a neck"), ("camera", "on a tripod"), ("drone", "flying over a park"), ("USB stick", "plugged into a computer"), ("watch", "on a wrist"), ("microphone", "on a podcast desk"), ("tablet", "with a digital pen"), ("VR headset", "ready for gaming"), ]) # Nature objects_settings.extend([ ("tree", "in a forest"), ("flower", "in a garden"), ("mountain", "on a horizon"), ("cloud", "in a blue sky"), ("waterfall", "in a scenic location"), ("beach", "next to an ocean"), ("cactus", "in a desert"), ("volcano", "erupting with lava"), ("coral", "under the sea"), ("moon", "in a night sky"), ]) prompts = [f"A {obj} {setting}" for obj, setting in objects_settings] return objects_settings if __name__ == "__main__": SAVE_DIR = "/scr/image/" save_path = "img43-att_mask" torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") random_latents = torch.randn( [1, 4, 96, 96], generator=torch.Generator().manual_seed(1)).to(torch_device) try: pipe = StableDiffusionPipelineSpatialAware.from_pretrained( "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32", cache_dir="/gscratch/scrubbed/anasery/").to(torch_device) except: pipe = StableDiffusionPipelineSpatialAware.from_pretrained( "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32").to(torch_device) fg_object = "apple" # fg object stores the object to be masked # overall prompt stores the prompt overall_prompt = f"An {fg_object} on plate" os.makedirs(f"{SAVE_DIR}/{overall_prompt}", exist_ok=True) masks_list = create_boxes() # torch.save(f"{overall_prompt}+masked", "prompt.pt") obj_settings = objects_list() # 166 for obj_setting in obj_settings[120:]: fg_object = obj_setting[0] overall_prompt = f"A {obj_setting[0]} {obj_setting[1]}" print(overall_prompt) # randomly select 10 numbers from range len of masks_list selected_mask_ids = random.sample(range(len(masks_list)), 3) for mask_id in selected_mask_ids: os.makedirs( f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}", exist_ok=True) torchvision.utils.save_image( masks_list[mask_id][0][0], f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/mask.png") for frozen_steps in range(0, 5): img = generate_image(pipe, overall_prompt, random_latents, get_latents=False, num_inference_steps=50, fg_masks=masks_list[mask_id].to( torch_device), fg_masked_latents=None, frozen_steps=frozen_steps, frozen_prompt=None, fg_prompt=fg_object) img.save( f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/{frozen_steps}.png")