hysts HF staff commited on
Commit
e755b70
1 Parent(s): b1a4d93

Support training text encoder

Browse files
Files changed (4) hide show
  1. app.py +16 -1
  2. inference.py +20 -0
  3. lora +1 -1
  4. trainer.py +16 -4
app.py CHANGED
@@ -83,6 +83,10 @@ def create_training_demo(trainer: Trainer,
83
  num_training_steps = gr.Number(
84
  label='Number of Training Steps', value=1000, precision=0)
85
  learning_rate = gr.Number(label='Learning Rate', value=0.0001)
 
 
 
 
86
  gradient_accumulation = gr.Number(
87
  label='Number of Gradient Accumulation',
88
  value=1,
@@ -113,6 +117,8 @@ def create_training_demo(trainer: Trainer,
113
  concept_prompt,
114
  num_training_steps,
115
  learning_rate,
 
 
116
  gradient_accumulation,
117
  fp16,
118
  use_8bit_adam,
@@ -136,6 +142,7 @@ def create_training_demo(trainer: Trainer,
136
  def find_weight_files() -> list[str]:
137
  curr_dir = pathlib.Path(__file__).parent
138
  paths = sorted(curr_dir.rglob('*.pt'))
 
139
  return [path.relative_to(curr_dir).as_posix() for path in paths]
140
 
141
 
@@ -165,6 +172,11 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
165
  maximum=2,
166
  step=0.05,
167
  value=1)
 
 
 
 
 
168
  seed = gr.Slider(label='Seed',
169
  minimum=0,
170
  maximum=100000,
@@ -185,7 +197,8 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
185
  run_button = gr.Button('Generate')
186
 
187
  gr.Markdown('''
188
- - The pretrained models are trained with the concept prompt "style of sks".
 
189
  ''')
190
  with gr.Column():
191
  result = gr.Image(label='Result')
@@ -199,6 +212,7 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
199
  lora_weight_name,
200
  prompt,
201
  alpha,
 
202
  seed,
203
  num_steps,
204
  guidance_scale,
@@ -211,6 +225,7 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
211
  lora_weight_name,
212
  prompt,
213
  alpha,
 
214
  seed,
215
  num_steps,
216
  guidance_scale,
 
83
  num_training_steps = gr.Number(
84
  label='Number of Training Steps', value=1000, precision=0)
85
  learning_rate = gr.Number(label='Learning Rate', value=0.0001)
86
+ train_text_encoder = gr.Checkbox(label='Train Text Encoder',
87
+ value=False)
88
+ learning_rate_text = gr.Number(
89
+ label='Learning Rate for Text Encoder', value=0.00005)
90
  gradient_accumulation = gr.Number(
91
  label='Number of Gradient Accumulation',
92
  value=1,
 
117
  concept_prompt,
118
  num_training_steps,
119
  learning_rate,
120
+ train_text_encoder,
121
+ learning_rate_text,
122
  gradient_accumulation,
123
  fp16,
124
  use_8bit_adam,
 
142
  def find_weight_files() -> list[str]:
143
  curr_dir = pathlib.Path(__file__).parent
144
  paths = sorted(curr_dir.rglob('*.pt'))
145
+ paths = [path for path in paths if not path.stem.endswith('.text_encoder')]
146
  return [path.relative_to(curr_dir).as_posix() for path in paths]
147
 
148
 
 
172
  maximum=2,
173
  step=0.05,
174
  value=1)
175
+ alpha_for_text = gr.Slider(label='Alpha for Text Encoder',
176
+ minimum=0,
177
+ maximum=2,
178
+ step=0.05,
179
+ value=1)
180
  seed = gr.Slider(label='Seed',
181
  minimum=0,
182
  maximum=100000,
 
197
  run_button = gr.Button('Generate')
198
 
199
  gr.Markdown('''
200
+ - The pretrained models for "disney", "illust" and "pop" are trained with the concept prompt "style of sks".
201
+ - The pretrained model for "kiriko" is trained with the concept prompt "game character bnha". For this model, the text encoder is also trained.
202
  ''')
203
  with gr.Column():
204
  result = gr.Image(label='Result')
 
212
  lora_weight_name,
213
  prompt,
214
  alpha,
215
+ alpha_for_text,
216
  seed,
217
  num_steps,
218
  guidance_scale,
 
225
  lora_weight_name,
226
  prompt,
227
  alpha,
228
+ alpha_for_text,
229
  seed,
230
  num_steps,
231
  guidance_scale,
inference.py CHANGED
@@ -32,6 +32,14 @@ class InferencePipeline:
32
  curr_dir = pathlib.Path(__file__).parent
33
  return curr_dir / name
34
 
 
 
 
 
 
 
 
 
35
  def load_pipe(self, model_id: str, lora_filename: str) -> None:
36
  weight_path = self.get_lora_weight_path(lora_filename)
37
  if weight_path == self.weight_path:
@@ -47,6 +55,16 @@ class InferencePipeline:
47
  pipe = pipe.to(self.device)
48
 
49
  monkeypatch_lora(pipe.unet, lora_weight)
 
 
 
 
 
 
 
 
 
 
50
  self.pipe = pipe
51
 
52
  def run(
@@ -55,6 +73,7 @@ class InferencePipeline:
55
  lora_weight_name: str,
56
  prompt: str,
57
  alpha: float,
 
58
  seed: int,
59
  n_steps: int,
60
  guidance_scale: float,
@@ -66,6 +85,7 @@ class InferencePipeline:
66
 
67
  generator = torch.Generator(device=self.device).manual_seed(seed)
68
  tune_lora_scale(self.pipe.unet, alpha) # type: ignore
 
69
  out = self.pipe(prompt,
70
  num_inference_steps=n_steps,
71
  guidance_scale=guidance_scale,
 
32
  curr_dir = pathlib.Path(__file__).parent
33
  return curr_dir / name
34
 
35
+ @staticmethod
36
+ def get_lora_text_encoder_weight_path(path: pathlib.Path) -> str:
37
+ parent_dir = path.parent
38
+ stem = path.stem
39
+ text_encoder_filename = f'{stem}.text_encoder.pt'
40
+ path = parent_dir / text_encoder_filename
41
+ return path.as_posix() if path.exists() else ''
42
+
43
  def load_pipe(self, model_id: str, lora_filename: str) -> None:
44
  weight_path = self.get_lora_weight_path(lora_filename)
45
  if weight_path == self.weight_path:
 
55
  pipe = pipe.to(self.device)
56
 
57
  monkeypatch_lora(pipe.unet, lora_weight)
58
+
59
+ lora_text_encoder_weight_path = self.get_lora_text_encoder_weight_path(
60
+ weight_path)
61
+ if lora_text_encoder_weight_path:
62
+ lora_text_encoder_weight = torch.load(
63
+ lora_text_encoder_weight_path, map_location=self.device)
64
+ monkeypatch_lora(pipe.text_encoder,
65
+ lora_text_encoder_weight,
66
+ target_replace_module=['CLIPAttention'])
67
+
68
  self.pipe = pipe
69
 
70
  def run(
 
73
  lora_weight_name: str,
74
  prompt: str,
75
  alpha: float,
76
+ alpha_for_text: float,
77
  seed: int,
78
  n_steps: int,
79
  guidance_scale: float,
 
85
 
86
  generator = torch.Generator(device=self.device).manual_seed(seed)
87
  tune_lora_scale(self.pipe.unet, alpha) # type: ignore
88
+ tune_lora_scale(self.pipe.text_encoder, alpha_for_text) # type: ignore
89
  out = self.pipe(prompt,
90
  num_inference_steps=n_steps,
91
  guidance_scale=guidance_scale,
lora CHANGED
@@ -1 +1 @@
1
- Subproject commit ba349e56e23e92e3b128c7c67ae58d3067540daa
 
1
+ Subproject commit 26787a09bff4ebcb08f0ad4e848b67bce4389a7a
trainer.py CHANGED
@@ -54,10 +54,20 @@ class Trainer:
54
  out_path = self.instance_data_dir / f'{i:03d}.jpg'
55
  image.save(out_path, format='JPEG', quality=100)
56
 
57
- def run(self, base_model: str, resolution_s: str,
58
- concept_images: list | None, concept_prompt: str, n_steps: int,
59
- learning_rate: float, gradient_accumulation: int, fp16: bool,
60
- use_8bit_adam: bool) -> tuple[dict, str]:
 
 
 
 
 
 
 
 
 
 
61
  if not torch.cuda.is_available():
62
  raise gr.Error('CUDA is not available.')
63
 
@@ -93,6 +103,8 @@ class Trainer:
93
  command += ' --mixed_precision fp16 '
94
  if use_8bit_adam:
95
  command += ' --use_8bit_adam'
 
 
96
 
97
  with open(self.output_dir / 'train.sh', 'w') as f:
98
  command_s = ' '.join(command.split())
 
54
  out_path = self.instance_data_dir / f'{i:03d}.jpg'
55
  image.save(out_path, format='JPEG', quality=100)
56
 
57
+ def run(
58
+ self,
59
+ base_model: str,
60
+ resolution_s: str,
61
+ concept_images: list | None,
62
+ concept_prompt: str,
63
+ n_steps: int,
64
+ learning_rate: float,
65
+ train_text_encoder: bool,
66
+ learning_rate_text: float,
67
+ gradient_accumulation: int,
68
+ fp16: bool,
69
+ use_8bit_adam: bool,
70
+ ) -> tuple[dict, str]:
71
  if not torch.cuda.is_available():
72
  raise gr.Error('CUDA is not available.')
73
 
 
103
  command += ' --mixed_precision fp16 '
104
  if use_8bit_adam:
105
  command += ' --use_8bit_adam'
106
+ if train_text_encoder:
107
+ command += f' --train_text_encoder --learning_rate_text={learning_rate_text} --color_jitter'
108
 
109
  with open(self.output_dir / 'train.sh', 'w') as f:
110
  command_s = ' '.join(command.split())