Spaces:
Sleeping
Sleeping
Yaron Koresh
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,6 +22,21 @@ from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, Eul
|
|
| 22 |
import jax
|
| 23 |
import jax.numpy as jnp
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def forest_schnell():
|
| 26 |
PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
|
| 27 |
return PIPE
|
|
@@ -193,7 +208,9 @@ def main():
|
|
| 193 |
global time
|
| 194 |
global last_motion
|
| 195 |
global base
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
last_motion=None
|
| 198 |
fps=20
|
| 199 |
time=16
|
|
@@ -208,13 +225,12 @@ def main():
|
|
| 208 |
|
| 209 |
repo="stabilityai/sd-vae-ft-mse-original"
|
| 210 |
ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
|
| 211 |
-
|
| 212 |
vae = "./vae"
|
| 213 |
|
| 214 |
repo="ByteDance/SDXL-Lightning"
|
| 215 |
ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
|
| 216 |
-
|
| 217 |
-
unet = "./unet"
|
| 218 |
|
| 219 |
#repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
|
| 220 |
|
|
|
|
| 22 |
import jax
|
| 23 |
import jax.numpy as jnp
|
| 24 |
|
| 25 |
+
class MyModel(nn.Module):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.register_buffer('a', torch.ones(1, 1))
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor, extend: bool):
|
| 31 |
+
if extend:
|
| 32 |
+
new_tensor = torch.randn(1, 1)
|
| 33 |
+
self.a = torch.cat([self.a, new_tensor], dim=0)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
for _ in range(10):
|
| 37 |
+
out = model(x, extend=True)
|
| 38 |
+
print(model.state_dict())
|
| 39 |
+
|
| 40 |
def forest_schnell():
|
| 41 |
PIPE = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, token=os.getenv("hf_token")).to("cuda")
|
| 42 |
return PIPE
|
|
|
|
| 208 |
global time
|
| 209 |
global last_motion
|
| 210 |
global base
|
| 211 |
+
global model
|
| 212 |
+
|
| 213 |
+
model = MyModel()
|
| 214 |
last_motion=None
|
| 215 |
fps=20
|
| 216 |
time=16
|
|
|
|
| 225 |
|
| 226 |
repo="stabilityai/sd-vae-ft-mse-original"
|
| 227 |
ckpt="vae-ft-mse-840000-ema-pruned.safetensors"
|
| 228 |
+
vae = model(load_file(hf_hub_download(repo, ckpt), device=device)
|
| 229 |
vae = "./vae"
|
| 230 |
|
| 231 |
repo="ByteDance/SDXL-Lightning"
|
| 232 |
ckpt=f"sdxl_lightning_{step}step_unet.safetensors"
|
| 233 |
+
unet = model(load_file(hf_hub_download(repo, ckpt), device=device)
|
|
|
|
| 234 |
|
| 235 |
#repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
|
| 236 |
|