hysts HF staff commited on
Commit
d87e974
1 Parent(s): 5520f5c
Files changed (1) hide show
  1. trainer.py +4 -7
trainer.py CHANGED
@@ -67,13 +67,12 @@ class Trainer:
67
  gradient_accumulation: int,
68
  fp16: bool,
69
  use_8bit_adam: bool,
70
- ) -> tuple[dict, str]:
71
  if not torch.cuda.is_available():
72
  raise gr.Error('CUDA is not available.')
73
 
74
- out_path = ''
75
  if self.is_running:
76
- return gr.update(value=self.is_running_message), out_path
77
 
78
  if concept_images is None:
79
  raise gr.Error('You need to upload images.')
@@ -116,9 +115,7 @@ class Trainer:
116
 
117
  if res.returncode == 0:
118
  result_message = 'Training Completed!'
119
- weight_path = self.output_dir / 'lora_weight.pt'
120
- if weight_path.exists():
121
- out_path = weight_path.as_posix()
122
  else:
123
  result_message = 'Training Failed!'
124
- return gr.update(value=result_message), out_path
 
 
67
  gradient_accumulation: int,
68
  fp16: bool,
69
  use_8bit_adam: bool,
70
+ ) -> tuple[dict, list[pathlib.Path]]:
71
  if not torch.cuda.is_available():
72
  raise gr.Error('CUDA is not available.')
73
 
 
74
  if self.is_running:
75
+ return gr.update(value=self.is_running_message), []
76
 
77
  if concept_images is None:
78
  raise gr.Error('You need to upload images.')
 
115
 
116
  if res.returncode == 0:
117
  result_message = 'Training Completed!'
 
 
 
118
  else:
119
  result_message = 'Training Failed!'
120
+ weight_paths = sorted(self.output_dir.glob('*.pt'))
121
+ return gr.update(value=result_message), weight_paths