Spaces:
Build error
Build error
Fix
Browse files- trainer.py +4 -7
trainer.py
CHANGED
@@ -67,13 +67,12 @@ class Trainer:
|
|
67 |
gradient_accumulation: int,
|
68 |
fp16: bool,
|
69 |
use_8bit_adam: bool,
|
70 |
-
) -> tuple[dict,
|
71 |
if not torch.cuda.is_available():
|
72 |
raise gr.Error('CUDA is not available.')
|
73 |
|
74 |
-
out_path = ''
|
75 |
if self.is_running:
|
76 |
-
return gr.update(value=self.is_running_message),
|
77 |
|
78 |
if concept_images is None:
|
79 |
raise gr.Error('You need to upload images.')
|
@@ -116,9 +115,7 @@ class Trainer:
|
|
116 |
|
117 |
if res.returncode == 0:
|
118 |
result_message = 'Training Completed!'
|
119 |
-
weight_path = self.output_dir / 'lora_weight.pt'
|
120 |
-
if weight_path.exists():
|
121 |
-
out_path = weight_path.as_posix()
|
122 |
else:
|
123 |
result_message = 'Training Failed!'
|
124 |
-
|
|
|
|
67 |
gradient_accumulation: int,
|
68 |
fp16: bool,
|
69 |
use_8bit_adam: bool,
|
70 |
+
) -> tuple[dict, list[pathlib.Path]]:
|
71 |
if not torch.cuda.is_available():
|
72 |
raise gr.Error('CUDA is not available.')
|
73 |
|
|
|
74 |
if self.is_running:
|
75 |
+
return gr.update(value=self.is_running_message), []
|
76 |
|
77 |
if concept_images is None:
|
78 |
raise gr.Error('You need to upload images.')
|
|
|
115 |
|
116 |
if res.returncode == 0:
|
117 |
result_message = 'Training Completed!'
|
|
|
|
|
|
|
118 |
else:
|
119 |
result_message = 'Training Failed!'
|
120 |
+
weight_paths = sorted(self.output_dir.glob('*.pt'))
|
121 |
+
return gr.update(value=result_message), weight_paths
|