tedlasai commited on
Commit
6a2f159
·
1 Parent(s): 6a1328e

fixed full pipeline and added options

Browse files
Files changed (2) hide show
  1. gradio/app.py +77 -26
  2. inference.py +14 -6
gradio/app.py CHANGED
@@ -19,6 +19,8 @@ args.pretrained_model_path = "THUDM/CogVideoX-2b"
19
  args.model_config_path = "training/configs/outsidephotos.yaml"
20
  args.video_width = 1280
21
  args.video_height = 720
 
 
22
  args.seed = None
23
 
24
  pipe, model_config = load_model(args)
@@ -27,40 +29,62 @@ OUTPUT_DIR = Path("/tmp/generated_videos")
27
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
28
 
29
 
30
- @spaces.GPU
31
- def generate_video_from_image(image: Image.Image) -> str:
 
 
 
 
 
 
32
  print("Generating video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  video_id = uuid.uuid4().hex
35
  output_path = OUTPUT_DIR / f"{video_id}.mp4"
36
 
37
  args.device = "cuda"
38
 
39
- processed_image, video = inference_on_image(pipe, image, "past_present_and_future", model_config, args)
 
40
  export_to_video(video, output_path, fps=20)
41
 
42
- return str(output_path)
43
-
44
-
45
- def demo_predict(image: Image.Image) -> str:
46
- """
47
- Wrapper for Gradio. Takes an image and returns a video path.
48
- """
49
- if image is None:
50
- raise gr.Error("Please upload an image first.")
51
-
52
- video_path = generate_video_from_image(image)
53
- if not os.path.exists(video_path):
54
  raise gr.Error("Video generation failed: output file not found.")
55
- return video_path
 
56
 
57
 
58
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
59
  gr.Markdown(
60
  """
61
- # 🖼️ ➜ 🎬 Recover motion from a blurry image!
 
 
 
 
 
 
62
 
63
- Upload an image and the model will generate a short video.
64
  """
65
  )
66
 
@@ -71,24 +95,51 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
71
  label="Input image",
72
  interactive=True,
73
  )
74
- tense_choice = gr.Dropdown(
75
- label="I want to generate the",
76
- choices=["present", "past, present and future"],
77
- value="past, present and future", # default selection
78
- interactive=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  )
 
80
  generate_btn = gr.Button("Generate video", variant="primary")
 
81
  with gr.Column():
82
  video_out = gr.Video(
83
  label="Generated video",
84
- format="mp4", # ensures browser-friendly output
85
  autoplay=True,
86
  loop=True,
87
  )
88
 
89
  generate_btn.click(
90
- fn=demo_predict,
91
- inputs=image_in,
92
  outputs=video_out,
93
  api_name="predict",
94
  )
 
19
  args.model_config_path = "training/configs/outsidephotos.yaml"
20
  args.video_width = 1280
21
  args.video_height = 720
22
+ # args.video_width = 960
23
+ # args.video_height = 540
24
  args.seed = None
25
 
26
  pipe, model_config = load_model(args)
 
29
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
30
 
31
 
32
+ @spaces.GPU(timeout=300)
33
+ def generate_video_from_image(image: Image.Image, interval_key: str, orientation_mode: str, num_inference_steps: int) -> str:
34
+ """
35
+ Wrapper for Gradio. Takes an image and returns a video path.
36
+ """
37
+ if image is None:
38
+ raise gr.Error("Please upload an image first.")
39
+
40
  print("Generating video")
41
+ import torch
42
+ print("CUDA:", torch.cuda.is_available())
43
+ print("Device:", torch.cuda.get_device_name(0))
44
+ print("bf16 supported:", torch.cuda.is_bf16_supported())
45
+
46
+ if orientation_mode == "Landscape (1280×720)":
47
+ print("Chosing resolution 1280×720 (landscape)")
48
+ args.video_width = 1280
49
+ args.video_height = 720
50
+ elif orientation_mode == "Portrait (720×1280)":
51
+ print("Choosing resolution 720×1280 (portrait)")
52
+ args.video_height = 1280
53
+ args.video_width = 720
54
+ else:
55
+ print("Unknown orientation mode", orientation_mode, "defaulting to 1280x720")
56
+ args.video_width = 1280
57
+ args.video_height = 720
58
+
59
+ args.num_inference_steps = num_inference_steps
60
 
61
  video_id = uuid.uuid4().hex
62
  output_path = OUTPUT_DIR / f"{video_id}.mp4"
63
 
64
  args.device = "cuda"
65
 
66
+ pipe.to(args.device)
67
+ processed_image, video = inference_on_image(pipe, image, interval_key, model_config, args)
68
  export_to_video(video, output_path, fps=20)
69
 
70
+ if not os.path.exists(output_path):
 
 
 
 
 
 
 
 
 
 
 
71
  raise gr.Error("Video generation failed: output file not found.")
72
+
73
+ return str(output_path)
74
 
75
 
76
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
77
  gr.Markdown(
78
  """
79
+ # 🖼️ ➜ 🎬 Recover Motion from a Blurry Image
80
+
81
+ This demo accompanies the paper **“Generating the Past, Present, and Future from a Motion-Blurred Image”**
82
+ by Tedla *et al.*, ACM Transactions on Graphics (SIGGRAPH Asia 2025).
83
+
84
+ - 🌐 **Project page:** <https://blur2vid.github.io/>
85
+ - 💻 **Code:** <https://github.com/tedlasai/blur2vid/>
86
 
87
+ Upload a blurry image and the model will generate a short video containing the recovered motion depending on your selection.
88
  """
89
  )
