multimodalart HF staff commited on
Commit
568d1c7
1 Parent(s): d6f9b71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -52,10 +52,16 @@ def spherical_dist_loss(x, y):
52
  return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
53
 
54
  cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
 
55
  model = get_model('cc12m_1_cfg')()
56
  _, side_y, side_x = model.shape
57
  model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
58
  model = model.half().cuda().eval().requires_grad_(False)
 
 
 
 
 
59
  clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
60
  clip_model.eval().requires_grad_(False)
61
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
@@ -138,7 +144,7 @@ def run_all(prompt, steps, n_images, weight, clip_guided):
138
  else:
139
  extra_args = {'clip_embed': clip_embed}
140
  cond_fn_ = cond_fn
141
- model_fn = make_cond_model_fn(model, cond_fn_)
142
  outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
143
  images_out = []
144
  for i, out in enumerate(outs):
 
52
  return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
53
 
54
  cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
55
+ cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
56
  model = get_model('cc12m_1_cfg')()
57
  _, side_y, side_x = model.shape
58
  model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
59
  model = model.half().cuda().eval().requires_grad_(False)
60
+
61
+ model_small = get_model('cc12m_1')()
62
+ model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
63
+ model_small = model.half().cuda().eval().requires_grad_(False)
64
+
65
  clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
66
  clip_model.eval().requires_grad_(False)
67
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
 
144
  else:
145
  extra_args = {'clip_embed': clip_embed}
146
  cond_fn_ = cond_fn
147
+ model_fn = make_cond_model_fn(model_small, cond_fn_)
148
  outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
149
  images_out = []
150
  for i, out in enumerate(outs):