williamberman
commited on
Commit
•
bde23cb
1
Parent(s):
d89243e
fix
Browse files
sdxl.py
CHANGED
@@ -15,7 +15,6 @@ from torch.utils.data import default_collate
|
|
15 |
from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
|
16 |
CLIPTokenizerFast)
|
17 |
|
18 |
-
import wandb
|
19 |
from diffusion import (default_num_train_timesteps,
|
20 |
euler_ode_solver_diffusion_loop, make_sigmas)
|
21 |
from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
|
@@ -234,6 +233,8 @@ class SDXLTraining:
|
|
234 |
|
235 |
@torch.no_grad()
|
236 |
def log_validation(self, step, num_validation_images: int, validation_prompts: Optional[List[str]] = None, validation_images: Optional[List[str]] = None):
|
|
|
|
|
237 |
if isinstance(self.unet, DDP):
|
238 |
unet = self.unet.module
|
239 |
unet.eval()
|
|
|
15 |
from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
|
16 |
CLIPTokenizerFast)
|
17 |
|
|
|
18 |
from diffusion import (default_num_train_timesteps,
|
19 |
euler_ode_solver_diffusion_loop, make_sigmas)
|
20 |
from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
|
|
|
233 |
|
234 |
@torch.no_grad()
|
235 |
def log_validation(self, step, num_validation_images: int, validation_prompts: Optional[List[str]] = None, validation_images: Optional[List[str]] = None):
|
236 |
+
import wandb
|
237 |
+
|
238 |
if isinstance(self.unet, DDP):
|
239 |
unet = self.unet.module
|
240 |
unet.eval()
|