Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
·
c8aa68b
1
Parent(s):
a29f551
fix seed and export
Browse files- StableDiffuser.py +1 -1
- app.py +11 -8
- train.py +1 -1
StableDiffuser.py
CHANGED
|
@@ -63,7 +63,7 @@ class StableDiffuser(torch.nn.Module):
|
|
| 63 |
def get_noise(self, batch_size, width, height, generator=None):
|
| 64 |
param = list(self.parameters())[0]
|
| 65 |
return torch.randn(
|
| 66 |
-
(batch_size, self.unet.in_channels, width // 8, height // 8),
|
| 67 |
generator=generator).type(param.dtype).to(param.device)
|
| 68 |
|
| 69 |
def add_noise(self, latents, noise, step):
|
|
|
|
| 63 |
def get_noise(self, batch_size, width, height, generator=None):
|
| 64 |
param = list(self.parameters())[0]
|
| 65 |
return torch.randn(
|
| 66 |
+
(batch_size, self.unet.config.in_channels, width // 8, height // 8),
|
| 67 |
generator=generator).type(param.dtype).to(param.device)
|
| 68 |
|
| 69 |
def add_noise(self, latents, noise, step):
|
app.py
CHANGED
|
@@ -199,12 +199,13 @@ class Demo:
|
|
| 199 |
|
| 200 |
with gr.Column():
|
| 201 |
self.train_memory_options = gr.Markdown(interactive=False,
|
| 202 |
-
value='Performance and VRAM usage optimizations, may not work on all devices
|
| 203 |
with gr.Row():
|
| 204 |
self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=True)
|
| 205 |
self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
|
| 206 |
self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
|
| 207 |
-
self.train_use_gradient_checkpointing_input = gr.Checkbox(
|
|
|
|
| 208 |
|
| 209 |
with gr.Column(scale=1):
|
| 210 |
|
|
@@ -248,9 +249,10 @@ class Demo:
|
|
| 248 |
)
|
| 249 |
|
| 250 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
| 251 |
self.export_button = gr.Button(
|
| 252 |
-
value="Export"
|
| 253 |
-
)
|
| 254 |
|
| 255 |
self.infr_button.click(self.inference, inputs = [
|
| 256 |
self.prompt_input_infr,
|
|
@@ -288,12 +290,12 @@ class Demo:
|
|
| 288 |
self.save_path_input_export,
|
| 289 |
self.save_half_export
|
| 290 |
],
|
| 291 |
-
outputs=[self.
|
| 292 |
)
|
| 293 |
|
| 294 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
| 295 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
| 296 |
-
seed
|
| 297 |
pbar = gr.Progress(track_tqdm=True)):
|
| 298 |
|
| 299 |
if self.training:
|
|
@@ -330,7 +332,7 @@ class Demo:
|
|
| 330 |
try:
|
| 331 |
self.training = True
|
| 332 |
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
| 333 |
-
use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing, seed=seed)
|
| 334 |
finally:
|
| 335 |
self.training = False
|
| 336 |
|
|
@@ -355,8 +357,9 @@ class Demo:
|
|
| 355 |
with finetuner:
|
| 356 |
if save_half:
|
| 357 |
diffuser = diffuser.half()
|
| 358 |
-
diffuser.pipeline.to(
|
| 359 |
diffuser.pipeline.save_pretrained(save_path)
|
|
|
|
| 360 |
|
| 361 |
|
| 362 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
|
|
|
| 199 |
|
| 200 |
with gr.Column():
|
| 201 |
self.train_memory_options = gr.Markdown(interactive=False,
|
| 202 |
+
value='Performance and VRAM usage optimizations, may not work on all devices:')
|
| 203 |
with gr.Row():
|
| 204 |
self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=True)
|
| 205 |
self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
|
| 206 |
self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
|
| 207 |
+
self.train_use_gradient_checkpointing_input = gr.Checkbox(
|
| 208 |
+
label="Gradient checkpointing", value=False)
|
| 209 |
|
| 210 |
with gr.Column(scale=1):
|
| 211 |
|
|
|
|
| 249 |
)
|
| 250 |
|
| 251 |
with gr.Column(scale=1):
|
| 252 |
+
self.export_status = gr.Button(
|
| 253 |
+
value='', variant='primary', label='Status', interactive=False)
|
| 254 |
self.export_button = gr.Button(
|
| 255 |
+
value="Export")
|
|
|
|
| 256 |
|
| 257 |
self.infr_button.click(self.inference, inputs = [
|
| 258 |
self.prompt_input_infr,
|
|
|
|
| 290 |
self.save_path_input_export,
|
| 291 |
self.save_half_export
|
| 292 |
],
|
| 293 |
+
outputs=[self.export_status]
|
| 294 |
)
|
| 295 |
|
| 296 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
| 297 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
| 298 |
+
seed=-1,
|
| 299 |
pbar = gr.Progress(track_tqdm=True)):
|
| 300 |
|
| 301 |
if self.training:
|
|
|
|
| 332 |
try:
|
| 333 |
self.training = True
|
| 334 |
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
| 335 |
+
use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing, seed=int(seed))
|
| 336 |
finally:
|
| 337 |
self.training = False
|
| 338 |
|
|
|
|
| 357 |
with finetuner:
|
| 358 |
if save_half:
|
| 359 |
diffuser = diffuser.half()
|
| 360 |
+
diffuser.pipeline.to('cpu', torch_dtype=torch.float16)
|
| 361 |
diffuser.pipeline.save_pretrained(save_path)
|
| 362 |
+
return [gr.update(value=f'Done! Your model is at {save_path}.')]
|
| 363 |
|
| 364 |
|
| 365 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
train.py
CHANGED
|
@@ -52,7 +52,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
| 52 |
|
| 53 |
if seed == -1:
|
| 54 |
seed = random.randint(0, 2 ** 30)
|
| 55 |
-
set_seed(seed)
|
| 56 |
|
| 57 |
for i in pbar:
|
| 58 |
with torch.no_grad():
|
|
|
|
| 52 |
|
| 53 |
if seed == -1:
|
| 54 |
seed = random.randint(0, 2 ** 30)
|
| 55 |
+
set_seed(int(seed))
|
| 56 |
|
| 57 |
for i in pbar:
|
| 58 |
with torch.no_grad():
|