EuuIia commited on
Commit
8815ceb
·
verified ·
1 Parent(s): db005a9

Create managers/vae_manager.py

Browse files
Files changed (1) hide show
  1. managers/vae_manager.py +55 -0
managers/vae_manager.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vae_manager.py — versão simples (beta 1.0)
2
+ # Responsável por decodificar latentes (B,C,T,H,W) → pixels (B,C,T,H',W') em [0,1].
3
+
4
+ import torch
5
+ import contextlib
6
+
7
+ class _SimpleVAEManager:
8
+ def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32):
9
+ """
10
+ pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...)
11
+ device: "cuda" ou "cpu" onde a decodificação deve ocorrer
12
+ autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32)
13
+ """
14
+ self.pipeline = pipeline
15
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.autocast_dtype = autocast_dtype
17
+
18
+ def attach_pipeline(self, pipeline, device=None, autocast_dtype=None):
19
+ self.pipeline = pipeline
20
+ if device is not None:
21
+ self.device = device
22
+ if autocast_dtype is not None:
23
+ self.autocast_dtype = autocast_dtype
24
+
25
+ @torch.no_grad()
26
+ def decode(self, latents_5d: torch.Tensor) -> torch.Tensor:
27
+ """
28
+ Decodifica todo o bloco 5D de uma vez, replicando o fluxo simples do deformes4D.
29
+ Retorna tensor de pixels 5D em [0,1] com shape (B,C,T,H',W').
30
+ """
31
+ if self.pipeline is None:
32
+ raise RuntimeError("VAE Manager sem pipeline. Chame attach_pipeline primeiro.")
33
+
34
+ # Garante device correto
35
+ latents_5d = latents_5d.to(self.device, non_blocking=True)
36
+
37
+ ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
38
+ with ctx:
39
+ if hasattr(self.pipeline, "decode_latents"):
40
+ pixels_5d = self.pipeline.decode_latents(latents_5d)
41
+ elif hasattr(self.pipeline, "vae") and hasattr(self.pipeline.vae, "decode"):
42
+ pixels_5d = self.pipeline.vae.decode(latents_5d)
43
+ else:
44
+ raise RuntimeError("Pipeline não expõe decode_latents nem vae.decode.")
45
+
46
+ # Normaliza para [0,1] se vier em [-1,1]
47
+ if pixels_5d.min() < 0:
48
+ pixels_5d = (pixels_5d.clamp(-1, 1) + 1.0) / 2.0
49
+ else:
50
+ pixels_5d = pixels_5d.clamp(0, 1)
51
+
52
+ return pixels_5d
53
+
54
+ # Singleton global de uso simples
55
+ vae_manager_singleton = _SimpleVAEManager()