Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Update model.py
Browse files
model.py
CHANGED
@@ -92,7 +92,7 @@ class Model:
|
|
92 |
preprocessor_name: str,
|
93 |
) -> list[PIL.Image.Image]:
|
94 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
95 |
-
self.
|
96 |
self.vq_model.to('cuda')
|
97 |
if isinstance(image, np.ndarray):
|
98 |
image = Image.fromarray(image)
|
@@ -147,6 +147,7 @@ class Model:
|
|
147 |
top_k=top_k,
|
148 |
top_p=top_p,
|
149 |
sample_logits=True,
|
|
|
150 |
)
|
151 |
sampling_time = time.time() - t1
|
152 |
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
@@ -183,7 +184,7 @@ class Model:
|
|
183 |
control_strength: float,
|
184 |
preprocessor_name: str
|
185 |
) -> list[PIL.Image.Image]:
|
186 |
-
self.
|
187 |
self.t5_model.model.to(self.device)
|
188 |
self.gpt_model_depth.to(self.device)
|
189 |
self.get_control_depth.model.to(self.device)
|
@@ -237,6 +238,7 @@ class Model:
|
|
237 |
top_k=top_k,
|
238 |
top_p=top_p,
|
239 |
sample_logits=True,
|
|
|
240 |
)
|
241 |
sampling_time = time.time() - t1
|
242 |
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
|
|
92 |
preprocessor_name: str,
|
93 |
) -> list[PIL.Image.Image]:
|
94 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
95 |
+
self.gpt_model_edge.to('cuda').to(torch.bfloat16)
|
96 |
self.vq_model.to('cuda')
|
97 |
if isinstance(image, np.ndarray):
|
98 |
image = Image.fromarray(image)
|
|
|
147 |
top_k=top_k,
|
148 |
top_p=top_p,
|
149 |
sample_logits=True,
|
150 |
+
control_strength=control_strength,
|
151 |
)
|
152 |
sampling_time = time.time() - t1
|
153 |
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
|
|
184 |
control_strength: float,
|
185 |
preprocessor_name: str
|
186 |
) -> list[PIL.Image.Image]:
|
187 |
+
self.gpt_model_edge.to('cpu')
|
188 |
self.t5_model.model.to(self.device)
|
189 |
self.gpt_model_depth.to(self.device)
|
190 |
self.get_control_depth.model.to(self.device)
|
|
|
238 |
top_k=top_k,
|
239 |
top_p=top_p,
|
240 |
sample_logits=True,
|
241 |
+
control_strength=control_strength,
|
242 |
)
|
243 |
sampling_time = time.time() - t1
|
244 |
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|