Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -61,7 +61,7 @@ def load_models(
|
|
61 |
dit_path=None,
|
62 |
ae_path=None,
|
63 |
qwen2vl_model_path=None,
|
64 |
-
device="
|
65 |
max_length=256,
|
66 |
dtype=torch.bfloat16,
|
67 |
):
|
@@ -118,7 +118,7 @@ class ImageGenerator:
|
|
118 |
dit_path=None,
|
119 |
ae_path=None,
|
120 |
qwen2vl_model_path=None,
|
121 |
-
device="
|
122 |
max_length=640,
|
123 |
dtype=torch.bfloat16,
|
124 |
) -> None:
|
@@ -135,9 +135,9 @@ class ImageGenerator:
|
|
135 |
self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
|
136 |
|
137 |
def to_cuda(self):
|
138 |
-
self.ae.to(device='
|
139 |
-
self.dit.to(device='
|
140 |
-
self.llm_encoder.to(device='
|
141 |
|
142 |
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
143 |
bs, _, h, w = img.shape
|
|
|
61 |
dit_path=None,
|
62 |
ae_path=None,
|
63 |
qwen2vl_model_path=None,
|
64 |
+
device="cpu",
|
65 |
max_length=256,
|
66 |
dtype=torch.bfloat16,
|
67 |
):
|
|
|
118 |
dit_path=None,
|
119 |
ae_path=None,
|
120 |
qwen2vl_model_path=None,
|
121 |
+
device="cpu",
|
122 |
max_length=640,
|
123 |
dtype=torch.bfloat16,
|
124 |
) -> None:
|
|
|
135 |
self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
|
136 |
|
137 |
def to_cuda(self):
|
138 |
+
self.ae.to(device='cpu', dtype=torch.float32)
|
139 |
+
self.dit.to(device='cpu', dtype=torch.bfloat16)
|
140 |
+
self.llm_encoder.to(device='cpu', dtype=torch.bfloat16)
|
141 |
|
142 |
def prepare(self, prompt, img, ref_image, ref_image_raw):
|
143 |
bs, _, h, w = img.shape
|