Badr AlKhamissi commited on
Commit
35c104c
1 Parent(s): e8f6bdd

fp32 instead of fp16

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. code/losses.py +11 -20
app.py CHANGED
@@ -30,7 +30,7 @@ from diffusers import StableDiffusionPipeline
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
32
  model = None
33
- model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device)
34
 
35
  from typing import Mapping
36
  from tqdm import tqdm
 
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
32
  model = None
33
+ model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
34
 
35
  from typing import Mapping
36
  from tqdm import tqdm
code/losses.py CHANGED
@@ -21,8 +21,8 @@ class SDSLoss(nn.Module):
21
  self.pipe = model
22
  self.pipe = self.pipe.to(self.device)
23
 
24
- self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
25
- self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
26
 
27
  # default scheduler: PNDMScheduler(beta_start=0.00085, beta_end=0.012,
28
  # beta_schedule="scaled_linear", num_train_timesteps=1000)
@@ -35,24 +35,15 @@ class SDSLoss(nn.Module):
35
  def embed_text(self):
36
  # tokenizer and embed text
37
 
38
- if "jpeg" not in self.cfg.caption:
39
- text_input = self.pipe.tokenizer(self.cfg.caption, padding="max_length",
40
- max_length=self.pipe.tokenizer.model_max_length,
41
- truncation=True, return_tensors="pt")
42
- uncond_input = self.pipe.tokenizer([""], padding="max_length",
43
- max_length=text_input.input_ids.shape[-1],
44
- return_tensors="pt")
45
- with torch.no_grad():
46
- text_embeddings = self.pipe.text_encoder(text_input.input_ids.to(self.device))[0]
47
- uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.device))[0]
48
- else:
49
- print(f"> Reading Image {self.cfg.caption}")
50
- with torch.no_grad():
51
- image = Image.open(self.cfg.caption)
52
- inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
53
- img_emb = self.clip_model.get_image_features(**inputs)
54
- text_embeddings = img_emb
55
- uncond_embeddings = img_emb
56
 
57
  self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
58
  self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
 
21
  self.pipe = model
22
  self.pipe = self.pipe.to(self.device)
23
 
24
+ # self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
25
+ # self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
26
 
27
  # default scheduler: PNDMScheduler(beta_start=0.00085, beta_end=0.012,
28
  # beta_schedule="scaled_linear", num_train_timesteps=1000)
 
35
  def embed_text(self):
36
  # tokenizer and embed text
37
 
38
+ text_input = self.pipe.tokenizer(self.cfg.caption, padding="max_length",
39
+ max_length=self.pipe.tokenizer.model_max_length,
40
+ truncation=True, return_tensors="pt")
41
+ uncond_input = self.pipe.tokenizer([""], padding="max_length",
42
+ max_length=text_input.input_ids.shape[-1],
43
+ return_tensors="pt")
44
+ with torch.no_grad():
45
+ text_embeddings = self.pipe.text_encoder(text_input.input_ids.to(self.device))[0]
46
+ uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.device))[0]
 
 
 
 
 
 
 
 
 
47
 
48
  self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
49
  self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)