CiaraRowles commited on
Commit
522db09
1 Parent(s): dc9d306

script to run

Browse files
Files changed (1) hide show
  1. runtemporalnetxl.py +110 -0
runtemporalnetxl.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
6
+ from diffusers.utils import load_image
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ def split_video_into_frames(video_path, frames_dir):
11
+ if not os.path.exists(frames_dir):
12
+ os.makedirs(frames_dir)
13
+ print("splitting video")
14
+ vidcap = cv2.VideoCapture(video_path)
15
+ success, image = vidcap.read()
16
+ count = 0
17
+ while success:
18
+ frame_path = os.path.join(frames_dir, f"frame{count:04d}.png")
19
+ cv2.imwrite(frame_path, image)
20
+ success, image = vidcap.read()
21
+ count += 1
22
+
23
+ def frame_number(frame_filename):
24
+ # Extract the frame number from the filename and convert it to an integer
25
+ return int(frame_filename[5:-4])
26
+
27
+ # Argument parser
28
+ parser = argparse.ArgumentParser(description='Generate images based on video frames.')
29
+ parser.add_argument('--prompt',default='a woman',help='the stable diffusion prompt')
30
+ parser.add_argument('--video_path', default='./None.mp4', help='Path to the input video file.')
31
+ parser.add_argument('--frames_dir', default='./frames', help='Directory to save the extracted video frames.')
32
+ parser.add_argument('--output_frames_dir', default='./output_frames', help='Directory to save the generated images.')
33
+ parser.add_argument('--init_image_path', default=None, help='Path to the initial conditioning image.')
34
+
35
+ args = parser.parse_args()
36
+
37
+ video_path = args.video_path
38
+ frames_dir = args.frames_dir
39
+ output_frames_dir = args.output_frames_dir
40
+ init_image_path = args.init_image_path
41
+ prompt = args.prompt
42
+
43
+ # If frames do not already exist, split video into frames
44
+ if not os.path.exists(frames_dir):
45
+ split_video_into_frames(video_path, frames_dir)
46
+
47
+ # Create output frames directory if it doesn't exist
48
+ if not os.path.exists(output_frames_dir):
49
+ os.makedirs(output_frames_dir)
50
+
51
+ # Load the initial conditioning image, if provided
52
+ if init_image_path:
53
+ print(f"using image {init_image_path}")
54
+ last_generated_image = load_image(init_image_path)
55
+ else:
56
+ initial_frame_path = os.path.join(frames_dir, "frame0000.png")
57
+ last_generated_image = load_image(initial_frame_path)
58
+
59
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ controlnet1_path = "CiaraRowles/TemporalNet1XL"
61
+ controlnet2_path = "diffusers/controlnet-canny-sdxl-1.0"
62
+
63
+ controlnet = [
64
+ ControlNetModel.from_pretrained(controlnet1_path, torch_dtype=torch.float16),
65
+ ControlNetModel.from_pretrained(controlnet2_path, torch_dtype=torch.float16)
66
+ ]
67
+ #controlnet = ControlNetModel.from_pretrained(controlnet2_path, torch_dtype=torch.float16)
68
+
69
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
70
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16
71
+ )
72
+
73
+ #pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
74
+ #pipe.enable_xformers_memory_efficient_attention()
75
+ pipe.enable_model_cpu_offload()
76
+
77
+ generator = torch.manual_seed(7)
78
+
79
+ # Loop over the saved frames in numerical order
80
+ frame_files = sorted(os.listdir(frames_dir), key=frame_number)
81
+
82
+ for i, frame_file in enumerate(frame_files):
83
+ # Use the original video frame to create Canny edge-detected image as the conditioning image for the first ControlNetModel
84
+ control_image_path = os.path.join(frames_dir, frame_file)
85
+ control_image = load_image(control_image_path)
86
+
87
+ canny_image = np.array(control_image)
88
+ canny_image = cv2.Canny(canny_image, 25, 200)
89
+ canny_image = canny_image[:, :, None]
90
+ canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
91
+ canny_image = Image.fromarray(canny_image)
92
+
93
+ # Generate image
94
+ image = pipe(
95
+ prompt, num_inference_steps=20, generator=generator, image=[last_generated_image, canny_image], controlnet_conditioning_scale=[0.6, 0.7]
96
+ #prompt, num_inference_steps=20, generator=generator, image=canny_image, controlnet_conditioning_scale=0.5
97
+ ).images[0]
98
+
99
+ # Save the generated image to output folder
100
+ output_path = os.path.join(output_frames_dir, f"output{str(i).zfill(4)}.png")
101
+ image.save(output_path)
102
+
103
+ # Save the Canny image for reference
104
+ canny_image_path = os.path.join(output_frames_dir, f"outputcanny{str(i).zfill(4)}.png")
105
+ canny_image.save(canny_image_path)
106
+
107
+ # Update the last_generated_image with the newly generated image for the next iteration
108
+ last_generated_image = image
109
+
110
+ print(f"Saved generated image for frame {i} to {output_path}")