hysts HF staff commited on
Commit
1c9110a
1 Parent(s): 06be08c
Files changed (1) hide show
  1. trainer.py +5 -3
trainer.py CHANGED
@@ -32,8 +32,8 @@ class Trainer:
32
  self.is_running = False
33
  self.is_running_message = 'Another training is in progress.'
34
 
35
- self.instance_data_dir = pathlib.Path('training_data')
36
  self.output_dir = pathlib.Path('results')
 
37
 
38
  def check_if_running(self) -> dict:
39
  if self.is_running:
@@ -42,11 +42,10 @@ class Trainer:
42
  return gr.update(value='No training is running.')
43
 
44
  def cleanup_dirs(self) -> None:
45
- shutil.rmtree(self.instance_data_dir, ignore_errors=True)
46
  shutil.rmtree(self.output_dir, ignore_errors=True)
47
 
48
  def prepare_dataset(self, concept_images: list, resolution: int) -> None:
49
- self.instance_data_dir.mkdir()
50
  for i, temp_path in enumerate(concept_images):
51
  image = PIL.Image.open(temp_path.name)
52
  image = pad_image(image)
@@ -96,6 +95,9 @@ class Trainer:
96
  --lr_warmup_steps=0 \
97
  --max_train_steps={n_steps}
98
  '''
 
 
 
99
  res = subprocess.run(shlex.split(command))
100
  self.is_running = False
101
 
 
32
  self.is_running = False
33
  self.is_running_message = 'Another training is in progress.'
34
 
 
35
  self.output_dir = pathlib.Path('results')
36
+ self.instance_data_dir = self.output_dir / 'training_data'
37
 
38
  def check_if_running(self) -> dict:
39
  if self.is_running:
 
42
  return gr.update(value='No training is running.')
43
 
44
  def cleanup_dirs(self) -> None:
 
45
  shutil.rmtree(self.output_dir, ignore_errors=True)
46
 
47
  def prepare_dataset(self, concept_images: list, resolution: int) -> None:
48
+ self.instance_data_dir.mkdir(parents=True)
49
  for i, temp_path in enumerate(concept_images):
50
  image = PIL.Image.open(temp_path.name)
51
  image = pad_image(image)
 
95
  --lr_warmup_steps=0 \
96
  --max_train_steps={n_steps}
97
  '''
98
+ with open(self.output_dir / 'train.sh', 'w') as f:
99
+ command_s = ' '.join(command.split())
100
+ f.write(command_s)
101
  res = subprocess.run(shlex.split(command))
102
  self.is_running = False
103