File size: 12,012 Bytes
653ce35
7ca5351
91b1f35
c93a0cb
771cad6
c93a0cb
120ac54
c93a0cb
 
454eedf
39489bd
0cbec0b
854b688
 
 
acf84db
 
 
 
 
 
 
 
 
 
91b1f35
 
 
 
 
 
3a66e37
91b1f35
 
 
 
3a66e37
1c8c6b0
eda2433
 
 
 
 
b827d8f
93f9e0a
6e02ec8
 
eda2433
35f86ae
 
eda2433
35f86ae
3a66e37
833e264
 
 
 
65a9a4b
28ef21d
833e264
28ef21d
833e264
1c8c6b0
69ec2ea
9f98966
1c8c6b0
 
 
 
 
9f98966
 
 
1c8c6b0
de0aaee
9f98966
 
 
 
 
 
 
1c8c6b0
9f98966
 
 
90e6ba4
9f98966
 
 
69ec2ea
1c8c6b0
618e51c
d5a4d02
 
 
 
 
 
 
454eedf
16450f6
7aa5be2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771cad6
fe1dcee
 
d5a4d02
 
6b906fc
 
 
 
cab2223
 
b015e22
dd0e95b
cab2223
 
b89ba4a
dd0e95b
 
 
cb7d2da
dd0e95b
 
3a51f78
 
 
dd0e95b
 
70de192
 
 
dd0e95b
d5a4d02
2d6e7e3
 
b015e22
dd0e95b
b440bd5
 
dd0e95b
 
cdd5bef
 
b015e22
 
 
 
 
69ec2ea
 
b015e22
771cad6
b015e22
771cad6
967b7dd
a4ad89d
 
6e159a1
84e6bd9
 
 
b015e22
967b7dd
 
 
 
d5a4d02
 
 
967b7dd
d3c88cc
 
 
7aa5be2
967b7dd
771cad6
eadc7f8
1c8c6b0
 
65a9a4b
e23155a
771cad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65a9a4b
 
6a046f7
65a9a4b
 
7bb183a
006cb6c
7bb183a
65a9a4b
91b1f35
 
 
7bb183a
b827d8f
91b1f35
 
 
 
 
 
ea76f7d
91b1f35
 
 
bce3142
91b1f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b03a4d
 
91b1f35
1c8c6b0
 
7495fff
 
98aad5a
39489bd
7ca5351
0269ee9
 
b827d8f
771cad6
0269ee9
eadc7f8
106f93a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import gradio as gr
import os
import shutil
import subprocess
from share_btn import community_icon_html, loading_icon_html, share_js
import cv2
import numpy as np
from moviepy.editor import VideoFileClip, concatenate_videoclips
import math

from huggingface_hub import snapshot_download

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


model_ids = [
    'runwayml/stable-diffusion-v1-5',
    'lllyasviel/sd-controlnet-depth', 
    'lllyasviel/sd-controlnet-canny', 
    'lllyasviel/sd-controlnet-openpose',
]
for model_id in model_ids:
    model_name = model_id.split('/')[-1]
    snapshot_download(model_id, local_dir=f'checkpoints/{model_name}')

def load_model(model_id):
    local_dir = f'checkpoints/stable-diffusion-v1-5'
    # Check if the directory exists
    if os.path.exists(local_dir):
        # Delete the directory if it exists
        shutil.rmtree(local_dir)

    model_name = model_id.split('/')[-1]
    snapshot_download(model_id, local_dir=f'checkpoints/{model_name}')
    os.rename(f'checkpoints/{model_name}', f'checkpoints/stable-diffusion-v1-5')
    return "model loaded"

def get_frame_count(filepath):
    if filepath is not None:
        video = cv2.VideoCapture(filepath) 
        frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    
        video.release()

        # LIMITS
        if frame_count > 24 :
            frame_count = 24 # limit to 24 frames to avoid cuDNN errors

        return gr.update(maximum=frame_count)

    else:
        return gr.update(value=1, maximum=12 )

def get_video_dimension(filepath):
    video = cv2.VideoCapture(filepath)
    width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(video.get(cv2.CAP_PROP_FPS))
    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    video.release()
    return width, height, fps, frame_count

