svjack commited on
Commit
98fed26
1 Parent(s): d3ede38

Upload gradio_app_with_frames.py

Browse files
Files changed (1) hide show
  1. gradio_app_with_frames.py +138 -0
gradio_app_with_frames.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import uuid
5
+ import subprocess
6
+ import gradio as gr
7
+ import shutil
8
+ from glob import glob
9
+ from huggingface_hub import snapshot_download, hf_hub_download
10
+ from moviepy.editor import VideoFileClip # Import MoviePy
11
+
12
+ # Download models
13
+ os.makedirs("pretrained_weights", exist_ok=True)
14
+
15
+ # List of subdirectories to create inside "checkpoints"
16
+ subfolders = [
17
+ "stable-video-diffusion-img2vid-xt"
18
+ ]
19
+
20
+ # Create each subdirectory
21
+ for subfolder in subfolders:
22
+ os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)
23
+
24
+ snapshot_download(
25
+ repo_id="stabilityai/stable-video-diffusion-img2vid",
26
+ local_dir="./pretrained_weights/stable-video-diffusion-img2vid-xt"
27
+ )
28
+
29
+ snapshot_download(
30
+ repo_id="Yhmeng1106/anidoc",
31
+ local_dir="./pretrained_weights"
32
+ )
33
+
34
+ hf_hub_download(
35
+ repo_id="facebook/cotracker",
36
+ filename="cotracker2.pth",
37
+ local_dir="./pretrained_weights"
38
+ )
39
+
40
+ def generate(control_sequence, ref_image):
41
+ control_image = control_sequence # "data_test/sample4.mp4"
42
+ ref_image = ref_image # "data_test/sample4.png"
43
+ unique_id = str(uuid.uuid4())
44
+ output_dir = f"results_{unique_id}"
45
+
46
+ try:
47
+ # Use MoviePy to get the number of frames in the control_sequence video
48
+ video_clip = VideoFileClip(control_image)
49
+ num_frames = int(video_clip.fps * video_clip.duration) # Calculate total frames
50
+ video_clip.close() # Close the video clip to free resources
51
+
52
+ # Run the inference command
53
+ subprocess.run(
54
+ [
55
+ "python", "scripts_infer/anidoc_inference.py",
56
+ "--all_sketch",
57
+ "--matching",
58
+ "--tracking",
59
+ "--control_image", f"{control_image}",
60
+ "--ref_image", f"{ref_image}",
61
+ "--output_dir", f"{output_dir}",
62
+ "--max_point", "10",
63
+ "--num_frames", str(num_frames) # Pass the calculated num_frames
64
+ ],
65
+ check=True
66
+ )
67
+
68
+ # Search for the mp4 file in a subfolder of output_dir
69
+ output_video = glob(os.path.join(output_dir, "*.mp4"))
70
+ print(output_video)
71
+
72
+ if output_video:
73
+ output_video_path = output_video[0] # Get the first match
74
+ else:
75
+ output_video_path = None
76
+
77
+ print(output_video_path)
78
+ return output_video_path
79
+
80
+ except subprocess.CalledProcessError as e:
81
+ raise gr.Error(f"Error during inference: {str(e)}")
82
+ except Exception as e:
83
+ raise gr.Error(f"Error processing video: {str(e)}")
84
+
85
+ css = """
86
+ div#col-container{
87
+ margin: 0 auto;
88
+ max-width: 982px;
89
+ }
90
+ """
91
+ with gr.Blocks(css=css) as demo:
92
+ with gr.Column(elem_id="col-container"):
93
+ gr.Markdown("# AniDoc: Animation Creation Made Easier")
94
+ gr.Markdown("AniDoc colorizes a sequence of sketches based on a character design reference with high fidelity, even when the sketches significantly differ in pose and scale.")
95
+ gr.HTML("""
96
+ <div style="display:flex;column-gap:4px;">
97
+ <a href="https://github.com/yihao-meng/AniDoc">
98
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
99
+ </a>
100
+ <a href="https://yihao-meng.github.io/AniDoc_demo/">
101
+ <img src='https://img.shields.io/badge/Project-Page-green'>
102
+ </a>
103
+ <a href="https://arxiv.org/pdf/2412.14173">
104
+ <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
105
+ </a>
106
+ <a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true">
107
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
108
+ </a>
109
+ <a href="https://huggingface.co/fffiloni">
110
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
111
+ </a>
112
+ </div>
113
+ """)
114
+ with gr.Row():
115
+ with gr.Column():
116
+ control_sequence = gr.Video(label="Control Sequence", format="mp4")
117
+ ref_image = gr.Image(label="Reference Image", type="filepath")
118
+ submit_btn = gr.Button("Submit")
119
+ with gr.Column():
120
+ video_result = gr.Video(label="Result")
121
+
122
+ gr.Examples(
123
+ examples=[
124
+ ["data_test/sample1.mp4", "data_test/sample1.png"],
125
+ ["data_test/sample2.mp4", "data_test/sample2.png"],
126
+ ["data_test/sample3.mp4", "data_test/sample3.png"],
127
+ ["data_test/sample4.mp4", "data_test/sample4.png"]
128
+ ],
129
+ inputs=[control_sequence, ref_image]
130
+ )
131
+
132
+ submit_btn.click(
133
+ fn=generate,
134
+ inputs=[control_sequence, ref_image],
135
+ outputs=[video_result]
136
+ )
137
+
138
+ demo.queue().launch(show_api=False, show_error=True, share=True)