multimodalart HF staff commited on
Commit
d6f9b71
1 Parent(s): 7af4a09

Add spherical dist loss

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -46,6 +46,11 @@ class MakeCutouts(nn.Module):
46
  cutouts.append(cutout)
47
  return torch.cat(cutouts)
48
 
 
 
 
 
 
49
  cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
50
  model = get_model('cc12m_1_cfg')()
51
  _, side_y, side_x = model.shape
 
46
  cutouts.append(cutout)
47
  return torch.cat(cutouts)
48
 
49
+ def spherical_dist_loss(x, y):
50
+ x = F.normalize(x, dim=-1)
51
+ y = F.normalize(y, dim=-1)
52
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
53
+
54
  cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
55
  model = get_model('cc12m_1_cfg')()
56
  _, side_y, side_x = model.shape