def resize_video(input_vid, output_vid, width, height, fps):
    print(f"RESIZING ...")
    # Open the input video file
    video = cv2.VideoCapture(input_vid)

    # Get the original video's width and height
    original_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    original_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Create a VideoWriter object to write the resized video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for the output video
    output_video = cv2.VideoWriter(output_vid, fourcc, fps, (width, height))

    while True:
        # Read a frame from the input video
        ret, frame = video.read()
        if not ret:
            break

        # Resize the frame to the desired dimensions
        resized_frame = cv2.resize(frame, (width, height))

        # Write the resized frame to the output video file
        output_video.write(resized_frame)

    # Release the video objects
    video.release()
    output_video.release()
    print(f"RESIZE VIDEO DONE!")
    return output_vid

def make_nearest_multiple_of_32(number):
    remainder = number % 32
    if remainder <= 16:
        number -= remainder
    else:
        number += 32 - remainder
    return number 

def change_video_fps(input_path):
    print(f"CHANGING FIANL OUTPUT FPS")
    cap = cv2.VideoCapture(input_path)
    # Check if the final file already exists
    if os.path.exists('output_video.mp4'):
        # Delete the existing file
        os.remove('output_video.mp4')
    output_path = 'output_video.mp4'
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_fps = 12
    output_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    out = cv2.VideoWriter(output_path, fourcc, output_fps, output_size)
    
    frame_count = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Write the current frame to the output video multiple times to increase the frame rate
        for _ in range(output_fps // 8):
            out.write(frame)
        
        frame_count += 1
        print(f'Processed frame {frame_count}')

    cap.release()
    out.release()
    cv2.destroyAllWindows()

    return 'output_video.mp4'

def run_inference(prompt, video_path, condition, video_length, seed):
    
    seed = math.floor(seed)
    o_width = get_video_dimension(video_path)[0]
    o_height = get_video_dimension(video_path)[1]

    # Prepare dimensions
    if o_width > 512 :
        # Calculate the new height while maintaining the aspect ratio
        n_height = int(o_height / o_width * 512)
        n_width = 512

    # Make sure new dimensions are multipe of 32
    r_width = make_nearest_multiple_of_32(n_width)
    r_height = make_nearest_multiple_of_32(n_height)
    print(f"multiple of 32 sizes : {r_width}x{r_height}")

    # Get FPS of original video input
    original_fps = get_video_dimension(video_path)[2] 
    if original_fps > 12 :
        print(f"FPS is too high: {original_fps}")          
        target_fps = 12
    else : 
        target_fps = original_fps
    print(f"NEW INPUT FPS: {target_fps}, NEW LENGTH: {video_length}")
    
    # Check if the resized file already exists
    if os.path.exists('resized.mp4'):
        # Delete the existing file
        os.remove('resized.mp4')
    
    resized = resize_video(video_path, 'resized.mp4', r_width, r_height, target_fps)
    resized_video_fcount = get_video_dimension(resized)[3]
    print(f"RESIZED VIDEO FRAME COUNT: {resized_video_fcount}")

    # Make sure new total frame count is enough to handle chosen video length
    if video_length > resized_video_fcount :
        video_length = resized_video_fcount
    # video_length = int((target_fps * video_length) / original_fps)
    
    output_path = 'output/'
    os.makedirs(output_path, exist_ok=True)
            
    # Check if the file already exists
    if os.path.exists(os.path.join(output_path, f"result.mp4")):
        # Delete the existing file
        os.remove(os.path.join(output_path, f"result.mp4"))

    print(f"RUNNING INFERENCE ...")
    if video_length > 12:
        command = f"python inference.py --prompt '{prompt}' --inference_steps 50 --condition '{condition}' --video_path '{resized}' --output_path '{output_path}' --temp_chunk_path 'result' --width {r_width} --height {r_height} --fps {target_fps} --seed {seed} --video_length {video_length} --smoother_steps 19 20 --is_long_video"
    else:
        command = f"python inference.py --prompt '{prompt}' --inference_steps 50 --condition '{condition}' --video_path '{resized}' --output_path '{output_path}' --temp_chunk_path 'result'  --width {r_width} --height {r_height} --fps {target_fps} --seed {seed} --video_length {video_length} --smoother_steps 19 20"
    
    try:
        subprocess.run(command, shell=True)
    except cuda.Error as e:
        return f"CUDA Error: {e}", None
    except RuntimeError as e:
        return f"Runtime Error: {e}", None

    # Construct the video path
    video_path_output = os.path.join(output_path, f"result.mp4")

    # Resize to original video input size
    #o_width = get_video_dimension(video_path)[0]
    #o_height = get_video_dimension(video_path)[1]
    #resize_video(video_path_output, 'resized_final.mp4', o_width, o_height, target_fps)

    # Check generated video FPS
    gen_fps = get_video_dimension(video_path_output)[2] 
    print(f"GEN VIDEO FPS: {gen_fps}")
    final = change_video_fps(video_path_output)
    print(f"FINISHED !")
    
    return final, gr.Group.update(visible=True)
 

css="""
#col-container {max-width: 810px; margin-left: auto; margin-right: auto;}
.animate-spin {
  animation: spin 1s linear infinite;
}
@keyframes spin {
  from {
      transform: rotate(0deg);
  }
  to {
      transform: rotate(360deg);
  }
}
#share-btn-container {
  display: flex; 
  padding-left: 0.5rem !important; 
  padding-right: 0.5rem !important; 
  background-color: #000000; 
  justify-content: center; 
  align-items: center; 
  border-radius: 9999px !important; 
  max-width: 13rem;
}
#share-btn-container:hover {
  background-color: #060606;
}
#share-btn {
  all: initial; 
  color: #ffffff;
  font-weight: 600; 
  cursor:pointer; 
  font-family: 'IBM Plex Sans', sans-serif; 
  margin-left: 0.5rem !important; 
  padding-top: 0.5rem !important; 
  padding-bottom: 0.5rem !important;
  right:0;
}
#share-btn * {
  all: unset;
}
#share-btn-container div:nth-child(-n+2){
  width: auto !important;
  min-height: 0px !important;
}
#share-btn-container .wrap {
  display: none !important;
}
#share-btn-container.hidden {
  display: none!important;
}
img[src*='#center'] { 
    display: block;
    margin: auto;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
            <h1 style="text-align: center;">ControlVideo</h1>
            <p style="text-align: center;"> Pytorch implementation of "<a href='https://github.com/chenxwh/ControlVideo' target='_blank'>ControlVideo</a>: Training-free Controllable Text-to-Video Generation" </p>
            
            [![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center)](https://huggingface.co/spaces/fffiloni/ControlVideo?duplicate=true)
        """)
        
        with gr.Column():
            with gr.Row():
                video_path = gr.Video(label="Input video", source="upload", type="filepath", visible=True, elem_id="video-in")
                with gr.Column():
                    video_res = gr.Video(label="result", elem_id="video-out")
                    with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
                        community_icon = gr.HTML(community_icon_html)
                        loading_icon = gr.HTML(loading_icon_html)
                        share_button = gr.Button("Share to community", elem_id="share-btn")
            with gr.Row():
                    chosen_model = gr.Dropdown(label="Diffusion model (*1.5)", choices=['runwayml/stable-diffusion-v1-5','nitrosocke/Ghibli-Diffusion'], value="runwayml/stable-diffusion-v1-5", allow_custom_value=True)
                    model_status = gr.Textbox(label="status")
                    load_model_btn = gr.Button("load model (optional)")
            prompt = gr.Textbox(label="prompt", info="If you loaded a custom model, do not forget to include Prompt trigger", elem_id="prompt-in")
            with gr.Column():
                video_length = gr.Slider(label="Video length", info="How many frames do you want to process ? For demo purpose, max is set to 24", minimum=1, maximum=12, step=1, value=2)
                with gr.Row():
                    condition = gr.Dropdown(label="Condition", choices=["depth", "canny", "pose"], value="depth")
                    seed = gr.Number(label="seed", value=42)
            submit_btn = gr.Button("Submit")
        
            
            
            gr.Examples(
                examples=[["Indiana Jones moonwalk in the wild jungle", "./examples/moonwalk.mp4", 'depth', 24, 192837465]],
                fn=run_inference,
                inputs=[prompt,
                         video_path,
                         condition,
                         video_length,
                         seed
                        ],
                outputs=[video_res, share_group],
                cache_examples=False
            )
                
    share_button.click(None, [], [], _js=share_js)
    load_model_btn.click(fn=load_model, inputs=[chosen_model], outputs=[model_status], queue=False)
    video_path.change(fn=get_frame_count,
                      inputs=[video_path],
                      outputs=[video_length],
                      queue=False
                     )
    submit_btn.click(fn=run_inference, 
                     inputs=[prompt,
                             video_path,
                             condition,
                             video_length,
                             seed
                            ],
                    outputs=[video_res, share_group])

demo.queue(max_size=12).launch()