| import os |
| import json |
| from IPython import display |
| import random |
| from torchvision.utils import make_grid |
| from einops import rearrange |
| import pandas as pd |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import pathlib |
| import torchvision.transforms as T |
|
|
| from .generate import generate, add_noise |
| from .prompt import sanitize |
| from .animation import DeformAnimKeys, sample_from_cv2, sample_to_cv2, anim_frame_warp, vid2frames |
| from .depth import DepthModel |
| from .colors import maintain_colors |
| from .load_images import prepare_overlay_mask |
|
|
| def next_seed(args): |
| if args.seed_behavior == 'iter': |
| args.seed += 1 |
| elif args.seed_behavior == 'fixed': |
| pass |
| else: |
| args.seed = random.randint(0, 2**32 - 1) |
| return args.seed |
|
|
| def render_image_batch(args, prompts, root): |
| args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)} |
| |
| |
| os.makedirs(args.outdir, exist_ok=True) |
| if args.save_settings or args.save_samples: |
| print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*") |
|
|
| |
| if args.save_settings: |
| filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") |
| with open(filename, "w+", encoding="utf-8") as f: |
| dictlist = dict(args.__dict__) |
| del dictlist['master_args'] |
| del dictlist['root'] |
| del dictlist['get_output_folder'] |
| json.dump(dictlist, f, ensure_ascii=False, indent=4) |
|
|
| index = 0 |
| |
| |
| init_array = [] |
| if args.use_init: |
| if args.init_image == "": |
| raise FileNotFoundError("No path was given for init_image") |
| if args.init_image.startswith('http://') or args.init_image.startswith('https://'): |
| init_array.append(args.init_image) |
| elif not os.path.isfile(args.init_image): |
| if args.init_image[-1] != "/": |
| args.init_image += "/" |
| for image in sorted(os.listdir(args.init_image)): |
| if image.split(".")[-1] in ("png", "jpg", "jpeg"): |
| init_array.append(args.init_image + image) |
| else: |
| init_array.append(args.init_image) |
| else: |
| init_array = [""] |
|
|
| |
| clear_between_batches = args.n_batch >= 32 |
|
|
| for iprompt, prompt in enumerate(prompts): |
| args.prompt = prompt |
| args.clip_prompt = prompt |
| print(f"Prompt {iprompt+1} of {len(prompts)}") |
| print(f"{args.prompt}") |
|
|
| all_images = [] |
|
|
| for batch_index in range(args.n_batch): |
| if clear_between_batches and batch_index % 32 == 0: |
| display.clear_output(wait=True) |
| print(f"Batch {batch_index+1} of {args.n_batch}") |
| |
| for image in init_array: |
| args.init_image = image |
| results = generate(args, root) |
| for image in results: |
| if args.make_grid: |
| all_images.append(T.functional.pil_to_tensor(image)) |
| if args.save_samples: |
| if args.filename_format == "{timestring}_{index}_{prompt}.png": |
| filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png" |
| else: |
| filename = f"{args.timestring}_{index:05}_{args.seed}.png" |
| image.save(os.path.join(args.outdir, filename)) |
| if args.display_samples: |
| display.display(image) |
| index += 1 |
| args.seed = next_seed(args) |
|
|
| |
| if args.make_grid: |
| grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows)) |
| grid = rearrange(grid, 'c h w -> h w c').cpu().numpy() |
| filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png" |
| grid_image = Image.fromarray(grid.astype(np.uint8)) |
| grid_image.save(os.path.join(args.outdir, filename)) |
| display.clear_output(wait=True) |
| display.display(grid_image) |
|
|
|
|
| def render_animation(args, anim_args, animation_prompts, root): |
| |
| args.prompts = animation_prompts |
|
|
| |
| keys = DeformAnimKeys(anim_args) |
|
|
| |
| start_frame = 0 |
| if anim_args.resume_from_timestring: |
| for tmp in os.listdir(args.outdir): |
| if tmp.split("_")[0] == anim_args.resume_timestring: |
| start_frame += 1 |
| start_frame = start_frame - 1 |
|
|
| |
| os.makedirs(args.outdir, exist_ok=True) |
| print(f"Saving animation frames to {args.outdir}") |
|
|
| |
| ''' |
| settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") |
| with open(settings_filename, "w+", encoding="utf-8") as f: |
| s = {**dict(args.__dict__), **dict(anim_args.__dict__)} |
| #DGSpitzer: run.py adds these three parameters |
| del s['master_args'] |
| del s['opt'] |
| del s['root'] |
| del s['get_output_folder'] |
| #print(s) |
| json.dump(s, f, ensure_ascii=False, indent=4) |
| ''' |
| |
| if anim_args.resume_from_timestring: |
| args.timestring = anim_args.resume_timestring |
|
|
| |
| prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) |
| for i, prompt in animation_prompts.items(): |
| prompt_series[int(i)] = prompt |
| prompt_series = prompt_series.ffill().bfill() |
|
|
| |
| using_vid_init = anim_args.animation_mode == 'Video Input' |
|
|
| |
| predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps |
| if predict_depths: |
| depth_model = DepthModel(root.device) |
| depth_model.load_midas(root.models_path) |
| if anim_args.midas_weight < 1.0: |
| depth_model.load_adabins(root.models_path) |
| else: |
| depth_model = None |
| anim_args.save_depth_maps = False |
|
|
| |
| turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence) |
| turbo_prev_image, turbo_prev_frame_idx = None, 0 |
| turbo_next_image, turbo_next_frame_idx = None, 0 |
|
|
| |
| prev_sample = None |
| color_match_sample = None |
| if anim_args.resume_from_timestring: |
| last_frame = start_frame-1 |
| if turbo_steps > 1: |
| last_frame -= last_frame%turbo_steps |
| path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png") |
| img = cv2.imread(path) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| prev_sample = sample_from_cv2(img) |
| if anim_args.color_coherence != 'None': |
| color_match_sample = img |
| if turbo_steps > 1: |
| turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame |
| turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx |
| start_frame = last_frame+turbo_steps |
|
|
| args.n_samples = 1 |
| frame_idx = start_frame |
| while frame_idx < anim_args.max_frames: |
| print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}") |
| noise = keys.noise_schedule_series[frame_idx] |
| strength = keys.strength_schedule_series[frame_idx] |
| contrast = keys.contrast_schedule_series[frame_idx] |
| depth = None |
| |
| |
| if turbo_steps > 1: |
| tween_frame_start_idx = max(0, frame_idx-turbo_steps) |
| for tween_frame_idx in range(tween_frame_start_idx, frame_idx): |
| tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx) |
| print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}") |
|
|
| advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx |
| advance_next = tween_frame_idx > turbo_next_frame_idx |
|
|
| if depth_model is not None: |
| assert(turbo_next_image is not None) |
| depth = depth_model.predict(turbo_next_image, anim_args) |
|
|
| if advance_prev: |
| turbo_prev_image, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device) |
| if advance_next: |
| turbo_next_image, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device) |
| |
| if args.use_mask and args.overlay_mask: |
| |
| init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) |
| if root.half_precision: |
| args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device) |
| else: |
| args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device) |
|
|
| |
| if args.use_mask: |
| if args.mask_sample is None: |
| args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape) |
| |
| mask_image, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) |
| if root.half_precision: |
| args.mask_sample = sample_from_cv2(mask_image).half().to(root.device) |
| else: |
| args.mask_sample = sample_from_cv2(mask_image).to(root.device) |
|
|
| turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx |
|
|
| if turbo_prev_image is not None and tween < 1.0: |
| img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween |
| else: |
| img = turbo_next_image |
|
|
| filename = f"{args.timestring}_{tween_frame_idx:05}.png" |
| cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR)) |
| if anim_args.save_depth_maps: |
| depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth) |
| if turbo_next_image is not None: |
| prev_sample = sample_from_cv2(turbo_next_image) |
|
|
| |
| if prev_sample is not None: |
| prev_img, depth = anim_frame_warp(prev_sample, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device) |
| |
| |
| if args.use_mask and args.overlay_mask: |
| |
| init_image_raw, _ = anim_frame_warp(args.init_sample_raw, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) |
| |
| if root.half_precision: |
| args.init_sample_raw = sample_from_cv2(init_image_raw).half().to(root.device) |
| else: |
| args.init_sample_raw = sample_from_cv2(init_image_raw).to(root.device) |
|
|
| |
| if args.use_mask: |
| if args.mask_sample is None: |
| args.mask_sample = prepare_overlay_mask(args, root, prev_sample.shape) |
| |
| mask_sample, _ = anim_frame_warp(args.mask_sample, args, anim_args, keys, frame_idx, depth_model, depth, device=root.device) |
| |
| if root.half_precision: |
| args.mask_sample = sample_from_cv2(mask_sample).half().to(root.device) |
| else: |
| args.mask_sample = sample_from_cv2(mask_sample).to(root.device) |
| |
| |
| if anim_args.color_coherence != 'None': |
| if color_match_sample is None: |
| color_match_sample = prev_img.copy() |
| else: |
| prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) |
|
|
| |
| contrast_sample = prev_img * contrast |
| |
| noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) |
|
|
| |
| args.use_init = True |
| if root.half_precision: |
| args.init_sample = noised_sample.half().to(root.device) |
| else: |
| args.init_sample = noised_sample.to(root.device) |
| args.strength = max(0.0, min(1.0, strength)) |
|
|
| |
| args.prompt = prompt_series[frame_idx] |
| args.clip_prompt = args.prompt |
| print(f"{args.prompt} {args.seed}") |
| if not using_vid_init: |
| print(f"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}") |
| print(f"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}") |
| print(f"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}") |
|
|
| |
| if using_vid_init: |
| init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:05}.jpg") |
| print(f"Using video init frame {init_frame}") |
| args.init_image = init_frame |
| if anim_args.use_mask_video: |
| mask_frame = os.path.join(args.outdir, 'maskframes', f"{frame_idx+1:05}.jpg") |
| args.mask_file = mask_frame |
|
|
| |
| sample, image = generate(args, root, frame_idx, return_latent=False, return_sample=True) |
| |
| if not using_vid_init: |
| prev_sample = sample |
| if args.use_mask and args.overlay_mask: |
| if args.init_sample_raw is None: |
| args.init_sample_raw = sample |
|
|
| if turbo_steps > 1: |
| turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx |
| turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx |
| frame_idx += turbo_steps |
| else: |
| filename = f"{args.timestring}_{frame_idx:05}.png" |
| image.save(os.path.join(args.outdir, filename)) |
| if anim_args.save_depth_maps: |
| depth = depth_model.predict(sample_to_cv2(sample), anim_args) |
| depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth) |
| frame_idx += 1 |
|
|
| display.clear_output(wait=True) |
| display.display(image) |
|
|
| args.seed = next_seed(args) |
|
|
| def render_input_video(args, anim_args, animation_prompts, root): |
| |
| video_in_frame_path = os.path.join(args.outdir, 'inputframes') |
| os.makedirs(video_in_frame_path, exist_ok=True) |
| |
| |
| print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") |
| vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) |
|
|
| |
| anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) |
| args.use_init = True |
| print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}") |
|
|
| if anim_args.use_mask_video: |
| |
| mask_in_frame_path = os.path.join(args.outdir, 'maskframes') |
| os.makedirs(mask_in_frame_path, exist_ok=True) |
|
|
| |
| print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") |
| vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) |
| args.use_mask = True |
| args.overlay_mask = True |
|
|
| render_animation(args, anim_args, animation_prompts, root) |
|
|
| def render_interpolation(args, anim_args, animation_prompts, root): |
| |
| args.prompts = animation_prompts |
|
|
| |
| os.makedirs(args.outdir, exist_ok=True) |
| print(f"Saving animation frames to {args.outdir}") |
|
|
| |
| settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") |
| with open(settings_filename, "w+", encoding="utf-8") as f: |
| s = {**dict(args.__dict__), **dict(anim_args.__dict__)} |
| del s['master_args'] |
| del s['opt'] |
| del s['root'] |
| del s['get_output_folder'] |
| json.dump(s, f, ensure_ascii=False, indent=4) |
| |
| |
| args.n_samples = 1 |
| args.seed_behavior = 'fixed' |
| prompts_c_s = [] |
|
|
| print(f"Preparing for interpolation of the following...") |
|
|
| for i, prompt in animation_prompts.items(): |
| args.prompt = prompt |
| args.clip_prompt = args.prompt |
|
|
| |
| results = generate(args, root, return_c=True) |
| c, image = results[0], results[1] |
| prompts_c_s.append(c) |
| |
| |
| display.display(image) |
| |
| args.seed = next_seed(args) |
|
|
| display.clear_output(wait=True) |
| print(f"Interpolation start...") |
|
|
| frame_idx = 0 |
|
|
| if anim_args.interpolate_key_frames: |
| for i in range(len(prompts_c_s)-1): |
| dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0] |
| if dist_frames <= 0: |
| print("key frames duplicated or reversed. interpolation skipped.") |
| return |
| else: |
| for j in range(dist_frames): |
| |
| prompt1_c = prompts_c_s[i] |
| prompt2_c = prompts_c_s[i+1] |
| args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames)) |
|
|
| |
| results = generate(args, root) |
| image = results[0] |
|
|
| filename = f"{args.timestring}_{frame_idx:05}.png" |
| image.save(os.path.join(args.outdir, filename)) |
| frame_idx += 1 |
|
|
| display.clear_output(wait=True) |
| display.display(image) |
|
|
| args.seed = next_seed(args) |
|
|
| else: |
| for i in range(len(prompts_c_s)-1): |
| for j in range(anim_args.interpolate_x_frames+1): |
| |
| prompt1_c = prompts_c_s[i] |
| prompt2_c = prompts_c_s[i+1] |
| args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1))) |
|
|
| |
| results = generate(args, root) |
| image = results[0] |
|
|
| filename = f"{args.timestring}_{frame_idx:05}.png" |
| image.save(os.path.join(args.outdir, filename)) |
| frame_idx += 1 |
|
|
| display.clear_output(wait=True) |
| display.display(image) |
|
|
| args.seed = next_seed(args) |
|
|
| |
| args.init_c = prompts_c_s[-1] |
| results = generate(args, root) |
| image = results[0] |
| filename = f"{args.timestring}_{frame_idx:05}.png" |
| image.save(os.path.join(args.outdir, filename)) |
|
|
| display.clear_output(wait=True) |
| display.display(image) |
| args.seed = next_seed(args) |
|
|
| |
| args.init_c = None |
|
|