user commited on
Commit
013a220
1 Parent(s): 0d6d3ac

Added class data support

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. trainer.py +11 -1
app.py CHANGED
@@ -77,6 +77,7 @@ def create_training_demo(trainer: Trainer, pipe: InferencePipeline) -> gr.Blocks
77
  with gr.Box():
78
  gr.Markdown("Training Data")
79
  concept_images = gr.Files(label="Images for your concept")
 
80
  concept_prompt = gr.Textbox(label="Concept Prompt", max_lines=1)
81
  gr.Markdown(
82
  """
@@ -202,6 +203,7 @@ def create_training_demo(trainer: Trainer, pipe: InferencePipeline) -> gr.Blocks
202
  num_training_steps,
203
  concept_images,
204
  concept_prompt,
 
205
  learning_rate,
206
  gradient_accumulation,
207
  fp16,
 
77
  with gr.Box():
78
  gr.Markdown("Training Data")
79
  concept_images = gr.Files(label="Images for your concept")
80
+ class_images = gr.Files(label="Class images")
81
  concept_prompt = gr.Textbox(label="Concept Prompt", max_lines=1)
82
  gr.Markdown(
83
  """
 
203
  num_training_steps,
204
  concept_images,
205
  concept_prompt,
206
+ class_images,
207
  learning_rate,
208
  gradient_accumulation,
209
  fp16,
trainer.py CHANGED
@@ -31,6 +31,7 @@ class Trainer:
31
  self.is_running_message = "Another training is in progress."
32
 
33
  self.output_dir = pathlib.Path("results")
 
34
  self.instance_data_dir = self.output_dir / "training_data"
35
 
36
  def check_if_running(self) -> dict:
@@ -52,6 +53,13 @@ class Trainer:
52
  out_path = self.instance_data_dir / f"{i:03d}.jpg"
53
  image.save(out_path, format="JPEG", quality=100)
54
 
 
 
 
 
 
 
 
55
  def run(
56
  self,
57
  base_model: str,
@@ -59,6 +67,7 @@ class Trainer:
59
  n_steps: int,
60
  concept_images: list | None,
61
  concept_prompt: str,
 
62
  learning_rate: float,
63
  gradient_accumulation: int,
64
  fp16: bool,
@@ -93,6 +102,7 @@ class Trainer:
93
 
94
  self.cleanup_dirs()
95
  self.prepare_dataset(concept_images, resolution)
 
96
 
97
  command = f"""
98
  accelerate launch train_dreambooth.py \
@@ -116,7 +126,7 @@ class Trainer:
116
  command += f""" --with_prior_preservation \
117
  --prior_loss_weight={prior_loss_weight} \
118
  --class_prompt="{class_prompt}" \
119
- --class_data_dir={self.output_dir / 'class_data'}
120
  """
121
 
122
  command += f""" --use_lora \
 
31
  self.is_running_message = "Another training is in progress."
32
 
33
  self.output_dir = pathlib.Path("results")
34
+ self.class_dir = self.output_dir / "class_data"
35
  self.instance_data_dir = self.output_dir / "training_data"
36
 
37
  def check_if_running(self) -> dict:
 
53
  out_path = self.instance_data_dir / f"{i:03d}.jpg"
54
  image.save(out_path, format="JPEG", quality=100)
55
 
56
+ def copy_class_data(self, class_images: list) -> None:
57
+ self.class_dir.mkdir(parents=True)
58
+ for i, temp_path in enumerate(class_images):
59
+ image = PIL.Image.open(temp_path.name)
60
+ out_path = self.class_dir / f"{i:03d}.jpg"
61
+ image.save(out_path, format="JPEG", quality=100)
62
+
63
  def run(
64
  self,
65
  base_model: str,
 
67
  n_steps: int,
68
  concept_images: list | None,
69
  concept_prompt: str,
70
+ class_images: list | None,
71
  learning_rate: float,
72
  gradient_accumulation: int,
73
  fp16: bool,
 
102
 
103
  self.cleanup_dirs()
104
  self.prepare_dataset(concept_images, resolution)
105
+ self.copy_class_data(class_images)
106
 
107
  command = f"""
108
  accelerate launch train_dreambooth.py \
 
126
  command += f""" --with_prior_preservation \
127
  --prior_loss_weight={prior_loss_weight} \
128
  --class_prompt="{class_prompt}" \
129
+ --class_data_dir={self.class_dir}
130
  """
131
 
132
  command += f""" --use_lora \