Spaces:
Sleeping
Sleeping
File size: 12,830 Bytes
44f2ca8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 |
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
from transformers import pipeline
import torchvision
from PIL import Image
from models.t2i_pipeline import StableDiffusionPipelineSpatialAware
import torchvision.io as vision_io
import torch.nn.functional as F
import torch
import tqdm
import numpy as np
import cv2
import warnings
import time
import tempfile
import argparse
import glob
import multiprocessing as mp
import os
import random
# fmt: off
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: on
warnings.filterwarnings("ignore")
# constants
WINDOW_NAME = "demo"
def generate_image(pipe, overall_prompt, latents, get_latents=False, num_inference_steps=50, fg_masks=None,
fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
'''
Main function that calls the image diffusion model
latent: input_noise from where it starts the generation
get_latents: if True, returns the latents for each frame
'''
image = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks,
frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, output_type='pil',
fg_prompt=fg_prompt, make_attention_mask_2d=True, attention_mask_block_diagonal=True).images[0]
torch.save(image, "img.pt")
if get_latents:
video_latents = pipe(overall_prompt, latents=latents,
num_inference_steps=num_inference_steps, output_type="latent").images
torch.save(video_latents, "img_latents.pt")
return image, video_latents
return image
def save_frames(path):
video, audio, video_info = vision_io.read_video(
f"demo3/{path}.mp4", pts_unit='sec')
# Number of frames
num_frames = video.size(0)
# Save each frame
os.makedirs(f"demo3/{path}", exist_ok=True)
for i in range(num_frames):
frame = video[i, :, :, :].numpy()
# Convert from C x H x W to H x W x C and from torch tensor to PIL Image
# frame = frame.permute(1, 2, 0).numpy()
img = Image.fromarray(frame.astype('uint8'))
img.save(f"demo3/{path}/frame_{i:04d}.png")
def create_boxes():
img_width = 96
img_height = 96
# initialize bboxes list
sbboxes = []
# object dimensions
for object_size in [20, 30, 40, 50, 60]:
obj_width, obj_height = object_size, object_size
# starting position
start_x = 3
start_y = 4
# calculate total size occupied by the objects in the grid
total_obj_width = 3 * obj_width
total_obj_height = 3 * obj_height
# determine horizontal and vertical spacings
spacing_horizontal = (img_width - total_obj_width - start_x) // 2
spacing_vertical = (img_height - total_obj_height - start_y) // 2
for i in range(3):
for j in range(3):
x_start = start_x + i * (obj_width + spacing_horizontal)
y_start = start_y + j * (obj_height + spacing_vertical)
# Corrected to img_width to include the last pixel
x_end = min(x_start + obj_width, img_width)
# Corrected to img_height to include the last pixel
y_end = min(y_start + obj_height, img_height)
sbboxes.append([x_start, y_start, x_end, y_end])
mask_id = 0
masks_list = []
for sbbox in sbboxes:
smask = torch.zeros(1, 1, 96, 96)
smask[0, 0, sbbox[1]:sbbox[3], sbbox[0]:sbbox[2]] = 1.0
masks_list.append(smask)
# torchvision.utils.save_image(smask, f"{SAVE_DIR}/masks/mask_{mask_id}.png") # save masks as images
mask_id += 1
return masks_list
def objects_list():
objects_settings = [
("apple", "on a table"),
("ball", "in a park"),
("cat", "on a couch"),
("dog", "in a backyard"),
("elephant", "in a jungle"),
("fountain pen", "on a desk"),
("guitar", "on a stage"),
("helicopter", "in the sky"),
("island", "in the sea"),
("jar", "on a shelf"),
("kite", "in the sky"),
("lamp", "in a room"),
("motorbike", "on a road"),
("notebook", "on a table"),
("owl", "on a tree"),
("piano", "in a hall"),
("queen", "in a castle"),
("robot", "in a lab"),
("snake", "in a forest"),
("tent", "in the mountains"),
("umbrella", "on a beach"),
("violin", "in an orchestra"),
("wheel", "in a garage"),
("xylophone", "in a music class"),
("yacht", "in a marina"),
("zebra", "in a savannah"),
("aeroplane", "in the clouds"),
("bridge", "over a river"),
("computer", "in an office"),
("dragon", "in a cave"),
("egg", "in a nest"),
("flower", "in a garden"),
("globe", "in a library"),
("hat", "on a rack"),
("ice cube", "in a glass"),
("jewelry", "in a box"),
("kangaroo", "in a desert"),
("lion", "in a den"),
("mug", "on a counter"),
("nest", "on a branch"),
("octopus", "in the ocean"),
("parrot", "in a rainforest"),
("quilt", "on a bed"),
("rose", "in a vase"),
("ship", "in a dock"),
("train", "on the tracks"),
("utensils", "in a kitchen"),
("vase", "on a window sill"),
("watch", "in a store"),
("x-ray", "in a hospital"),
("yarn", "in a basket"),
("zeppelin", "above a city"),
]
objects_settings.extend([
("muffin", "on a bakery shelf"),
("notebook", "on a student's desk"),
("owl", "in a tree"),
("piano", "in a concert hall"),
("quill", "on parchment"),
("robot", "in a factory"),
("snake", "in the grass"),
("telescope", "in an observatory"),
("umbrella", "at the beach"),
("violin", "in an orchestra"),
("whale", "in the ocean"),
("xylophone", "in a music store"),
("yacht", "in a marina"),
("zebra", "on a savanna"),
# Kitchen items
("spoon", "in a drawer"),
("plate", "in a cupboard"),
("cup", "on a shelf"),
("frying pan", "on a stove"),
("jar", "in the refrigerator"),
# Office items
("computer", "in an office"),
("printer", "by a desk"),
("chair", "around a conference table"),
("lamp", "on a workbench"),
("calendar", "on a wall"),
# Outdoor items
("bicycle", "on a street"),
("tent", "in a campsite"),
("fire", "in a fireplace"),
("mountain", "in the distance"),
("river", "through the woods"),
# and so on ...
])
# To expedite the generation, you can combine themes and objects:
themes = [
("wild animals", ["tiger", "lion", "cheetah",
"giraffe", "hippopotamus"], "in the wild"),
("household items", ["sofa", "tv", "clock",
"vase", "photo frame"], "in a living room"),
("clothes", ["shirt", "pants", "shoes",
"hat", "jacket"], "in a wardrobe"),
("musical instruments", ["drum", "trumpet",
"harp", "saxophone", "tuba"], "in a band"),
("cosmic entities", ["planet", "star",
"comet", "nebula", "asteroid"], "in space"),
# ... add more themes
]
# Using the themes to extend our list
for theme_name, theme_objects, theme_location in themes:
for theme_object in theme_objects:
objects_settings.append((theme_object, theme_location))
# Sports equipment
objects_settings.extend([
("basketball", "on a court"),
("golf ball", "on a golf course"),
("tennis racket", "on a tennis court"),
("baseball bat", "in a stadium"),
("hockey stick", "on an ice rink"),
("football", "on a field"),
("skateboard", "in a skatepark"),
("boxing gloves", "in a boxing ring"),
("ski", "on a snowy slope"),
("surfboard", "on a beach shore"),
])
# Toys and games
objects_settings.extend([
("teddy bear", "on a child's bed"),
("doll", "in a toy store"),
("toy car", "on a carpet"),
("board game", "on a table"),
("yo-yo", "in a child's hand"),
("kite", "in the sky on a windy day"),
("Lego bricks", "on a construction table"),
("jigsaw puzzle", "partially completed"),
("rubik's cube", "on a shelf"),
("action figure", "on display"),
])
# Transportation
objects_settings.extend([
("bus", "at a bus stop"),
("motorcycle", "on a road"),
("helicopter", "landing on a pad"),
("scooter", "on a sidewalk"),
("train", "at a station"),
("bicycle", "parked by a post"),
("boat", "in a harbor"),
("tractor", "on a farm"),
("airplane", "taking off from a runway"),
("submarine", "below sea level"),
])
# Medieval theme
objects_settings.extend([
("castle", "on a hilltop"),
("knight", "riding a horse"),
("bow and arrow", "in an archery range"),
("crown", "in a treasure chest"),
("dragon", "flying over mountains"),
("shield", "next to a warrior"),
("dagger", "on a wooden table"),
("torch", "lighting a dark corridor"),
("scroll", "sealed with wax"),
("cauldron", "with bubbling potion"),
])
# Modern technology
objects_settings.extend([
("smartphone", "on a charger"),
("laptop", "in a cafe"),
("headphones", "around a neck"),
("camera", "on a tripod"),
("drone", "flying over a park"),
("USB stick", "plugged into a computer"),
("watch", "on a wrist"),
("microphone", "on a podcast desk"),
("tablet", "with a digital pen"),
("VR headset", "ready for gaming"),
])
# Nature
objects_settings.extend([
("tree", "in a forest"),
("flower", "in a garden"),
("mountain", "on a horizon"),
("cloud", "in a blue sky"),
("waterfall", "in a scenic location"),
("beach", "next to an ocean"),
("cactus", "in a desert"),
("volcano", "erupting with lava"),
("coral", "under the sea"),
("moon", "in a night sky"),
])
prompts = [f"A {obj} {setting}" for obj, setting in objects_settings]
return objects_settings
if __name__ == "__main__":
SAVE_DIR = "/scr/image/"
save_path = "img43-att_mask"
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random_latents = torch.randn(
[1, 4, 96, 96], generator=torch.Generator().manual_seed(1)).to(torch_device)
try:
pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32", cache_dir="/gscratch/scrubbed/anasery/").to(torch_device)
except:
pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32").to(torch_device)
fg_object = "apple" # fg object stores the object to be masked
# overall prompt stores the prompt
overall_prompt = f"An {fg_object} on plate"
os.makedirs(f"{SAVE_DIR}/{overall_prompt}", exist_ok=True)
masks_list = create_boxes()
# torch.save(f"{overall_prompt}+masked", "prompt.pt")
obj_settings = objects_list() # 166
for obj_setting in obj_settings[120:]:
fg_object = obj_setting[0]
overall_prompt = f"A {obj_setting[0]} {obj_setting[1]}"
print(overall_prompt)
# randomly select 10 numbers from range len of masks_list
selected_mask_ids = random.sample(range(len(masks_list)), 3)
for mask_id in selected_mask_ids:
os.makedirs(
f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}", exist_ok=True)
torchvision.utils.save_image(
masks_list[mask_id][0][0], f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/mask.png")
for frozen_steps in range(0, 5):
img = generate_image(pipe, overall_prompt, random_latents, get_latents=False, num_inference_steps=50, fg_masks=masks_list[mask_id].to(
torch_device), fg_masked_latents=None, frozen_steps=frozen_steps, frozen_prompt=None, fg_prompt=fg_object)
img.save(
f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/{frozen_steps}.png")
|