Update app.py
Browse files
app.py
CHANGED
|
@@ -104,6 +104,32 @@ RESOLUTIONS = {
|
|
| 104 |
class LTX23DistilledA2VPipeline:
|
| 105 |
"""DistilledPipeline with optional audio conditioning."""
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
def __call__(
|
| 108 |
self,
|
| 109 |
prompt: str,
|
|
|
|
| 104 |
class LTX23DistilledA2VPipeline:
|
| 105 |
"""DistilledPipeline with optional audio conditioning."""
|
| 106 |
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
distilled_checkpoint_path: str,
|
| 110 |
+
gemma_root: str,
|
| 111 |
+
spatial_upsampler_path: str,
|
| 112 |
+
loras: tuple,
|
| 113 |
+
quantization: QuantizationPolicy | None = None,
|
| 114 |
+
):
|
| 115 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 116 |
+
self.dtype = torch.bfloat16
|
| 117 |
+
|
| 118 |
+
self.model_ledger = ModelLedger(
|
| 119 |
+
dtype=self.dtype,
|
| 120 |
+
device=self.device,
|
| 121 |
+
checkpoint_path=distilled_checkpoint_path,
|
| 122 |
+
spatial_upsampler_path=spatial_upsampler_path,
|
| 123 |
+
gemma_root_path=gemma_root,
|
| 124 |
+
loras=loras,
|
| 125 |
+
quantization=quantization,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.pipeline_components = PipelineComponents(
|
| 129 |
+
dtype=self.dtype,
|
| 130 |
+
device=self.device,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
def __call__(
|
| 134 |
self,
|
| 135 |
prompt: str,
|