williamberman commited on
Commit
bde23cb
1 Parent(s): d89243e
Files changed (1) hide show
  1. sdxl.py +2 -1
sdxl.py CHANGED
@@ -15,7 +15,6 @@ from torch.utils.data import default_collate
15
  from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
16
  CLIPTokenizerFast)
17
 
18
- import wandb
19
  from diffusion import (default_num_train_timesteps,
20
  euler_ode_solver_diffusion_loop, make_sigmas)
21
  from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
@@ -234,6 +233,8 @@ class SDXLTraining:
234
 
235
  @torch.no_grad()
236
  def log_validation(self, step, num_validation_images: int, validation_prompts: Optional[List[str]] = None, validation_images: Optional[List[str]] = None):
 
 
237
  if isinstance(self.unet, DDP):
238
  unet = self.unet.module
239
  unet.eval()
 
15
  from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
16
  CLIPTokenizerFast)
17
 
 
18
  from diffusion import (default_num_train_timesteps,
19
  euler_ode_solver_diffusion_loop, make_sigmas)
20
  from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
 
233
 
234
  @torch.no_grad()
235
  def log_validation(self, step, num_validation_images: int, validation_prompts: Optional[List[str]] = None, validation_images: Optional[List[str]] = None):
236
+ import wandb
237
+
238
  if isinstance(self.unet, DDP):
239
  unet = self.unet.module
240
  unet.eval()