90
 
 
95
  label="Input image",
96
  interactive=True,
97
  )
98
+
99
+ with gr.Row():
100
+ tense_choice = gr.Radio(
101
+ label="Select the interval to be generated:",
102
+ choices=["present", "past, present and future"],
103
+ value="past, present and future",
104
+ interactive=True,
105
+ )
106
+
107
+ with gr.Row():
108
+ mode_choice = gr.Radio(
109
+ label="Orientation",
110
+ choices=["Landscape (1280×720)", "Portrait (720×1280)"],
111
+ value="Landscape (1280×720)",
112
+ interactive=True,
113
+ )
114
+
115
+ gr.Markdown(
116
+ "<span style='font-size: 12px; color: gray;'>"
117
+ "Note: Model was trained on 1280×720 (Landscape). Portrait mode will degrade performance."
118
+ "</span>"
119
+ )
120
+
121
+ num_inference_steps = gr.Slider(
122
+ label="Number of inference steps",
123
+ minimum=4,
124
+ maximum=50,
125
+ step=1,
126
+ value=20,
127
+ info="More steps = better quality but slower",
128
  )
129
+
130
  generate_btn = gr.Button("Generate video", variant="primary")
131
+
132
  with gr.Column():
133
  video_out = gr.Video(
134
  label="Generated video",
135
+ format="mp4",
136
  autoplay=True,
137
  loop=True,
138
  )
139
 
140
  generate_btn.click(
141
+ fn=generate_video_from_image,
142
+ inputs=[image_in, tense_choice, mode_choice, num_inference_steps], # ← include tense_choice!
143
  outputs=video_out,
144
  api_name="predict",
145
  )
inference.py CHANGED
@@ -122,6 +122,7 @@ def load_model(args):
122
  revision=model_config["revision"],
123
  variant=model_config["variant"],
124
  low_cpu_mem_usage=False,
 
125
  )
126
  weight_path = hf_hub_download(
127
  repo_id=args.blur2vid_hf_repo_path,
@@ -159,11 +160,12 @@ def load_model(args):
159
 
160
  # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
161
  # as these weights are only used for inference, keeping weights in full precision is not required.
162
- weight_dtype = torch.bfloat16
 
163
 
164
- # text_encoder.to(dtype=weight_dtype)
165
- # transformer.to(dtype=weight_dtype)
166
- # vae.to(dtype=weight_dtype)
167
 
168
  pipe = ControlnetCogVideoXPipeline.from_pretrained(
169
  args.pretrained_model_path,
@@ -199,7 +201,7 @@ def inference_on_image(pipe, image, interval_key, model_config, args):
199
  # run inference
200
  generator = torch.Generator(device=args.device).manual_seed(args.seed) if args.seed else None
201
 
202
- with torch.autocast(args.device, enabled=True):
203
  batch = convert_to_batch(image, interval_key, (args.video_height, args.video_width))
204
 
205
  frame = batch["blur_img"].permute(0, 2, 3, 1).cpu().numpy()
@@ -216,7 +218,7 @@ def inference_on_image(pipe, image, interval_key, model_config, args):
216
  "height": batch["height"],
217
  "width": batch["width"],
218
  "num_frames": torch.tensor([[model_config["max_num_frames"]]]), # torch.tensor([[batch["num_frames"]]]),
219
- "num_inference_steps": model_config["num_inference_steps"],
220
  }
221
 
222
  input_image = frame
@@ -305,6 +307,12 @@ if __name__ == "__main__":
305
  default=720,
306
  help="video resolution height",
307
  )
 
 
 
 
 
 
308
  parser.add_argument(
309
  "--seed",
310
  type=int,
 
122
  revision=model_config["revision"],
123
  variant=model_config["variant"],
124
  low_cpu_mem_usage=False,
125
+ attn_implementation="flash_attention_2",
126
  )
127
  weight_path = hf_hub_download(
128
  repo_id=args.blur2vid_hf_repo_path,
 
160
 
161
  # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
162
  # as these weights are only used for inference, keeping weights in full precision is not required.
163
+ # Somehow for HF Spaces we do need to keep them in full precision
164
+ weight_dtype = torch.bfloat16 # torch.float32 # torch.bfloat16
165
 
166
+ text_encoder.to(dtype=weight_dtype)
167
+ transformer.to(dtype=weight_dtype)
168
+ vae.to(dtype=weight_dtype)
169
 
170
  pipe = ControlnetCogVideoXPipeline.from_pretrained(
171
  args.pretrained_model_path,
 
201
  # run inference
202
  generator = torch.Generator(device=args.device).manual_seed(args.seed) if args.seed else None
203
 
204
+ with torch.autocast(device_type=args.device, dtype=torch.bfloat16, enabled=True):
205
  batch = convert_to_batch(image, interval_key, (args.video_height, args.video_width))
206
 
207
  frame = batch["blur_img"].permute(0, 2, 3, 1).cpu().numpy()
 
218
  "height": batch["height"],
219
  "width": batch["width"],
220
  "num_frames": torch.tensor([[model_config["max_num_frames"]]]), # torch.tensor([[batch["num_frames"]]]),
221
+ "num_inference_steps": args.num_inference_steps,
222
  }
223
 
224
  input_image = frame
 
307
  default=720,
308
  help="video resolution height",
309
  )
310
+ parser.add_argument(
311
+ "--num_inference_steps",
312
+ type=int,
313
+ default=50,
314
+ help="number of DDIM steps",
315
+ )
316
  parser.add_argument(
317
  "--seed",
318
  type=int,