Bingsu commited on
Commit
c36f73b
1 Parent(s): a90835f

fix: textual_inversion

Browse files
Files changed (1) hide show
  1. textual_inversion.py +24 -25
textual_inversion.py CHANGED
@@ -25,7 +25,8 @@ from diffusers import (
25
  )
26
  from diffusers.optimization import get_scheduler
27
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
28
- from diffusers.utils import check_min_version
 
29
  from huggingface_hub import HfFolder, Repository, whoami
30
 
31
  # TODO: remove and import from diffusers.utils when the new version of diffusers is released
@@ -56,7 +57,7 @@ else:
56
 
57
 
58
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
- check_min_version("0.10.0.dev0")
60
 
61
 
62
  logger = get_logger(__name__)
@@ -316,25 +317,25 @@ imagenet_templates_small = [
316
  ]
317
 
318
  imagenet_style_templates_small = [
319
- "a painting in the style of {}",
320
- "a rendering in the style of {}",
321
- "a cropped painting in the style of {}",
322
- "the painting in the style of {}",
323
- "a clean painting in the style of {}",
324
- "a dirty painting in the style of {}",
325
- "a dark painting in the style of {}",
326
- "a picture in the style of {}",
327
- "a cool painting in the style of {}",
328
- "a close-up painting in the style of {}",
329
- "a bright painting in the style of {}",
330
- "a cropped painting in the style of {}",
331
- "a good painting in the style of {}",
332
- "a close-up painting in the style of {}",
333
- "a rendition in the style of {}",
334
- "a nice painting in the style of {}",
335
- "a small painting in the style of {}",
336
- "a weird painting in the style of {}",
337
- "a large painting in the style of {}",
338
  ]
339
 
340
 
@@ -392,7 +393,7 @@ class TextualInversionDataset(Dataset):
392
  example = {}
393
  image = Image.open(self.image_paths[i % self.num_images])
394
 
395
- if not image.mode == "RGB":
396
  image = image.convert("RGB")
397
 
398
  placeholder_string = self.placeholder_token
@@ -449,13 +450,11 @@ def freeze_params(params):
449
 
450
  def main():
451
  args = parse_args()
452
- logging_dir = os.path.join(args.output_dir, args.logging_dir)
453
 
454
  accelerator = Accelerator(
455
  gradient_accumulation_steps=args.gradient_accumulation_steps,
456
  mixed_precision=args.mixed_precision,
457
- log_with="tensorboard",
458
- logging_dir=logging_dir,
459
  )
460
 
461
  # If passed along, set the training seed now.
 
25
  )
26
  from diffusers.optimization import get_scheduler
27
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
28
+
29
+ # from diffusers.utils import check_min_version
30
  from huggingface_hub import HfFolder, Repository, whoami
31
 
32
  # TODO: remove and import from diffusers.utils when the new version of diffusers is released
 
57
 
58
 
59
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60
+ # check_min_version("0.10.0.dev0")
61
 
62
 
63
  logger = get_logger(__name__)
 
317
  ]
318
 
319
  imagenet_style_templates_small = [
320
+ "a painting of {}, art by *",
321
+ "a rendering of {}, art by *",
322
+ "a cropped painting of {}, art by *",
323
+ "the painting of {}, art by *",
324
+ "a clean painting of {}, art by *",
325
+ "a dirty painting of {}, art by *",
326
+ "a dark painting of {}, art by *",
327
+ "a picture of {}, art by *",
328
+ "a cool painting of {}, art by *",
329
+ "a close-up painting of {}, art by *",
330
+ "a bright painting of {}, art by *",
331
+ "a cropped painting of {}, art by *",
332
+ "a good painting of {}, art by *",
333
+ "a close-up painting of {}, art by *",
334
+ "a rendition of {}, art by *",
335
+ "a nice painting of {}, art by *",
336
+ "a small painting of {}, art by *",
337
+ "a weird painting of {}, art by *",
338
+ "a large painting of {}, art by *",
339
  ]
340
 
341
 
 
393
  example = {}
394
  image = Image.open(self.image_paths[i % self.num_images])
395
 
396
+ if image.mode != "RGB":
397
  image = image.convert("RGB")
398
 
399
  placeholder_string = self.placeholder_token
 
450
 
451
  def main():
452
  args = parse_args()
453
+ # logging_dir = os.path.join(args.output_dir, args.logging_dir)
454
 
455
  accelerator = Accelerator(
456
  gradient_accumulation_steps=args.gradient_accumulation_steps,
457
  mixed_precision=args.mixed_precision,
 
 
458
  )
459
 
460
  # If passed along, set the training seed now.