Tsmith2024 commited on
Commit
72d7a72
·
verified ·
1 Parent(s): 627335d

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +42 -29
handler.py CHANGED
@@ -2,68 +2,81 @@ import base64
2
  import io
3
  import os
4
  import tempfile
5
- from typing import Any, Dict
6
-
7
  import torch
 
8
  from PIL import Image
9
  from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
10
  from diffusers.utils import export_to_video
11
 
12
-
13
  class EndpointHandler:
14
  def __init__(self, path: str = ""):
15
- model_path = path or os.environ.get("MODEL_ID", "/repository")
16
- print(f"Loading Wan2.2-TI2V-5B from {model_path}…")
17
- dtype = torch.bfloat16
 
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
- vae = AutoencoderKLWan.from_pretrained(
20
- model_path, subfolder="vae", torch_dtype=torch.float32,
21
- )
22
  self.pipe = WanImageToVideoPipeline.from_pretrained(
23
- model_path, vae=vae, torch_dtype=dtype,
 
 
 
24
  )
25
- self.pipe.to(device)
26
- self.pipe.enable_attention_slicing()
27
  self.device = device
28
  print("✓ Model loaded and ready")
29
 
30
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
- inputs = data.get("inputs", data)
32
- start_img = self._decode_image(inputs["start_image"])
33
- end_img = self._decode_image(inputs["end_image"])
34
- prompt = inputs.get("prompt", "Smooth cinematic motion, natural movement")
 
 
 
35
  num_frames = int(inputs.get("num_frames", 41))
36
- guidance = float(inputs.get("guidance_scale", 5.0))
37
- steps = int(inputs.get("num_inference_steps", 20))
38
- fps = int(inputs.get("fps", 16))
 
39
  num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1)
40
- w, h = start_img.size
41
- width = (w // 32) * 32
42
- height = (h // 32) * 32
43
- start_img = start_img.resize((width, height))
44
- end_img = end_img.resize((width, height))
 
 
 
 
45
  with torch.inference_mode():
46
  output = self.pipe(
47
  image=start_img,
48
  last_image=end_img,
49
  prompt=prompt,
50
- negative_prompt="",
51
  height=height,
52
  width=width,
53
  num_frames=num_frames,
54
  guidance_scale=guidance,
55
  num_inference_steps=steps,
56
  ).frames[0]
 
 
57
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
58
  tmp_path = tmp.name
59
- export_to_video(output, tmp_path, fps=fps)
 
 
60
  with open(tmp_path, "rb") as f:
61
  video_b64 = base64.b64encode(f.read()).decode("utf-8")
 
62
  os.unlink(tmp_path)
63
  return {"video": video_b64}
64
 
65
- @staticmethod
66
- def _decode_image(b64_str: str) -> Image.Image:
67
  if "," in b64_str:
68
  b64_str = b64_str.split(",", 1)[1]
69
- return Image.open(io.BytesIO(base64.b64decode(b64_str))).convert("RGB")
 
 
2
  import io
3
  import os
4
  import tempfile
 
 
5
  import torch
6
+ from typing import Any, Dict
7
  from PIL import Image
8
  from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
9
  from diffusers.utils import export_to_video
10
 
 
11
  class EndpointHandler:
12
  def __init__(self, path: str = ""):
13
+ # Use the MODEL_ID env var or default to the 5B TI2V model
14
+ model_id = os.environ.get("MODEL_ID", "Wan-AI/Wan2.2-TI2V-5B-Diffusers")
15
+ print(f"Loading Wan2.2-TI2V-5B from {model_id}...")
16
+
17
+ dtype = torch.bfloat16
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # VAE in float32 for precision, rest in bfloat16 for speed/memory
21
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
22
  self.pipe = WanImageToVideoPipeline.from_pretrained(
23
+ model_id,
24
+ vae=vae,
25
+ torch_dtype=dtype,
26
+ device_map="auto"
27
  )
 
 
28
  self.device = device
29
  print("✓ Model loaded and ready")
30
 
31
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
32
+ inputs = data.get("inputs", data)
33
+
34
+ # Decode start and end images
35
+ start_img = self._decode_image(inputs["start_image"])
36
+ end_img = self._decode_image(inputs["end_image"])
37
+
38
+ prompt = inputs.get("prompt", "Smooth cinematic motion")
39
  num_frames = int(inputs.get("num_frames", 41))
40
+ guidance = float(inputs.get("guidance_scale", 5.0))
41
+ steps = int(inputs.get("num_inference_steps", 20))
42
+
43
+ # Wan requires (4N + 1) frames
44
  num_frames = max(9, ((num_frames - 1) // 4) * 4 + 1)
45
+
46
+ # Dimension snapping
47
+ w, h = start_img.size
48
+ width = (w // 32) * 32
49
+ height = (h // 32) * 32
50
+
51
+ start_img = start_img.resize((width, height))
52
+ end_img = end_img.resize((width, height))
53
+
54
  with torch.inference_mode():
55
  output = self.pipe(
56
  image=start_img,
57
  last_image=end_img,
58
  prompt=prompt,
 
59
  height=height,
60
  width=width,
61
  num_frames=num_frames,
62
  guidance_scale=guidance,
63
  num_inference_steps=steps,
64
  ).frames[0]
65
+
66
+ # Export video to bytes
67
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
68
  tmp_path = tmp.name
69
+
70
+ export_to_video(output, tmp_path, fps=16)
71
+
72
  with open(tmp_path, "rb") as f:
73
  video_b64 = base64.b64encode(f.read()).decode("utf-8")
74
+
75
  os.unlink(tmp_path)
76
  return {"video": video_b64}
77
 
78
+ def _decode_image(self, b64_str: str) -> Image.Image:
 
79
  if "," in b64_str:
80
  b64_str = b64_str.split(",", 1)[1]
81
+ img_bytes = base64.b64decode(b64_str)
82
+ return Image.open(io.BytesIO(img_bytes)).convert("RGB")