ziyangmai commited on
Commit
6dd3263
1 Parent(s): 15761f7

update page

Browse files
Files changed (1) hide show
  1. app.py +170 -141
app.py CHANGED
@@ -1,154 +1,183 @@
1
  import gradio as gr
2
- import numpy as np
 
 
3
  import random
 
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
 
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import torch
4
+ import tempfile
5
  import random
6
+ import string
7
+ import json
8
+ from omegaconf import OmegaConf,ListConfig
9
 
 
 
 
10
 
11
+ from train import main as train_main
12
+ from inference import inference as inference_main
13
+ # 模拟训练函数
14
+ def train_model(video, config):
15
+ output_dir = 'results'
16
+ os.makedirs(output_dir, exist_ok=True)
17
+ cur_save_dir = os.path.join(output_dir, str(len(os.listdir(output_dir))).zfill(2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ config.dataset.single_video_path = video
20
+ config.train.output_dir = cur_save_dir
21
+
22
+ # copy video to cur_save_dir
23
+ video_name = 'source.mp4'
24
+ video_path = os.path.join(cur_save_dir, video_name)
25
+ os.system(f"cp {video} {video_path}")
26
 
27
+ train_main(config)
28
+ # cur_save_dir = 'results/06'
29
+ return cur_save_dir
30
 
31
+ # 模拟推理函数
32
+ def inference_model(text, checkpoint, inference_steps, video_type,seed):
33
+
34
+ checkpoint = os.path.join('results',checkpoint)
 
 
 
35
 
36
+ embedding_dir = '/'.join(checkpoint.split('/')[:-1])
37
+ video_round = checkpoint.split('/')[-1]
 
 
 
 
 
38
 
39
+ video_path = inference_main(
40
+ embedding_dir=embedding_dir,
41
+ prompt=text,
42
+ video_round=video_round,
43
+ save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]),
44
+ motion_type=video_type,
45
+ seed=seed,
46
+ inference_steps=inference_steps
47
+ )
48
 
49
+ return video_path
50
+
51
+
52
+ # 获取checkpoint文件列表
53
+ def get_checkpoints(checkpoint_dir):
54
+
55
+ checkpoints = []
56
+ for root, dirs, files in os.walk(checkpoint_dir):
57
+ for file in files:
58
+ if file == 'motion_embed.pt':
59
+ checkpoints.append('/'.join(root.split('/')[-2:]))
60
+ return checkpoints
61
+
62
+
63
+ def extract_combinations(motion_embeddings_combinations):
64
+ assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required"
65
+ combinations = []
66
+ for combination in motion_embeddings_combinations:
67
+ name, resolution = combination.split(" ")
68
+ combinations.append([name, int(resolution)])
69
+ return combinations
70
+
71
+
72
+ def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):
73
+
74
+ default_config = OmegaConf.load('configs/config.yaml')
75
+
76
+ default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
77
+ default_config.model.unet = unet
78
+ default_config.train.checkpointing_steps = checkpointing_steps
79
+ default_config.train.max_train_steps = max_train_steps
80
+
81
+ return default_config
82
+
83
+
84
+ def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps):
85
+
86
+ default_config = OmegaConf.load('configs/config.yaml')
87
+
88
+ default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations))
89
+ default_config.model.unet = unet
90
+ default_config.train.checkpointing_steps = checkpointing_steps
91
+ default_config.train.max_train_steps = max_train_steps
92
+
93
+ return default_config
94
+
95
+
96
+ def update_preview_video(checkpoint_dir):
97
+ # get the parent dir of the checkpoint
98
+ parent_dir = '/'.join(checkpoint_dir.split('/')[:-1])
99
+ return gr.update(value=f'results/{parent_dir}/source.mp4')
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
+ inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640']
104
+ default_motion_embeddings_combinations = ['down 1280','up 1280']
105
+
106
+ examples_train = [
107
+ 'assets/train/car_turn.mp4',
108
+ 'assets/train/pan_up.mp4',
109
+ 'assets/train/run_up.mp4',
110
+ 'assets/train/train_ride.mp4',
111
+ 'assets/train/orbit_shot.mp4',
112
+ 'assets/train/dolly_zoom_out.mp4',
113
+ 'assets/train/santa_dance.mp4',
114
+ ]
115
+
116
+ examples_inference = [
117
+ ['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'],
118
+ ['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint-100'],
119
+ ['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint-300'],
120
+
121
+ ['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'],
122
+ ['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint-200'],
123
+ ['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'],
124
+ ['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint-200'],
125
+ ]
126
+
127
+ # 创建Gradio界面
128
+ with gr.Blocks() as demo:
129
+ with gr.Tab("Train"):
130
+ with gr.Row():
131
+ with gr.Column():
132
+ video_input = gr.Video(label="Upload Video")
133
+ train_button = gr.Button("Train")
134
+ with gr.Column():
135
+ checkpoint_output = gr.Textbox(label="Checkpoint Directory")
136
+
137
+ with gr.Accordion("Advanced Settings", open=False):
138
+ with gr.Row():
139
+ motion_embeddings_combinations = gr.Dropdown(label="Motion Embeddings Combinations", choices=inject_motion_embeddings_combinations, multiselect=True,value=default_motion_embeddings_combinations)
140
+ unet_dropdown = gr.Dropdown(label="Unet", choices=["videoCrafter2", "zeroscope_v2_576w"], value="videoCrafter2")
141
+ checkpointing_steps = gr.Dropdown(label="Checkpointing Steps",choices=[100,50],value=100)
142
+ max_train_steps = gr.Slider(label="Max Train Steps", minimum=200,maximum=500,value=200,step=50)
143
+
144
+ # examples
145
+ gr.Examples(examples=examples_train,inputs=[video_input])
146
+
147
+
148
+ train_button.click(
149
+ lambda video, mec, u, cs, mts: train_model(video, generate_config_train(mec, u, cs, mts)),
150
+ inputs=[video_input, motion_embeddings_combinations, unet_dropdown, checkpointing_steps, max_train_steps],
151
+ outputs=checkpoint_output
152
+ )
153
+
154
+ with gr.Tab("Inference"):
155
+ with gr.Row():
156
+ with gr.Column():
157
+ preview_video = gr.Video(label="Preview Video")
158
+ text_input = gr.Textbox(label="Input Text")
159
+ checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results'))
160
+ seed = gr.Number(label="Seed", value=0)
161
+ inference_button = gr.Button("Generate Video")
162
+
163
+ with gr.Column():
164
+
165
+ output_video = gr.Video(label="Output Video")
166
+
167
+ with gr.Accordion("Advanced Settings", open=False):
168
+ with gr.Row():
169
+ inference_steps = gr.Number(label="Inference Steps", value=30)
170
+ motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object")
171
+
172
+ gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown])
173
+
174
+
175
+ def update_checkpoints(checkpoint_dir):
176
+ return gr.update(choices=get_checkpoints('results'))
177
+
178
+ checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video)
179
+ checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown)
180
+ inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video)
181
+
182
+ # 启动Gradio界面
183
+ demo.launch()