Badr AlKhamissi commited on
Commit
e8f6bdd
·
1 Parent(s): 9530fad

losses fix

Browse files
Files changed (1) hide show
  1. code/losses.py +2 -5
code/losses.py CHANGED
@@ -14,12 +14,11 @@ from transformers import CLIPProcessor, CLIPModel
14
  from diffusers import StableDiffusionPipeline
15
 
16
  class SDSLoss(nn.Module):
17
- def __init__(self, cfg, device):
18
  super(SDSLoss, self).__init__()
19
  self.cfg = cfg
20
  self.device = device
21
- self.pipe = StableDiffusionPipeline.from_pretrained(cfg.diffusion.model,
22
- torch_dtype=torch.float16, use_auth_token=cfg.token)
23
  self.pipe = self.pipe.to(self.device)
24
 
25
  self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
@@ -55,8 +54,6 @@ class SDSLoss(nn.Module):
55
  text_embeddings = img_emb
56
  uncond_embeddings = img_emb
57
 
58
- print(text_embeddings.size())
59
- print(uncond_embeddings.size())
60
  self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
61
  self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
62
  del self.pipe.tokenizer
 
14
  from diffusers import StableDiffusionPipeline
15
 
16
  class SDSLoss(nn.Module):
17
+ def __init__(self, cfg, device, model):
18
  super(SDSLoss, self).__init__()
19
  self.cfg = cfg
20
  self.device = device
21
+ self.pipe = model
 
22
  self.pipe = self.pipe.to(self.device)
23
 
24
  self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
 
54
  text_embeddings = img_emb
55
  uncond_embeddings = img_emb
56
 
 
 
57
  self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
58
  self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
59
  del self.pipe.tokenizer