multimodalart HF staff commited on
Commit
612ce40
1 Parent(s): 26ca94f

Remove model small

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -52,15 +52,15 @@ def spherical_dist_loss(x, y):
52
  return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
53
 
54
  cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
55
- cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
56
  model = get_model('cc12m_1_cfg')()
57
  _, side_y, side_x = model.shape
58
  model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
59
  model = model.half().cuda().eval().requires_grad_(False)
60
 
61
- model_small = get_model('cc12m_1')()
62
- model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
63
- model_small = model_small.half().cuda().eval().requires_grad_(False)
64
 
65
  print(model.clip_model)
66
  clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
@@ -147,7 +147,7 @@ def run_all(prompt, steps, n_images, weight, clip_guided):
147
  else:
148
  extra_args = {'clip_embed': clip_embed}
149
  cond_fn_ = cond_fn
150
- model_fn = make_cond_model_fn(model_small, cond_fn_)
151
  outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
152
  images_out = []
153
  for i, out in enumerate(outs):
 
52
  return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
53
 
54
  cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
55
+ #cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
56
  model = get_model('cc12m_1_cfg')()
57
  _, side_y, side_x = model.shape
58
  model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
59
  model = model.half().cuda().eval().requires_grad_(False)
60
 
61
+ #model_small = get_model('cc12m_1')()
62
+ #model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
63
+ #model_small = model_small.half().cuda().eval().requires_grad_(False)
64
 
65
  print(model.clip_model)
66
  clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
 
147
  else:
148
  extra_args = {'clip_embed': clip_embed}
149
  cond_fn_ = cond_fn
150
+ model_fn = make_cond_model_fn(model, cond_fn_)
151
  outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
152
  images_out = []
153
  for i, out in enumerate(outs):