Spaces:
Runtime error
Runtime error
Commit
•
568d1c7
1
Parent(s):
d6f9b71
Update app.py
Browse files
app.py
CHANGED
@@ -52,10 +52,16 @@ def spherical_dist_loss(x, y):
|
|
52 |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
53 |
|
54 |
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
|
|
|
55 |
model = get_model('cc12m_1_cfg')()
|
56 |
_, side_y, side_x = model.shape
|
57 |
model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
|
58 |
model = model.half().cuda().eval().requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
59 |
clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
|
60 |
clip_model.eval().requires_grad_(False)
|
61 |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
@@ -138,7 +144,7 @@ def run_all(prompt, steps, n_images, weight, clip_guided):
|
|
138 |
else:
|
139 |
extra_args = {'clip_embed': clip_embed}
|
140 |
cond_fn_ = cond_fn
|
141 |
-
model_fn = make_cond_model_fn(
|
142 |
outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
|
143 |
images_out = []
|
144 |
for i, out in enumerate(outs):
|
|
|
52 |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
53 |
|
54 |
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
|
55 |
+
cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
|
56 |
model = get_model('cc12m_1_cfg')()
|
57 |
_, side_y, side_x = model.shape
|
58 |
model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
|
59 |
model = model.half().cuda().eval().requires_grad_(False)
|
60 |
+
|
61 |
+
model_small = get_model('cc12m_1')()
|
62 |
+
model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
|
63 |
+
model_small = model.half().cuda().eval().requires_grad_(False)
|
64 |
+
|
65 |
clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
|
66 |
clip_model.eval().requires_grad_(False)
|
67 |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
|
|
144 |
else:
|
145 |
extra_args = {'clip_embed': clip_embed}
|
146 |
cond_fn_ = cond_fn
|
147 |
+
model_fn = make_cond_model_fn(model_small, cond_fn_)
|
148 |
outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
|
149 |
images_out = []
|
150 |
for i, out in enumerate(outs):
|