Spaces:
Build error
Build error
Update
Browse files- trainer.py +5 -3
trainer.py
CHANGED
@@ -32,8 +32,8 @@ class Trainer:
|
|
32 |
self.is_running = False
|
33 |
self.is_running_message = 'Another training is in progress.'
|
34 |
|
35 |
-
self.instance_data_dir = pathlib.Path('training_data')
|
36 |
self.output_dir = pathlib.Path('results')
|
|
|
37 |
|
38 |
def check_if_running(self) -> dict:
|
39 |
if self.is_running:
|
@@ -42,11 +42,10 @@ class Trainer:
|
|
42 |
return gr.update(value='No training is running.')
|
43 |
|
44 |
def cleanup_dirs(self) -> None:
|
45 |
-
shutil.rmtree(self.instance_data_dir, ignore_errors=True)
|
46 |
shutil.rmtree(self.output_dir, ignore_errors=True)
|
47 |
|
48 |
def prepare_dataset(self, concept_images: list, resolution: int) -> None:
|
49 |
-
self.instance_data_dir.mkdir()
|
50 |
for i, temp_path in enumerate(concept_images):
|
51 |
image = PIL.Image.open(temp_path.name)
|
52 |
image = pad_image(image)
|
@@ -96,6 +95,9 @@ class Trainer:
|
|
96 |
--lr_warmup_steps=0 \
|
97 |
--max_train_steps={n_steps}
|
98 |
'''
|
|
|
|
|
|
|
99 |
res = subprocess.run(shlex.split(command))
|
100 |
self.is_running = False
|
101 |
|
|
|
32 |
self.is_running = False
|
33 |
self.is_running_message = 'Another training is in progress.'
|
34 |
|
|
|
35 |
self.output_dir = pathlib.Path('results')
|
36 |
+
self.instance_data_dir = self.output_dir / 'training_data'
|
37 |
|
38 |
def check_if_running(self) -> dict:
|
39 |
if self.is_running:
|
|
|
42 |
return gr.update(value='No training is running.')
|
43 |
|
44 |
def cleanup_dirs(self) -> None:
|
|
|
45 |
shutil.rmtree(self.output_dir, ignore_errors=True)
|
46 |
|
47 |
def prepare_dataset(self, concept_images: list, resolution: int) -> None:
|
48 |
+
self.instance_data_dir.mkdir(parents=True)
|
49 |
for i, temp_path in enumerate(concept_images):
|
50 |
image = PIL.Image.open(temp_path.name)
|
51 |
image = pad_image(image)
|
|
|
95 |
--lr_warmup_steps=0 \
|
96 |
--max_train_steps={n_steps}
|
97 |
'''
|
98 |
+
with open(self.output_dir / 'train.sh', 'w') as f:
|
99 |
+
command_s = ' '.join(command.split())
|
100 |
+
f.write(command_s)
|
101 |
res = subprocess.run(shlex.split(command))
|
102 |
self.is_running = False
|
103 |
|