Spaces:
Runtime error
Runtime error
Julian Bilcke
commited on
Commit
·
c6546ad
1
Parent(s):
38cfbff
cleaning code
Browse files- vms/config.py +116 -89
- vms/services/trainer.py +90 -47
- vms/tabs/train_tab.py +153 -177
- vms/ui/video_trainer_ui.py +79 -43
vms/config.py
CHANGED
|
@@ -58,9 +58,9 @@ JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
|
|
| 58 |
|
| 59 |
# Expanded model types to include Wan-2.1-T2V
|
| 60 |
MODEL_TYPES = {
|
| 61 |
-
"HunyuanVideo
|
| 62 |
-
"LTX-Video
|
| 63 |
-
"Wan-2.1-T2V
|
| 64 |
}
|
| 65 |
|
| 66 |
# Training types
|
|
@@ -69,6 +69,23 @@ TRAINING_TYPES = {
|
|
| 69 |
"Full Finetune": "full-finetune"
|
| 70 |
}
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# it is best to use resolutions that are powers of 8
|
| 74 |
# The resolution should be divisible by 32
|
|
@@ -87,39 +104,49 @@ MEDIUM_19_9_RATIO_HEIGHT = 512 # 32 * 16
|
|
| 87 |
NB_FRAMES_1 = 1 # 1
|
| 88 |
NB_FRAMES_9 = 8 + 1 # 8 + 1
|
| 89 |
NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
# 256 isn't a lot by the way, especially with 60 FPS videos..
|
| 104 |
# can we crank it and put more frames in here?
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
SMALL_TRAINING_BUCKETS = [
|
| 107 |
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
|
| 108 |
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
|
| 109 |
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
|
| 110 |
-
(
|
| 111 |
-
(
|
| 112 |
-
(
|
| 113 |
-
(
|
| 114 |
-
(
|
| 115 |
-
(
|
| 116 |
-
(
|
| 117 |
-
(
|
| 118 |
-
(
|
| 119 |
-
(
|
| 120 |
-
(
|
| 121 |
-
(
|
| 122 |
-
(
|
| 123 |
]
|
| 124 |
|
| 125 |
MEDIUM_19_9_RATIO_WIDTH = 928 # 32 * 29
|
|
@@ -129,19 +156,19 @@ MEDIUM_19_9_RATIO_BUCKETS = [
|
|
| 129 |
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
|
| 130 |
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
|
| 131 |
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
|
| 132 |
-
(
|
| 133 |
-
(
|
| 134 |
-
(
|
| 135 |
-
(
|
| 136 |
-
(
|
| 137 |
-
(
|
| 138 |
-
(
|
| 139 |
-
(
|
| 140 |
-
(
|
| 141 |
-
(
|
| 142 |
-
(
|
| 143 |
-
(
|
| 144 |
-
(
|
| 145 |
]
|
| 146 |
|
| 147 |
# Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
|
|
@@ -149,24 +176,24 @@ TRAINING_PRESETS = {
|
|
| 149 |
"HunyuanVideo (normal)": {
|
| 150 |
"model_type": "hunyuan_video",
|
| 151 |
"training_type": "lora",
|
| 152 |
-
"lora_rank":
|
| 153 |
-
"lora_alpha":
|
| 154 |
-
"
|
| 155 |
-
"batch_size":
|
| 156 |
"learning_rate": 2e-5,
|
| 157 |
-
"save_iterations":
|
| 158 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
| 159 |
"flow_weighting_scheme": "none"
|
| 160 |
},
|
| 161 |
"LTX-Video (normal)": {
|
| 162 |
"model_type": "ltx_video",
|
| 163 |
"training_type": "lora",
|
| 164 |
-
"lora_rank":
|
| 165 |
-
"lora_alpha":
|
| 166 |
-
"
|
| 167 |
-
"batch_size":
|
| 168 |
-
"learning_rate":
|
| 169 |
-
"save_iterations":
|
| 170 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
| 171 |
"flow_weighting_scheme": "logit_normal"
|
| 172 |
},
|
|
@@ -174,21 +201,21 @@ TRAINING_PRESETS = {
|
|
| 174 |
"model_type": "ltx_video",
|
| 175 |
"training_type": "lora",
|
| 176 |
"lora_rank": "256",
|
| 177 |
-
"lora_alpha":
|
| 178 |
-
"
|
| 179 |
-
"batch_size":
|
| 180 |
-
"learning_rate":
|
| 181 |
-
"save_iterations":
|
| 182 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
| 183 |
"flow_weighting_scheme": "logit_normal"
|
| 184 |
},
|
| 185 |
"LTX-Video (Full Finetune)": {
|
| 186 |
"model_type": "ltx_video",
|
| 187 |
"training_type": "full-finetune",
|
| 188 |
-
"
|
| 189 |
-
"batch_size":
|
| 190 |
-
"learning_rate":
|
| 191 |
-
"save_iterations":
|
| 192 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
| 193 |
"flow_weighting_scheme": "logit_normal"
|
| 194 |
},
|
|
@@ -197,10 +224,10 @@ TRAINING_PRESETS = {
|
|
| 197 |
"training_type": "lora",
|
| 198 |
"lora_rank": "32",
|
| 199 |
"lora_alpha": "32",
|
| 200 |
-
"
|
| 201 |
-
"batch_size":
|
| 202 |
"learning_rate": 5e-5,
|
| 203 |
-
"save_iterations":
|
| 204 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
| 205 |
"flow_weighting_scheme": "logit_normal"
|
| 206 |
},
|
|
@@ -209,10 +236,10 @@ TRAINING_PRESETS = {
|
|
| 209 |
"training_type": "lora",
|
| 210 |
"lora_rank": "64",
|
| 211 |
"lora_alpha": "64",
|
| 212 |
-
"
|
| 213 |
-
"batch_size":
|
| 214 |
-
"learning_rate":
|
| 215 |
-
"save_iterations":
|
| 216 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
| 217 |
"flow_weighting_scheme": "logit_normal"
|
| 218 |
}
|
|
@@ -244,7 +271,7 @@ class TrainingConfig:
|
|
| 244 |
id_token: Optional[str] = None
|
| 245 |
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SMALL_TRAINING_BUCKETS)
|
| 246 |
video_reshape_mode: str = "center"
|
| 247 |
-
caption_dropout_p: float =
|
| 248 |
caption_dropout_technique: str = "empty"
|
| 249 |
precompute_conditions: bool = False
|
| 250 |
|
|
@@ -257,16 +284,16 @@ class TrainingConfig:
|
|
| 257 |
|
| 258 |
# Training arguments
|
| 259 |
training_type: str = "lora"
|
| 260 |
-
seed: int =
|
| 261 |
mixed_precision: str = "bf16"
|
| 262 |
batch_size: int = 1
|
| 263 |
-
|
| 264 |
-
lora_rank: int =
|
| 265 |
-
lora_alpha: int =
|
| 266 |
target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
|
| 267 |
gradient_accumulation_steps: int = 1
|
| 268 |
gradient_checkpointing: bool = True
|
| 269 |
-
checkpointing_steps: int =
|
| 270 |
checkpointing_limit: Optional[int] = 2
|
| 271 |
resume_from_checkpoint: Optional[str] = None
|
| 272 |
enable_slicing: bool = True
|
|
@@ -300,15 +327,15 @@ class TrainingConfig:
|
|
| 300 |
data_root=data_path,
|
| 301 |
output_dir=output_path,
|
| 302 |
batch_size=1,
|
| 303 |
-
|
| 304 |
lr=2e-5,
|
| 305 |
gradient_checkpointing=True,
|
| 306 |
id_token="afkx",
|
| 307 |
gradient_accumulation_steps=1,
|
| 308 |
-
lora_rank=
|
| 309 |
-
lora_alpha=
|
| 310 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
| 311 |
-
caption_dropout_p=
|
| 312 |
flow_weighting_scheme="none", # Hunyuan specific
|
| 313 |
training_type="lora"
|
| 314 |
)
|
|
@@ -322,15 +349,15 @@ class TrainingConfig:
|
|
| 322 |
data_root=data_path,
|
| 323 |
output_dir=output_path,
|
| 324 |
batch_size=1,
|
| 325 |
-
|
| 326 |
-
lr=
|
| 327 |
gradient_checkpointing=True,
|
| 328 |
id_token="BW_STYLE",
|
| 329 |
gradient_accumulation_steps=4,
|
| 330 |
-
lora_rank=
|
| 331 |
-
lora_alpha=
|
| 332 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
| 333 |
-
caption_dropout_p=
|
| 334 |
flow_weighting_scheme="logit_normal", # LTX specific
|
| 335 |
training_type="lora"
|
| 336 |
)
|
|
@@ -344,13 +371,13 @@ class TrainingConfig:
|
|
| 344 |
data_root=data_path,
|
| 345 |
output_dir=output_path,
|
| 346 |
batch_size=1,
|
| 347 |
-
|
| 348 |
lr=1e-5,
|
| 349 |
gradient_checkpointing=True,
|
| 350 |
id_token="BW_STYLE",
|
| 351 |
gradient_accumulation_steps=1,
|
| 352 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
| 353 |
-
caption_dropout_p=
|
| 354 |
flow_weighting_scheme="logit_normal", # LTX specific
|
| 355 |
training_type="full-finetune"
|
| 356 |
)
|
|
@@ -364,7 +391,7 @@ class TrainingConfig:
|
|
| 364 |
data_root=data_path,
|
| 365 |
output_dir=output_path,
|
| 366 |
batch_size=1,
|
| 367 |
-
|
| 368 |
lr=5e-5,
|
| 369 |
gradient_checkpointing=True,
|
| 370 |
id_token=None, # Default is no ID token for Wan
|
|
@@ -373,7 +400,7 @@ class TrainingConfig:
|
|
| 373 |
lora_alpha=32,
|
| 374 |
target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
|
| 375 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
| 376 |
-
caption_dropout_p=
|
| 377 |
flow_weighting_scheme="logit_normal", # Wan specific
|
| 378 |
training_type="lora"
|
| 379 |
)
|
|
@@ -428,7 +455,7 @@ class TrainingConfig:
|
|
| 428 |
#args.extend(["--mixed_precision", self.mixed_precision])
|
| 429 |
|
| 430 |
args.extend(["--batch_size", str(self.batch_size)])
|
| 431 |
-
args.extend(["--train_steps", str(self.
|
| 432 |
|
| 433 |
# LoRA specific arguments
|
| 434 |
if self.training_type == "lora":
|
|
|
|
| 58 |
|
| 59 |
# Expanded model types to include Wan-2.1-T2V
|
| 60 |
MODEL_TYPES = {
|
| 61 |
+
"HunyuanVideo": "hunyuan_video",
|
| 62 |
+
"LTX-Video": "ltx_video",
|
| 63 |
+
"Wan-2.1-T2V": "wan"
|
| 64 |
}
|
| 65 |
|
| 66 |
# Training types
|
|
|
|
| 69 |
"Full Finetune": "full-finetune"
|
| 70 |
}
|
| 71 |
|
| 72 |
+
DEFAULT_SEED = 42
|
| 73 |
+
|
| 74 |
+
DEFAULT_NB_TRAINING_STEPS = 1000
|
| 75 |
+
|
| 76 |
+
DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS = 200
|
| 77 |
+
|
| 78 |
+
DEFAULT_LORA_RANK = 128
|
| 79 |
+
DEFAULT_LORA_RANK_STR = str(DEFAULT_LORA_RANK)
|
| 80 |
+
|
| 81 |
+
DEFAULT_LORA_ALPHA = 128
|
| 82 |
+
DEFAULT_LORA_ALPHA_STR = str(DEFAULT_LORA_ALPHA)
|
| 83 |
+
|
| 84 |
+
DEFAULT_CAPTION_DROPOUT_P = 0.05
|
| 85 |
+
|
| 86 |
+
DEFAULT_BATCH_SIZE = 1
|
| 87 |
+
|
| 88 |
+
DEFAULT_LEARNING_RATE = 3e-5
|
| 89 |
|
| 90 |
# it is best to use resolutions that are powers of 8
|
| 91 |
# The resolution should be divisible by 32
|
|
|
|
| 104 |
NB_FRAMES_1 = 1 # 1
|
| 105 |
NB_FRAMES_9 = 8 + 1 # 8 + 1
|
| 106 |
NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
|
| 107 |
+
NB_FRAMES_33 = 8 * 4 + 1 # 32 + 1
|
| 108 |
+
NB_FRAMES_49 = 8 * 6 + 1 # 48 + 1
|
| 109 |
+
NB_FRAMES_65 = 8 * 8 + 1 # 64 + 1
|
| 110 |
+
NB_FRAMES_81 = 8 * 10 + 1 # 80 + 1
|
| 111 |
+
NB_FRAMES_97 = 8 * 12 + 1 # 96 + 1
|
| 112 |
+
NB_FRAMES_113 = 8 * 14 + 1 # 112 + 1
|
| 113 |
+
NB_FRAMES_129 = 8 * 16 + 1 # 128 + 1
|
| 114 |
+
NB_FRAMES_145 = 8 * 18 + 1 # 144 + 1
|
| 115 |
+
NB_FRAMES_161 = 8 * 20 + 1 # 160 + 1
|
| 116 |
+
NB_FRAMES_177 = 8 * 22 + 1 # 176 + 1
|
| 117 |
+
NB_FRAMES_193 = 8 * 24 + 1 # 192 + 1
|
| 118 |
+
NB_FRAMES_225 = 8 * 28 + 1 # 224 + 1
|
| 119 |
+
NB_FRAMES_257 = 8 * 32 + 1 # 256 + 1
|
| 120 |
# 256 isn't a lot by the way, especially with 60 FPS videos..
|
| 121 |
# can we crank it and put more frames in here?
|
| 122 |
|
| 123 |
+
NB_FRAMES_273 = 8 * 34 + 1 # 272 + 1
|
| 124 |
+
NB_FRAMES_289 = 8 * 36 + 1 # 288 + 1
|
| 125 |
+
NB_FRAMES_305 = 8 * 38 + 1 # 304 + 1
|
| 126 |
+
NB_FRAMES_321 = 8 * 40 + 1 # 320 + 1
|
| 127 |
+
NB_FRAMES_337 = 8 * 42 + 1 # 336 + 1
|
| 128 |
+
NB_FRAMES_353 = 8 * 44 + 1 # 352 + 1
|
| 129 |
+
NB_FRAMES_369 = 8 * 46 + 1 # 368 + 1
|
| 130 |
+
NB_FRAMES_385 = 8 * 48 + 1 # 384 + 1
|
| 131 |
+
NB_FRAMES_401 = 8 * 50 + 1 # 400 + 1
|
| 132 |
+
|
| 133 |
SMALL_TRAINING_BUCKETS = [
|
| 134 |
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
|
| 135 |
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
|
| 136 |
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
|
| 137 |
+
(NB_FRAMES_33, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
|
| 138 |
+
(NB_FRAMES_49, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
|
| 139 |
+
(NB_FRAMES_65, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
|
| 140 |
+
(NB_FRAMES_81, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
|
| 141 |
+
(NB_FRAMES_97, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
|
| 142 |
+
(NB_FRAMES_113, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
|
| 143 |
+
(NB_FRAMES_129, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
|
| 144 |
+
(NB_FRAMES_145, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
|
| 145 |
+
(NB_FRAMES_161, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
|
| 146 |
+
(NB_FRAMES_177, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
|
| 147 |
+
(NB_FRAMES_193, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
|
| 148 |
+
(NB_FRAMES_225, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
|
| 149 |
+
(NB_FRAMES_257, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
|
| 150 |
]
|
| 151 |
|
| 152 |
MEDIUM_19_9_RATIO_WIDTH = 928 # 32 * 29
|
|
|
|
| 156 |
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
|
| 157 |
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
|
| 158 |
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
|
| 159 |
+
(NB_FRAMES_33, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
|
| 160 |
+
(NB_FRAMES_49, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
|
| 161 |
+
(NB_FRAMES_65, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
|
| 162 |
+
(NB_FRAMES_81, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
|
| 163 |
+
(NB_FRAMES_97, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
|
| 164 |
+
(NB_FRAMES_113, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
|
| 165 |
+
(NB_FRAMES_129, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
|
| 166 |
+
(NB_FRAMES_145, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
|
| 167 |
+
(NB_FRAMES_161, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
|
| 168 |
+
(NB_FRAMES_177, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
|
| 169 |
+
(NB_FRAMES_193, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
|
| 170 |
+
(NB_FRAMES_225, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
|
| 171 |
+
(NB_FRAMES_257, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
|
| 172 |
]
|
| 173 |
|
| 174 |
# Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
|
|
|
|
| 176 |
"HunyuanVideo (normal)": {
|
| 177 |
"model_type": "hunyuan_video",
|
| 178 |
"training_type": "lora",
|
| 179 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 180 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
| 181 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 182 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 183 |
"learning_rate": 2e-5,
|
| 184 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 185 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
| 186 |
"flow_weighting_scheme": "none"
|
| 187 |
},
|
| 188 |
"LTX-Video (normal)": {
|
| 189 |
"model_type": "ltx_video",
|
| 190 |
"training_type": "lora",
|
| 191 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 192 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
| 193 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 194 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 195 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 196 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 197 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
| 198 |
"flow_weighting_scheme": "logit_normal"
|
| 199 |
},
|
|
|
|
| 201 |
"model_type": "ltx_video",
|
| 202 |
"training_type": "lora",
|
| 203 |
"lora_rank": "256",
|
| 204 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
| 205 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 206 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 207 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 208 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 209 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
| 210 |
"flow_weighting_scheme": "logit_normal"
|
| 211 |
},
|
| 212 |
"LTX-Video (Full Finetune)": {
|
| 213 |
"model_type": "ltx_video",
|
| 214 |
"training_type": "full-finetune",
|
| 215 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 216 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 217 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 218 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 219 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
| 220 |
"flow_weighting_scheme": "logit_normal"
|
| 221 |
},
|
|
|
|
| 224 |
"training_type": "lora",
|
| 225 |
"lora_rank": "32",
|
| 226 |
"lora_alpha": "32",
|
| 227 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 228 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 229 |
"learning_rate": 5e-5,
|
| 230 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 231 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
| 232 |
"flow_weighting_scheme": "logit_normal"
|
| 233 |
},
|
|
|
|
| 236 |
"training_type": "lora",
|
| 237 |
"lora_rank": "64",
|
| 238 |
"lora_alpha": "64",
|
| 239 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 240 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 241 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 242 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 243 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
| 244 |
"flow_weighting_scheme": "logit_normal"
|
| 245 |
}
|
|
|
|
| 271 |
id_token: Optional[str] = None
|
| 272 |
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SMALL_TRAINING_BUCKETS)
|
| 273 |
video_reshape_mode: str = "center"
|
| 274 |
+
caption_dropout_p: float = DEFAULT_CAPTION_DROPOUT_P
|
| 275 |
caption_dropout_technique: str = "empty"
|
| 276 |
precompute_conditions: bool = False
|
| 277 |
|
|
|
|
| 284 |
|
| 285 |
# Training arguments
|
| 286 |
training_type: str = "lora"
|
| 287 |
+
seed: int = DEFAULT_SEED
|
| 288 |
mixed_precision: str = "bf16"
|
| 289 |
batch_size: int = 1
|
| 290 |
+
train_step: int = DEFAULT_NB_TRAINING_STEPS
|
| 291 |
+
lora_rank: int = DEFAULT_LORA_RANK
|
| 292 |
+
lora_alpha: int = DEFAULT_LORA_ALPHA
|
| 293 |
target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
|
| 294 |
gradient_accumulation_steps: int = 1
|
| 295 |
gradient_checkpointing: bool = True
|
| 296 |
+
checkpointing_steps: int = DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS
|
| 297 |
checkpointing_limit: Optional[int] = 2
|
| 298 |
resume_from_checkpoint: Optional[str] = None
|
| 299 |
enable_slicing: bool = True
|
|
|
|
| 327 |
data_root=data_path,
|
| 328 |
output_dir=output_path,
|
| 329 |
batch_size=1,
|
| 330 |
+
train_steps=DEFAULT_NB_TRAINING_STEPS,
|
| 331 |
lr=2e-5,
|
| 332 |
gradient_checkpointing=True,
|
| 333 |
id_token="afkx",
|
| 334 |
gradient_accumulation_steps=1,
|
| 335 |
+
lora_rank=DEFAULT_LORA_RANK,
|
| 336 |
+
lora_alpha=DEFAULT_LORA_ALPHA,
|
| 337 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
| 338 |
+
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
| 339 |
flow_weighting_scheme="none", # Hunyuan specific
|
| 340 |
training_type="lora"
|
| 341 |
)
|
|
|
|
| 349 |
data_root=data_path,
|
| 350 |
output_dir=output_path,
|
| 351 |
batch_size=1,
|
| 352 |
+
train_steps=DEFAULT_NB_TRAINING_STEPS,
|
| 353 |
+
lr=DEFAULT_LEARNING_RATE,
|
| 354 |
gradient_checkpointing=True,
|
| 355 |
id_token="BW_STYLE",
|
| 356 |
gradient_accumulation_steps=4,
|
| 357 |
+
lora_rank=DEFAULT_LORA_RANK,
|
| 358 |
+
lora_alpha=DEFAULT_LORA_ALPHA,
|
| 359 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
| 360 |
+
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
| 361 |
flow_weighting_scheme="logit_normal", # LTX specific
|
| 362 |
training_type="lora"
|
| 363 |
)
|
|
|
|
| 371 |
data_root=data_path,
|
| 372 |
output_dir=output_path,
|
| 373 |
batch_size=1,
|
| 374 |
+
train_steps=DEFAULT_NB_TRAINING_STEPS,
|
| 375 |
lr=1e-5,
|
| 376 |
gradient_checkpointing=True,
|
| 377 |
id_token="BW_STYLE",
|
| 378 |
gradient_accumulation_steps=1,
|
| 379 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
| 380 |
+
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
| 381 |
flow_weighting_scheme="logit_normal", # LTX specific
|
| 382 |
training_type="full-finetune"
|
| 383 |
)
|
|
|
|
| 391 |
data_root=data_path,
|
| 392 |
output_dir=output_path,
|
| 393 |
batch_size=1,
|
| 394 |
+
train_steps=DEFAULT_NB_TRAINING_STEPS,
|
| 395 |
lr=5e-5,
|
| 396 |
gradient_checkpointing=True,
|
| 397 |
id_token=None, # Default is no ID token for Wan
|
|
|
|
| 400 |
lora_alpha=32,
|
| 401 |
target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
|
| 402 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
| 403 |
+
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
| 404 |
flow_weighting_scheme="logit_normal", # Wan specific
|
| 405 |
training_type="lora"
|
| 406 |
)
|
|
|
|
| 455 |
#args.extend(["--mixed_precision", self.mixed_precision])
|
| 456 |
|
| 457 |
args.extend(["--batch_size", str(self.batch_size)])
|
| 458 |
+
args.extend(["--train_steps", str(self.train_steps)])
|
| 459 |
|
| 460 |
# LoRA specific arguments
|
| 461 |
if self.training_type == "lora":
|
vms/services/trainer.py
CHANGED
|
@@ -23,7 +23,12 @@ from huggingface_hub import upload_folder, create_repo
|
|
| 23 |
from ..config import (
|
| 24 |
TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
|
| 25 |
STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
|
| 26 |
-
MODEL_TYPES, TRAINING_TYPES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
)
|
| 28 |
from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
|
| 29 |
|
|
@@ -111,18 +116,19 @@ class TrainingService:
|
|
| 111 |
except Exception as e:
|
| 112 |
logger.error(f"Error saving UI state: {str(e)}")
|
| 113 |
|
|
|
|
| 114 |
def load_ui_state(self) -> Dict[str, Any]:
|
| 115 |
"""Load saved UI state"""
|
| 116 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
| 117 |
default_state = {
|
| 118 |
"model_type": list(MODEL_TYPES.keys())[0],
|
| 119 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
| 120 |
-
"lora_rank":
|
| 121 |
-
"lora_alpha":
|
| 122 |
-
"
|
| 123 |
-
"batch_size":
|
| 124 |
-
"learning_rate":
|
| 125 |
-
"save_iterations":
|
| 126 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 127 |
}
|
| 128 |
|
|
@@ -145,9 +151,14 @@ class TrainingService:
|
|
| 145 |
|
| 146 |
saved_state = json.loads(file_content)
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
# Convert numeric values to appropriate types
|
| 149 |
-
if "
|
| 150 |
-
saved_state["
|
| 151 |
if "batch_size" in saved_state:
|
| 152 |
saved_state["batch_size"] = int(saved_state["batch_size"])
|
| 153 |
if "learning_rate" in saved_state:
|
|
@@ -158,6 +169,40 @@ class TrainingService:
|
|
| 158 |
# Make sure we have all keys (in case structure changed)
|
| 159 |
merged_state = default_state.copy()
|
| 160 |
merged_state.update(saved_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
return merged_state
|
| 162 |
except json.JSONDecodeError as e:
|
| 163 |
logger.error(f"Error parsing UI state JSON: {str(e)}")
|
|
@@ -176,12 +221,12 @@ class TrainingService:
|
|
| 176 |
default_state = {
|
| 177 |
"model_type": list(MODEL_TYPES.keys())[0],
|
| 178 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
| 179 |
-
"lora_rank":
|
| 180 |
-
"lora_alpha":
|
| 181 |
-
"
|
| 182 |
-
"batch_size":
|
| 183 |
-
"learning_rate":
|
| 184 |
-
"save_iterations":
|
| 185 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 186 |
}
|
| 187 |
self.save_ui_state(default_state)
|
|
@@ -209,12 +254,12 @@ class TrainingService:
|
|
| 209 |
default_state = {
|
| 210 |
"model_type": list(MODEL_TYPES.keys())[0],
|
| 211 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
| 212 |
-
"lora_rank":
|
| 213 |
-
"lora_alpha":
|
| 214 |
-
"
|
| 215 |
-
"batch_size":
|
| 216 |
-
"learning_rate":
|
| 217 |
-
"save_iterations":
|
| 218 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 219 |
}
|
| 220 |
self.save_ui_state(default_state)
|
|
@@ -361,7 +406,7 @@ class TrainingService:
|
|
| 361 |
model_type: str,
|
| 362 |
lora_rank: str,
|
| 363 |
lora_alpha: str,
|
| 364 |
-
|
| 365 |
batch_size: int,
|
| 366 |
learning_rate: float,
|
| 367 |
save_iterations: int,
|
|
@@ -508,7 +553,7 @@ class TrainingService:
|
|
| 508 |
return error_msg, "Unsupported model"
|
| 509 |
|
| 510 |
# Update with UI parameters
|
| 511 |
-
config.
|
| 512 |
config.batch_size = int(batch_size)
|
| 513 |
config.lr = float(learning_rate)
|
| 514 |
config.checkpointing_steps = int(save_iterations)
|
|
@@ -530,11 +575,11 @@ class TrainingService:
|
|
| 530 |
|
| 531 |
# Common settings for both models
|
| 532 |
config.mixed_precision = "bf16"
|
| 533 |
-
config.seed =
|
| 534 |
config.gradient_checkpointing = True
|
| 535 |
config.enable_slicing = True
|
| 536 |
config.enable_tiling = True
|
| 537 |
-
config.caption_dropout_p =
|
| 538 |
|
| 539 |
validation_error = self.validate_training_config(config, model_type)
|
| 540 |
if validation_error:
|
|
@@ -626,7 +671,7 @@ class TrainingService:
|
|
| 626 |
"training_type": training_type,
|
| 627 |
"lora_rank": lora_rank,
|
| 628 |
"lora_alpha": lora_alpha,
|
| 629 |
-
"
|
| 630 |
"batch_size": batch_size,
|
| 631 |
"learning_rate": learning_rate,
|
| 632 |
"save_iterations": save_iterations,
|
|
@@ -635,14 +680,12 @@ class TrainingService:
|
|
| 635 |
})
|
| 636 |
|
| 637 |
# Update initial training status
|
| 638 |
-
total_steps =
|
| 639 |
self.save_status(
|
| 640 |
state='training',
|
| 641 |
-
epoch=0,
|
| 642 |
step=0,
|
| 643 |
total_steps=total_steps,
|
| 644 |
loss=0.0,
|
| 645 |
-
total_epochs=num_epochs,
|
| 646 |
message='Training started',
|
| 647 |
repo_id=repo_id,
|
| 648 |
model_type=model_type,
|
|
@@ -789,12 +832,12 @@ class TrainingService:
|
|
| 789 |
"params": {
|
| 790 |
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
| 791 |
"training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
|
| 792 |
-
"lora_rank": ui_state.get("lora_rank",
|
| 793 |
-
"lora_alpha": ui_state.get("lora_alpha",
|
| 794 |
-
"
|
| 795 |
-
"batch_size": ui_state.get("batch_size",
|
| 796 |
-
"learning_rate": ui_state.get("learning_rate",
|
| 797 |
-
"save_iterations": ui_state.get("save_iterations",
|
| 798 |
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 799 |
"repo_id": "" # Default empty repo ID
|
| 800 |
}
|
|
@@ -853,12 +896,12 @@ class TrainingService:
|
|
| 853 |
ui_updates.update({
|
| 854 |
"model_type": model_type_display, # Use the display name for the UI dropdown
|
| 855 |
"training_type": training_type_display, # Use the display name for training type
|
| 856 |
-
"lora_rank": params.get('lora_rank',
|
| 857 |
-
"lora_alpha": params.get('lora_alpha',
|
| 858 |
-
"
|
| 859 |
-
"batch_size": params.get('batch_size',
|
| 860 |
-
"learning_rate": params.get('learning_rate',
|
| 861 |
-
"save_iterations": params.get('save_iterations',
|
| 862 |
"training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
| 863 |
})
|
| 864 |
|
|
@@ -872,12 +915,12 @@ class TrainingService:
|
|
| 872 |
# But keep model_type_display for the UI
|
| 873 |
result = self.start_training(
|
| 874 |
model_type=model_type_internal,
|
| 875 |
-
lora_rank=params.get('lora_rank',
|
| 876 |
-
lora_alpha=params.get('lora_alpha',
|
| 877 |
-
|
| 878 |
-
batch_size=params.get('batch_size',
|
| 879 |
-
learning_rate=params.get('learning_rate',
|
| 880 |
-
save_iterations=params.get('save_iterations',
|
| 881 |
repo_id=params.get('repo_id', ''),
|
| 882 |
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
| 883 |
training_type=training_type_internal,
|
|
|
|
| 23 |
from ..config import (
|
| 24 |
TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
|
| 25 |
STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
|
| 26 |
+
MODEL_TYPES, TRAINING_TYPES,
|
| 27 |
+
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 28 |
+
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
| 29 |
+
DEFAULT_LEARNING_RATE,
|
| 30 |
+
DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
|
| 31 |
+
DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
|
| 32 |
)
|
| 33 |
from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
|
| 34 |
|
|
|
|
| 116 |
except Exception as e:
|
| 117 |
logger.error(f"Error saving UI state: {str(e)}")
|
| 118 |
|
| 119 |
+
# Additional fix for the load_ui_state method in trainer.py to clean up old values
|
| 120 |
def load_ui_state(self) -> Dict[str, Any]:
|
| 121 |
"""Load saved UI state"""
|
| 122 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
| 123 |
default_state = {
|
| 124 |
"model_type": list(MODEL_TYPES.keys())[0],
|
| 125 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
| 126 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 127 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
| 128 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 129 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 130 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 131 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 132 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 133 |
}
|
| 134 |
|
|
|
|
| 151 |
|
| 152 |
saved_state = json.loads(file_content)
|
| 153 |
|
| 154 |
+
# Clean up model type if it contains " (LoRA)" suffix
|
| 155 |
+
if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
|
| 156 |
+
saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
|
| 157 |
+
logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
|
| 158 |
+
|
| 159 |
# Convert numeric values to appropriate types
|
| 160 |
+
if "train_steps" in saved_state:
|
| 161 |
+
saved_state["train_steps"] = int(saved_state["train_steps"])
|
| 162 |
if "batch_size" in saved_state:
|
| 163 |
saved_state["batch_size"] = int(saved_state["batch_size"])
|
| 164 |
if "learning_rate" in saved_state:
|
|
|
|
| 169 |
# Make sure we have all keys (in case structure changed)
|
| 170 |
merged_state = default_state.copy()
|
| 171 |
merged_state.update(saved_state)
|
| 172 |
+
|
| 173 |
+
# Validate model_type is in available choices
|
| 174 |
+
if merged_state["model_type"] not in MODEL_TYPES:
|
| 175 |
+
# Try to map from internal name
|
| 176 |
+
model_found = False
|
| 177 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
| 178 |
+
if internal_name == merged_state["model_type"]:
|
| 179 |
+
merged_state["model_type"] = display_name
|
| 180 |
+
model_found = True
|
| 181 |
+
break
|
| 182 |
+
# If still not found, use default
|
| 183 |
+
if not model_found:
|
| 184 |
+
merged_state["model_type"] = default_state["model_type"]
|
| 185 |
+
logger.warning(f"Invalid model type in saved state, using default")
|
| 186 |
+
|
| 187 |
+
# Validate training_type is in available choices
|
| 188 |
+
if merged_state["training_type"] not in TRAINING_TYPES:
|
| 189 |
+
# Try to map from internal name
|
| 190 |
+
training_found = False
|
| 191 |
+
for display_name, internal_name in TRAINING_TYPES.items():
|
| 192 |
+
if internal_name == merged_state["training_type"]:
|
| 193 |
+
merged_state["training_type"] = display_name
|
| 194 |
+
training_found = True
|
| 195 |
+
break
|
| 196 |
+
# If still not found, use default
|
| 197 |
+
if not training_found:
|
| 198 |
+
merged_state["training_type"] = default_state["training_type"]
|
| 199 |
+
logger.warning(f"Invalid training type in saved state, using default")
|
| 200 |
+
|
| 201 |
+
# Validate training_preset is in available choices
|
| 202 |
+
if merged_state["training_preset"] not in TRAINING_PRESETS:
|
| 203 |
+
merged_state["training_preset"] = default_state["training_preset"]
|
| 204 |
+
logger.warning(f"Invalid training preset in saved state, using default")
|
| 205 |
+
|
| 206 |
return merged_state
|
| 207 |
except json.JSONDecodeError as e:
|
| 208 |
logger.error(f"Error parsing UI state JSON: {str(e)}")
|
|
|
|
| 221 |
default_state = {
|
| 222 |
"model_type": list(MODEL_TYPES.keys())[0],
|
| 223 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
| 224 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 225 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
| 226 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 227 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 228 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 229 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 230 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 231 |
}
|
| 232 |
self.save_ui_state(default_state)
|
|
|
|
| 254 |
default_state = {
|
| 255 |
"model_type": list(MODEL_TYPES.keys())[0],
|
| 256 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
| 257 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 258 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
| 259 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 260 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 261 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 262 |
+
"save_iterations": DEFAULT_NB_TRAINING_STEPS,
|
| 263 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 264 |
}
|
| 265 |
self.save_ui_state(default_state)
|
|
|
|
| 406 |
model_type: str,
|
| 407 |
lora_rank: str,
|
| 408 |
lora_alpha: str,
|
| 409 |
+
train_steps: int,
|
| 410 |
batch_size: int,
|
| 411 |
learning_rate: float,
|
| 412 |
save_iterations: int,
|
|
|
|
| 553 |
return error_msg, "Unsupported model"
|
| 554 |
|
| 555 |
# Update with UI parameters
|
| 556 |
+
config.train_steps = int(train_steps)
|
| 557 |
config.batch_size = int(batch_size)
|
| 558 |
config.lr = float(learning_rate)
|
| 559 |
config.checkpointing_steps = int(save_iterations)
|
|
|
|
| 575 |
|
| 576 |
# Common settings for both models
|
| 577 |
config.mixed_precision = "bf16"
|
| 578 |
+
config.seed = DEFAULT_SEED
|
| 579 |
config.gradient_checkpointing = True
|
| 580 |
config.enable_slicing = True
|
| 581 |
config.enable_tiling = True
|
| 582 |
+
config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
|
| 583 |
|
| 584 |
validation_error = self.validate_training_config(config, model_type)
|
| 585 |
if validation_error:
|
|
|
|
| 671 |
"training_type": training_type,
|
| 672 |
"lora_rank": lora_rank,
|
| 673 |
"lora_alpha": lora_alpha,
|
| 674 |
+
"train_steps": train_steps,
|
| 675 |
"batch_size": batch_size,
|
| 676 |
"learning_rate": learning_rate,
|
| 677 |
"save_iterations": save_iterations,
|
|
|
|
| 680 |
})
|
| 681 |
|
| 682 |
# Update initial training status
|
| 683 |
+
total_steps = int(train_steps)
|
| 684 |
self.save_status(
|
| 685 |
state='training',
|
|
|
|
| 686 |
step=0,
|
| 687 |
total_steps=total_steps,
|
| 688 |
loss=0.0,
|
|
|
|
| 689 |
message='Training started',
|
| 690 |
repo_id=repo_id,
|
| 691 |
model_type=model_type,
|
|
|
|
| 832 |
"params": {
|
| 833 |
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
| 834 |
"training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
|
| 835 |
+
"lora_rank": ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
|
| 836 |
+
"lora_alpha": ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
|
| 837 |
+
"train_steps": ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
| 838 |
+
"batch_size": ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
|
| 839 |
+
"learning_rate": ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
|
| 840 |
+
"save_iterations": ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
| 841 |
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 842 |
"repo_id": "" # Default empty repo ID
|
| 843 |
}
|
|
|
|
| 896 |
ui_updates.update({
|
| 897 |
"model_type": model_type_display, # Use the display name for the UI dropdown
|
| 898 |
"training_type": training_type_display, # Use the display name for training type
|
| 899 |
+
"lora_rank": params.get('lora_rank', DEFAULT_LORA_RANK_STR),
|
| 900 |
+
"lora_alpha": params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
|
| 901 |
+
"train_steps": params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
|
| 902 |
+
"batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
|
| 903 |
+
"learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
|
| 904 |
+
"save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
| 905 |
"training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
| 906 |
})
|
| 907 |
|
|
|
|
| 915 |
# But keep model_type_display for the UI
|
| 916 |
result = self.start_training(
|
| 917 |
model_type=model_type_internal,
|
| 918 |
+
lora_rank=params.get('lora_rank', DEFAULT_LORA_RANK_STR),
|
| 919 |
+
lora_alpha=params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
|
| 920 |
+
train_size=params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
|
| 921 |
+
batch_size=params.get('batch_size', DEFAULT_BATCH_SIZE),
|
| 922 |
+
learning_rate=params.get('learning_rate', DEFAULT_LEARNING_RATE),
|
| 923 |
+
save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
| 924 |
repo_id=params.get('repo_id', ''),
|
| 925 |
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
| 926 |
training_type=training_type_internal,
|
vms/tabs/train_tab.py
CHANGED
|
@@ -9,7 +9,14 @@ from typing import Dict, Any, List, Optional, Tuple
|
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
from .base_tab import BaseTab
|
| 12 |
-
from ..config import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
|
@@ -63,20 +70,20 @@ class TrainTab(BaseTab):
|
|
| 63 |
self.components["lora_rank"] = gr.Dropdown(
|
| 64 |
label="LoRA Rank",
|
| 65 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
| 66 |
-
value=
|
| 67 |
type="value"
|
| 68 |
)
|
| 69 |
self.components["lora_alpha"] = gr.Dropdown(
|
| 70 |
label="LoRA Alpha",
|
| 71 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
| 72 |
-
value=
|
| 73 |
type="value"
|
| 74 |
)
|
| 75 |
|
| 76 |
with gr.Row():
|
| 77 |
-
self.components["
|
| 78 |
-
label="Number of
|
| 79 |
-
value=
|
| 80 |
minimum=1,
|
| 81 |
precision=0
|
| 82 |
)
|
|
@@ -89,13 +96,13 @@ class TrainTab(BaseTab):
|
|
| 89 |
with gr.Row():
|
| 90 |
self.components["learning_rate"] = gr.Number(
|
| 91 |
label="Learning Rate",
|
| 92 |
-
value=
|
| 93 |
-
minimum=1e-
|
| 94 |
)
|
| 95 |
self.components["save_iterations"] = gr.Number(
|
| 96 |
label="Save checkpoint every N iterations",
|
| 97 |
-
value=
|
| 98 |
-
minimum=
|
| 99 |
precision=0,
|
| 100 |
info="Model will be saved periodically after these many steps"
|
| 101 |
)
|
|
@@ -170,7 +177,7 @@ class TrainTab(BaseTab):
|
|
| 170 |
|
| 171 |
return {
|
| 172 |
self.components["model_info"]: info,
|
| 173 |
-
self.components["
|
| 174 |
self.components["batch_size"]: params["batch_size"],
|
| 175 |
self.components["learning_rate"]: params["learning_rate"],
|
| 176 |
self.components["save_iterations"]: params["save_iterations"],
|
|
@@ -186,7 +193,7 @@ class TrainTab(BaseTab):
|
|
| 186 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
| 187 |
outputs=[
|
| 188 |
self.components["model_info"],
|
| 189 |
-
self.components["
|
| 190 |
self.components["batch_size"],
|
| 191 |
self.components["learning_rate"],
|
| 192 |
self.components["save_iterations"],
|
|
@@ -204,7 +211,7 @@ class TrainTab(BaseTab):
|
|
| 204 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
| 205 |
outputs=[
|
| 206 |
self.components["model_info"],
|
| 207 |
-
self.components["
|
| 208 |
self.components["batch_size"],
|
| 209 |
self.components["learning_rate"],
|
| 210 |
self.components["save_iterations"],
|
|
@@ -225,9 +232,9 @@ class TrainTab(BaseTab):
|
|
| 225 |
outputs=[]
|
| 226 |
)
|
| 227 |
|
| 228 |
-
self.components["
|
| 229 |
-
fn=lambda v: self.app.update_ui_state(
|
| 230 |
-
inputs=[self.components["
|
| 231 |
outputs=[]
|
| 232 |
)
|
| 233 |
|
|
@@ -262,7 +269,7 @@ class TrainTab(BaseTab):
|
|
| 262 |
self.components["training_type"],
|
| 263 |
self.components["lora_rank"],
|
| 264 |
self.components["lora_alpha"],
|
| 265 |
-
self.components["
|
| 266 |
self.components["batch_size"],
|
| 267 |
self.components["learning_rate"],
|
| 268 |
self.components["save_iterations"],
|
|
@@ -280,7 +287,7 @@ class TrainTab(BaseTab):
|
|
| 280 |
self.components["training_type"],
|
| 281 |
self.components["lora_rank"],
|
| 282 |
self.components["lora_alpha"],
|
| 283 |
-
self.components["
|
| 284 |
self.components["batch_size"],
|
| 285 |
self.components["learning_rate"],
|
| 286 |
self.components["save_iterations"],
|
|
@@ -290,27 +297,20 @@ class TrainTab(BaseTab):
|
|
| 290 |
self.components["status_box"],
|
| 291 |
self.components["log_box"]
|
| 292 |
]
|
| 293 |
-
).success(
|
| 294 |
-
fn=self.get_latest_status_message_logs_and_button_labels,
|
| 295 |
-
outputs=[
|
| 296 |
-
self.components["status_box"],
|
| 297 |
-
self.components["log_box"],
|
| 298 |
-
self.components["start_btn"],
|
| 299 |
-
self.components["stop_btn"],
|
| 300 |
-
self.components["pause_resume_btn"],
|
| 301 |
-
self.components["current_task_box"] # Include new component
|
| 302 |
-
]
|
| 303 |
)
|
| 304 |
|
|
|
|
|
|
|
|
|
|
| 305 |
self.components["pause_resume_btn"].click(
|
| 306 |
fn=self.handle_pause_resume,
|
| 307 |
outputs=[
|
| 308 |
self.components["status_box"],
|
| 309 |
self.components["log_box"],
|
|
|
|
| 310 |
self.components["start_btn"],
|
| 311 |
self.components["stop_btn"],
|
| 312 |
-
|
| 313 |
-
self.components["current_task_box"] # Include new component
|
| 314 |
]
|
| 315 |
)
|
| 316 |
|
|
@@ -319,10 +319,10 @@ class TrainTab(BaseTab):
|
|
| 319 |
outputs=[
|
| 320 |
self.components["status_box"],
|
| 321 |
self.components["log_box"],
|
|
|
|
| 322 |
self.components["start_btn"],
|
| 323 |
self.components["stop_btn"],
|
| 324 |
-
|
| 325 |
-
self.components["current_task_box"] # Include new component
|
| 326 |
]
|
| 327 |
)
|
| 328 |
|
|
@@ -330,16 +330,6 @@ class TrainTab(BaseTab):
|
|
| 330 |
self.components["delete_checkpoints_btn"].click(
|
| 331 |
fn=lambda: self.app.trainer.delete_all_checkpoints(),
|
| 332 |
outputs=[self.components["status_box"]]
|
| 333 |
-
).then(
|
| 334 |
-
fn=self.get_latest_status_message_logs_and_button_labels,
|
| 335 |
-
outputs=[
|
| 336 |
-
self.components["status_box"],
|
| 337 |
-
self.components["log_box"],
|
| 338 |
-
self.components["start_btn"],
|
| 339 |
-
self.components["stop_btn"],
|
| 340 |
-
self.components["delete_checkpoints_btn"],
|
| 341 |
-
self.components["current_task_box"] # Include new component
|
| 342 |
-
]
|
| 343 |
)
|
| 344 |
|
| 345 |
def handle_training_start(self, preset, model_type, training_type, *args):
|
|
@@ -391,7 +381,7 @@ class TrainTab(BaseTab):
|
|
| 391 |
|
| 392 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
| 393 |
"""Get information about the selected model type and training method"""
|
| 394 |
-
if model_type == "HunyuanVideo
|
| 395 |
base_info = """### HunyuanVideo
|
| 396 |
- Required VRAM: ~48GB minimum
|
| 397 |
- Recommended batch size: 1-2
|
|
@@ -403,7 +393,7 @@ class TrainTab(BaseTab):
|
|
| 403 |
else:
|
| 404 |
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
| 405 |
|
| 406 |
-
elif model_type == "LTX-Video
|
| 407 |
base_info = """### LTX-Video
|
| 408 |
- Recommended batch size: 1-4
|
| 409 |
- Typical training time: 1-3 hours
|
|
@@ -414,14 +404,14 @@ class TrainTab(BaseTab):
|
|
| 414 |
else:
|
| 415 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
| 416 |
|
| 417 |
-
elif model_type == "Wan-2.1-T2V
|
| 418 |
base_info = """### Wan-2.1-T2V
|
| 419 |
-
- Recommended batch size:
|
| 420 |
-
- Typical training time:
|
| 421 |
- Default resolution: 49x512x768"""
|
| 422 |
|
| 423 |
if training_type == "LoRA Finetune":
|
| 424 |
-
return base_info + "\n- Required VRAM:
|
| 425 |
else:
|
| 426 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
| 427 |
|
|
@@ -440,51 +430,51 @@ class TrainTab(BaseTab):
|
|
| 440 |
# Use the first matching preset
|
| 441 |
preset = matching_presets[0]
|
| 442 |
return {
|
| 443 |
-
"
|
| 444 |
-
"batch_size": preset.get("batch_size",
|
| 445 |
-
"learning_rate": preset.get("learning_rate",
|
| 446 |
-
"save_iterations": preset.get("save_iterations",
|
| 447 |
-
"lora_rank": preset.get("lora_rank",
|
| 448 |
-
"lora_alpha": preset.get("lora_alpha",
|
| 449 |
}
|
| 450 |
|
| 451 |
# Default fallbacks
|
| 452 |
if model_type == "hunyuan_video":
|
| 453 |
return {
|
| 454 |
-
"
|
| 455 |
-
"batch_size":
|
| 456 |
"learning_rate": 2e-5,
|
| 457 |
-
"save_iterations":
|
| 458 |
-
"lora_rank":
|
| 459 |
-
"lora_alpha":
|
| 460 |
}
|
| 461 |
elif model_type == "ltx_video":
|
| 462 |
return {
|
| 463 |
-
"
|
| 464 |
-
"batch_size":
|
| 465 |
-
"learning_rate":
|
| 466 |
-
"save_iterations":
|
| 467 |
-
"lora_rank":
|
| 468 |
-
"lora_alpha":
|
| 469 |
}
|
| 470 |
elif model_type == "wan":
|
| 471 |
return {
|
| 472 |
-
"
|
| 473 |
-
"batch_size":
|
| 474 |
"learning_rate": 5e-5,
|
| 475 |
-
"save_iterations":
|
| 476 |
"lora_rank": "32",
|
| 477 |
"lora_alpha": "32"
|
| 478 |
}
|
| 479 |
else:
|
| 480 |
# Generic defaults
|
| 481 |
return {
|
| 482 |
-
"
|
| 483 |
-
"batch_size":
|
| 484 |
-
"learning_rate":
|
| 485 |
-
"save_iterations":
|
| 486 |
-
"lora_rank":
|
| 487 |
-
"lora_alpha":
|
| 488 |
}
|
| 489 |
|
| 490 |
def update_training_params(self, preset_name: str) -> Tuple:
|
|
@@ -522,12 +512,12 @@ class TrainTab(BaseTab):
|
|
| 522 |
show_lora_params = preset["training_type"] == "lora"
|
| 523 |
|
| 524 |
# Use preset defaults but preserve user-modified values if they exist
|
| 525 |
-
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank",
|
| 526 |
-
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha",
|
| 527 |
-
|
| 528 |
-
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size",
|
| 529 |
-
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate",
|
| 530 |
-
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations",
|
| 531 |
|
| 532 |
# Return values in the same order as the output components
|
| 533 |
return (
|
|
@@ -535,7 +525,7 @@ class TrainTab(BaseTab):
|
|
| 535 |
training_display_name,
|
| 536 |
lora_rank_val,
|
| 537 |
lora_alpha_val,
|
| 538 |
-
|
| 539 |
batch_size_val,
|
| 540 |
learning_rate_val,
|
| 541 |
save_iterations_val,
|
|
@@ -543,66 +533,6 @@ class TrainTab(BaseTab):
|
|
| 543 |
gr.Row(visible=show_lora_params)
|
| 544 |
)
|
| 545 |
|
| 546 |
-
def update_training_ui(self, training_state: Dict[str, Any]):
|
| 547 |
-
"""Update UI components based on training state"""
|
| 548 |
-
updates = {}
|
| 549 |
-
|
| 550 |
-
# Update status box with high-level information
|
| 551 |
-
status_text = []
|
| 552 |
-
if training_state["status"] != "idle":
|
| 553 |
-
status_text.extend([
|
| 554 |
-
f"Status: {training_state['status']}",
|
| 555 |
-
f"Progress: {training_state['progress']}",
|
| 556 |
-
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
| 557 |
-
f"Time elapsed: {training_state['elapsed']}",
|
| 558 |
-
f"Estimated remaining: {training_state['remaining']}",
|
| 559 |
-
"",
|
| 560 |
-
f"Current loss: {training_state['step_loss']}",
|
| 561 |
-
f"Learning rate: {training_state['learning_rate']}",
|
| 562 |
-
f"Gradient norm: {training_state['grad_norm']}",
|
| 563 |
-
f"Memory usage: {training_state['memory']}"
|
| 564 |
-
])
|
| 565 |
-
|
| 566 |
-
if training_state["error_message"]:
|
| 567 |
-
status_text.append(f"\nError: {training_state['error_message']}")
|
| 568 |
-
|
| 569 |
-
updates["status_box"] = "\n".join(status_text)
|
| 570 |
-
|
| 571 |
-
# Add current task information to the dedicated box
|
| 572 |
-
if training_state.get("current_task"):
|
| 573 |
-
updates["current_task_box"] = training_state["current_task"]
|
| 574 |
-
else:
|
| 575 |
-
updates["current_task_box"] = "No active task" if training_state["status"] != "training" else "Waiting for task information..."
|
| 576 |
-
|
| 577 |
-
# Update button states
|
| 578 |
-
updates["start_btn"] = gr.Button(
|
| 579 |
-
"Start training",
|
| 580 |
-
interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
|
| 581 |
-
variant="primary" if training_state["status"] == "idle" else "secondary"
|
| 582 |
-
)
|
| 583 |
-
|
| 584 |
-
updates["stop_btn"] = gr.Button(
|
| 585 |
-
"Stop training",
|
| 586 |
-
interactive=(training_state["status"] in ["training", "initializing"]),
|
| 587 |
-
variant="stop"
|
| 588 |
-
)
|
| 589 |
-
|
| 590 |
-
return updates
|
| 591 |
-
|
| 592 |
-
def handle_pause_resume(self):
|
| 593 |
-
status, _, _ = self.get_latest_status_message_and_logs()
|
| 594 |
-
|
| 595 |
-
if status == "paused":
|
| 596 |
-
self.app.trainer.resume_training()
|
| 597 |
-
else:
|
| 598 |
-
self.app.trainer.pause_training()
|
| 599 |
-
|
| 600 |
-
return self.get_latest_status_message_logs_and_button_labels()
|
| 601 |
-
|
| 602 |
-
def handle_stop(self):
|
| 603 |
-
self.app.trainer.stop_training()
|
| 604 |
-
return self.get_latest_status_message_logs_and_button_labels()
|
| 605 |
-
|
| 606 |
def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
|
| 607 |
"""Get latest status message, log content, and status code in a safer way"""
|
| 608 |
state = self.app.trainer.get_status()
|
|
@@ -663,61 +593,107 @@ class TrainTab(BaseTab):
|
|
| 663 |
|
| 664 |
return (state["status"], state["message"], logs)
|
| 665 |
|
| 666 |
-
def
|
| 667 |
-
"""Get
|
| 668 |
status, message, logs = self.get_latest_status_message_and_logs()
|
| 669 |
|
| 670 |
-
# Add checkpoints detection
|
| 671 |
-
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
| 672 |
-
|
| 673 |
-
button_updates = self.update_training_buttons(status, has_checkpoints).values()
|
| 674 |
-
|
| 675 |
# Get current task if available
|
| 676 |
current_task = ""
|
| 677 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
| 678 |
current_task = self.app.log_parser.get_current_task_display()
|
| 679 |
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
|
|
|
| 688 |
is_training = status in ["training", "initializing"]
|
| 689 |
is_completed = status in ["completed", "error", "stopped"]
|
| 690 |
|
| 691 |
start_text = "Continue Training" if has_checkpoints else "Start Training"
|
| 692 |
|
| 693 |
-
#
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
),
|
| 700 |
-
"stop_btn": gr.Button(
|
| 701 |
-
value="Stop at Last Checkpoint",
|
| 702 |
-
interactive=is_training,
|
| 703 |
-
variant="primary" if is_training else "secondary",
|
| 704 |
-
)
|
| 705 |
-
}
|
| 706 |
|
| 707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
if "delete_checkpoints_btn" in self.components:
|
| 709 |
-
|
| 710 |
-
|
| 711 |
interactive=has_checkpoints and not is_training,
|
| 712 |
-
variant="stop"
|
| 713 |
)
|
| 714 |
else:
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
value="Resume Training" if status == "paused" else "Pause Training",
|
| 718 |
interactive=(is_training or status == "paused") and not is_completed,
|
| 719 |
variant="secondary",
|
| 720 |
visible=False
|
| 721 |
)
|
| 722 |
|
| 723 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
from .base_tab import BaseTab
|
| 12 |
+
from ..config import (
|
| 13 |
+
TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
|
| 14 |
+
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 15 |
+
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
| 16 |
+
DEFAULT_LEARNING_RATE,
|
| 17 |
+
DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
|
| 18 |
+
DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
|
| 19 |
+
)
|
| 20 |
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
|
|
|
| 70 |
self.components["lora_rank"] = gr.Dropdown(
|
| 71 |
label="LoRA Rank",
|
| 72 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
| 73 |
+
value=DEFAULT_LORA_RANK_STR,
|
| 74 |
type="value"
|
| 75 |
)
|
| 76 |
self.components["lora_alpha"] = gr.Dropdown(
|
| 77 |
label="LoRA Alpha",
|
| 78 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
| 79 |
+
value=DEFAULT_LORA_ALPHA_STR,
|
| 80 |
type="value"
|
| 81 |
)
|
| 82 |
|
| 83 |
with gr.Row():
|
| 84 |
+
self.components["train_steps"] = gr.Number(
|
| 85 |
+
label="Number of Training Steps",
|
| 86 |
+
value=DEFAULT_NB_TRAINING_STEPS,
|
| 87 |
minimum=1,
|
| 88 |
precision=0
|
| 89 |
)
|
|
|
|
| 96 |
with gr.Row():
|
| 97 |
self.components["learning_rate"] = gr.Number(
|
| 98 |
label="Learning Rate",
|
| 99 |
+
value=DEFAULT_LEARNING_RATE,
|
| 100 |
+
minimum=1e-8
|
| 101 |
)
|
| 102 |
self.components["save_iterations"] = gr.Number(
|
| 103 |
label="Save checkpoint every N iterations",
|
| 104 |
+
value=DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 105 |
+
minimum=1,
|
| 106 |
precision=0,
|
| 107 |
info="Model will be saved periodically after these many steps"
|
| 108 |
)
|
|
|
|
| 177 |
|
| 178 |
return {
|
| 179 |
self.components["model_info"]: info,
|
| 180 |
+
self.components["train_steps"]: params["train_steps"],
|
| 181 |
self.components["batch_size"]: params["batch_size"],
|
| 182 |
self.components["learning_rate"]: params["learning_rate"],
|
| 183 |
self.components["save_iterations"]: params["save_iterations"],
|
|
|
|
| 193 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
| 194 |
outputs=[
|
| 195 |
self.components["model_info"],
|
| 196 |
+
self.components["train_steps"],
|
| 197 |
self.components["batch_size"],
|
| 198 |
self.components["learning_rate"],
|
| 199 |
self.components["save_iterations"],
|
|
|
|
| 211 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
| 212 |
outputs=[
|
| 213 |
self.components["model_info"],
|
| 214 |
+
self.components["train_steps"],
|
| 215 |
self.components["batch_size"],
|
| 216 |
self.components["learning_rate"],
|
| 217 |
self.components["save_iterations"],
|
|
|
|
| 232 |
outputs=[]
|
| 233 |
)
|
| 234 |
|
| 235 |
+
self.components["train_steps"].change(
|
| 236 |
+
fn=lambda v: self.app.update_ui_state(train_steps=v),
|
| 237 |
+
inputs=[self.components["train_steps"]],
|
| 238 |
outputs=[]
|
| 239 |
)
|
| 240 |
|
|
|
|
| 269 |
self.components["training_type"],
|
| 270 |
self.components["lora_rank"],
|
| 271 |
self.components["lora_alpha"],
|
| 272 |
+
self.components["train_steps"],
|
| 273 |
self.components["batch_size"],
|
| 274 |
self.components["learning_rate"],
|
| 275 |
self.components["save_iterations"],
|
|
|
|
| 287 |
self.components["training_type"],
|
| 288 |
self.components["lora_rank"],
|
| 289 |
self.components["lora_alpha"],
|
| 290 |
+
self.components["train_steps"],
|
| 291 |
self.components["batch_size"],
|
| 292 |
self.components["learning_rate"],
|
| 293 |
self.components["save_iterations"],
|
|
|
|
| 297 |
self.components["status_box"],
|
| 298 |
self.components["log_box"]
|
| 299 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
)
|
| 301 |
|
| 302 |
+
# Use simplified event handlers for pause/resume and stop
|
| 303 |
+
third_btn = self.components["delete_checkpoints_btn"] if "delete_checkpoints_btn" in self.components else self.components["pause_resume_btn"]
|
| 304 |
+
|
| 305 |
self.components["pause_resume_btn"].click(
|
| 306 |
fn=self.handle_pause_resume,
|
| 307 |
outputs=[
|
| 308 |
self.components["status_box"],
|
| 309 |
self.components["log_box"],
|
| 310 |
+
self.components["current_task_box"],
|
| 311 |
self.components["start_btn"],
|
| 312 |
self.components["stop_btn"],
|
| 313 |
+
third_btn
|
|
|
|
| 314 |
]
|
| 315 |
)
|
| 316 |
|
|
|
|
| 319 |
outputs=[
|
| 320 |
self.components["status_box"],
|
| 321 |
self.components["log_box"],
|
| 322 |
+
self.components["current_task_box"],
|
| 323 |
self.components["start_btn"],
|
| 324 |
self.components["stop_btn"],
|
| 325 |
+
third_btn
|
|
|
|
| 326 |
]
|
| 327 |
)
|
| 328 |
|
|
|
|
| 330 |
self.components["delete_checkpoints_btn"].click(
|
| 331 |
fn=lambda: self.app.trainer.delete_all_checkpoints(),
|
| 332 |
outputs=[self.components["status_box"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
)
|
| 334 |
|
| 335 |
def handle_training_start(self, preset, model_type, training_type, *args):
|
|
|
|
| 381 |
|
| 382 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
| 383 |
"""Get information about the selected model type and training method"""
|
| 384 |
+
if model_type == "HunyuanVideo":
|
| 385 |
base_info = """### HunyuanVideo
|
| 386 |
- Required VRAM: ~48GB minimum
|
| 387 |
- Recommended batch size: 1-2
|
|
|
|
| 393 |
else:
|
| 394 |
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
| 395 |
|
| 396 |
+
elif model_type == "LTX-Video":
|
| 397 |
base_info = """### LTX-Video
|
| 398 |
- Recommended batch size: 1-4
|
| 399 |
- Typical training time: 1-3 hours
|
|
|
|
| 404 |
else:
|
| 405 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
| 406 |
|
| 407 |
+
elif model_type == "Wan-2.1-T2V":
|
| 408 |
base_info = """### Wan-2.1-T2V
|
| 409 |
+
- Recommended batch size: ?
|
| 410 |
+
- Typical training time: ? hours
|
| 411 |
- Default resolution: 49x512x768"""
|
| 412 |
|
| 413 |
if training_type == "LoRA Finetune":
|
| 414 |
+
return base_info + "\n- Required VRAM: ?GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
| 415 |
else:
|
| 416 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
| 417 |
|
|
|
|
| 430 |
# Use the first matching preset
|
| 431 |
preset = matching_presets[0]
|
| 432 |
return {
|
| 433 |
+
"train_steps": preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
| 434 |
+
"batch_size": preset.get("batch_size", DEFAULT_BATCH_SIZE),
|
| 435 |
+
"learning_rate": preset.get("learning_rate", DEFAULT_LEARNING_RATE),
|
| 436 |
+
"save_iterations": preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
| 437 |
+
"lora_rank": preset.get("lora_rank", DEFAULT_LORA_RANK_STR),
|
| 438 |
+
"lora_alpha": preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
| 439 |
}
|
| 440 |
|
| 441 |
# Default fallbacks
|
| 442 |
if model_type == "hunyuan_video":
|
| 443 |
return {
|
| 444 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 445 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 446 |
"learning_rate": 2e-5,
|
| 447 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 448 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 449 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
| 450 |
}
|
| 451 |
elif model_type == "ltx_video":
|
| 452 |
return {
|
| 453 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 454 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 455 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 456 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 457 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 458 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
| 459 |
}
|
| 460 |
elif model_type == "wan":
|
| 461 |
return {
|
| 462 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 463 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 464 |
"learning_rate": 5e-5,
|
| 465 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 466 |
"lora_rank": "32",
|
| 467 |
"lora_alpha": "32"
|
| 468 |
}
|
| 469 |
else:
|
| 470 |
# Generic defaults
|
| 471 |
return {
|
| 472 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
| 473 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
| 474 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
| 475 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 476 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
| 477 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
| 478 |
}
|
| 479 |
|
| 480 |
def update_training_params(self, preset_name: str) -> Tuple:
|
|
|
|
| 512 |
show_lora_params = preset["training_type"] == "lora"
|
| 513 |
|
| 514 |
# Use preset defaults but preserve user-modified values if they exist
|
| 515 |
+
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
| 516 |
+
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) else preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
| 517 |
+
train_steps_val = current_state.get("train_steps") if current_state.get("train_steps") != preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS) else preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS)
|
| 518 |
+
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", DEFAULT_BATCH_SIZE) else preset.get("batch_size", DEFAULT_BATCH_SIZE)
|
| 519 |
+
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", DEFAULT_LEARNING_RATE) else preset.get("learning_rate", DEFAULT_LEARNING_RATE)
|
| 520 |
+
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) else preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
|
| 521 |
|
| 522 |
# Return values in the same order as the output components
|
| 523 |
return (
|
|
|
|
| 525 |
training_display_name,
|
| 526 |
lora_rank_val,
|
| 527 |
lora_alpha_val,
|
| 528 |
+
train_steps_val,
|
| 529 |
batch_size_val,
|
| 530 |
learning_rate_val,
|
| 531 |
save_iterations_val,
|
|
|
|
| 533 |
gr.Row(visible=show_lora_params)
|
| 534 |
)
|
| 535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
|
| 537 |
"""Get latest status message, log content, and status code in a safer way"""
|
| 538 |
state = self.app.trainer.get_status()
|
|
|
|
| 593 |
|
| 594 |
return (state["status"], state["message"], logs)
|
| 595 |
|
| 596 |
+
def get_status_updates(self):
|
| 597 |
+
"""Get status updates for text components (no variant property)"""
|
| 598 |
status, message, logs = self.get_latest_status_message_and_logs()
|
| 599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
# Get current task if available
|
| 601 |
current_task = ""
|
| 602 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
| 603 |
current_task = self.app.log_parser.get_current_task_display()
|
| 604 |
|
| 605 |
+
return message, logs, current_task
|
| 606 |
+
|
| 607 |
+
def get_button_updates(self):
|
| 608 |
+
"""Get button updates (with variant property)"""
|
| 609 |
+
status, _, _ = self.get_latest_status_message_and_logs()
|
| 610 |
+
|
| 611 |
+
# Add checkpoints detection
|
| 612 |
+
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
| 613 |
+
|
| 614 |
is_training = status in ["training", "initializing"]
|
| 615 |
is_completed = status in ["completed", "error", "stopped"]
|
| 616 |
|
| 617 |
start_text = "Continue Training" if has_checkpoints else "Start Training"
|
| 618 |
|
| 619 |
+
# Create button updates
|
| 620 |
+
start_btn = gr.Button(
|
| 621 |
+
value=start_text,
|
| 622 |
+
interactive=not is_training,
|
| 623 |
+
variant="primary" if not is_training else "secondary"
|
| 624 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
|
| 626 |
+
stop_btn = gr.Button(
|
| 627 |
+
value="Stop at Last Checkpoint",
|
| 628 |
+
interactive=is_training,
|
| 629 |
+
variant="primary" if is_training else "secondary"
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# Add delete_checkpoints_btn or pause_resume_btn
|
| 633 |
if "delete_checkpoints_btn" in self.components:
|
| 634 |
+
third_btn = gr.Button(
|
| 635 |
+
"Delete All Checkpoints",
|
| 636 |
interactive=has_checkpoints and not is_training,
|
| 637 |
+
variant="stop"
|
| 638 |
)
|
| 639 |
else:
|
| 640 |
+
third_btn = gr.Button(
|
| 641 |
+
"Resume Training" if status == "paused" else "Pause Training",
|
|
|
|
| 642 |
interactive=(is_training or status == "paused") and not is_completed,
|
| 643 |
variant="secondary",
|
| 644 |
visible=False
|
| 645 |
)
|
| 646 |
|
| 647 |
+
return start_btn, stop_btn, third_btn
|
| 648 |
+
|
| 649 |
+
def update_training_ui(self, training_state: Dict[str, Any]):
|
| 650 |
+
"""Update UI components based on training state"""
|
| 651 |
+
updates = {}
|
| 652 |
+
|
| 653 |
+
# Update status box with high-level information
|
| 654 |
+
status_text = []
|
| 655 |
+
if training_state["status"] != "idle":
|
| 656 |
+
status_text.extend([
|
| 657 |
+
f"Status: {training_state['status']}",
|
| 658 |
+
f"Progress: {training_state['progress']}",
|
| 659 |
+
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
| 660 |
+
f"Time elapsed: {training_state['elapsed']}",
|
| 661 |
+
f"Estimated remaining: {training_state['remaining']}",
|
| 662 |
+
"",
|
| 663 |
+
f"Current loss: {training_state['step_loss']}",
|
| 664 |
+
f"Learning rate: {training_state['learning_rate']}",
|
| 665 |
+
f"Gradient norm: {training_state['grad_norm']}",
|
| 666 |
+
f"Memory usage: {training_state['memory']}"
|
| 667 |
+
])
|
| 668 |
+
|
| 669 |
+
if training_state["error_message"]:
|
| 670 |
+
status_text.append(f"\nError: {training_state['error_message']}")
|
| 671 |
+
|
| 672 |
+
updates["status_box"] = "\n".join(status_text)
|
| 673 |
+
|
| 674 |
+
# Add current task information to the dedicated box
|
| 675 |
+
if training_state.get("current_task"):
|
| 676 |
+
updates["current_task_box"] = training_state["current_task"]
|
| 677 |
+
else:
|
| 678 |
+
updates["current_task_box"] = "No active task" if training_state["status"] != "training" else "Waiting for task information..."
|
| 679 |
+
|
| 680 |
+
return updates
|
| 681 |
+
|
| 682 |
+
def handle_pause_resume(self):
|
| 683 |
+
"""Handle pause/resume button click"""
|
| 684 |
+
status, _, _ = self.get_latest_status_message_and_logs()
|
| 685 |
+
|
| 686 |
+
if status == "paused":
|
| 687 |
+
self.app.trainer.resume_training()
|
| 688 |
+
else:
|
| 689 |
+
self.app.trainer.pause_training()
|
| 690 |
+
|
| 691 |
+
# Return the updates separately for text and buttons
|
| 692 |
+
return (*self.get_status_updates(), *self.get_button_updates())
|
| 693 |
+
|
| 694 |
+
def handle_stop(self):
|
| 695 |
+
"""Handle stop button click"""
|
| 696 |
+
self.app.trainer.stop_training()
|
| 697 |
+
|
| 698 |
+
# Return the updates separately for text and buttons
|
| 699 |
+
return (*self.get_status_updates(), *self.get_button_updates())
|
vms/ui/video_trainer_ui.py
CHANGED
|
@@ -9,7 +9,12 @@ from ..services import TrainingService, CaptioningService, SplittingService, Imp
|
|
| 9 |
from ..config import (
|
| 10 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
| 11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
| 12 |
-
MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
)
|
| 14 |
from ..utils import count_media_files, format_media_title, TrainingLogParser
|
| 15 |
from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
|
|
@@ -92,7 +97,7 @@ class VideoTrainerUI:
|
|
| 92 |
self.tabs["train_tab"].components["training_type"],
|
| 93 |
self.tabs["train_tab"].components["lora_rank"],
|
| 94 |
self.tabs["train_tab"].components["lora_alpha"],
|
| 95 |
-
self.tabs["train_tab"].components["
|
| 96 |
self.tabs["train_tab"].components["batch_size"],
|
| 97 |
self.tabs["train_tab"].components["learning_rate"],
|
| 98 |
self.tabs["train_tab"].components["save_iterations"],
|
|
@@ -104,31 +109,33 @@ class VideoTrainerUI:
|
|
| 104 |
|
| 105 |
def _add_timers(self):
|
| 106 |
"""Add auto-refresh timers to the UI"""
|
| 107 |
-
# Status update timer (every 1 second)
|
| 108 |
status_timer = gr.Timer(value=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
self.tabs["train_tab"].components["log_box"],
|
| 114 |
self.tabs["train_tab"].components["start_btn"],
|
| 115 |
self.tabs["train_tab"].components["stop_btn"]
|
| 116 |
]
|
| 117 |
|
| 118 |
-
# Add
|
| 119 |
-
if "current_task_box" in self.tabs["train_tab"].components:
|
| 120 |
-
outputs.append(self.tabs["train_tab"].components["current_task_box"])
|
| 121 |
-
|
| 122 |
-
# Add delete_checkpoints_btn only if it exists
|
| 123 |
if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
outputs.append(self.tabs["train_tab"].components["pause_resume_btn"])
|
| 128 |
|
| 129 |
-
|
| 130 |
-
fn=self.tabs["train_tab"].
|
| 131 |
-
outputs=
|
| 132 |
)
|
| 133 |
|
| 134 |
# Dataset refresh timer (every 5 seconds)
|
|
@@ -175,6 +182,11 @@ class VideoTrainerUI:
|
|
| 175 |
if "model_type" in recovery_ui:
|
| 176 |
model_type_value = recovery_ui["model_type"]
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
# If it's an internal name, convert to display name
|
| 179 |
if model_type_value not in MODEL_TYPES:
|
| 180 |
# Find the display name for this internal model type
|
|
@@ -201,7 +213,7 @@ class VideoTrainerUI:
|
|
| 201 |
ui_state["training_type"] = training_type_value
|
| 202 |
|
| 203 |
# Copy other parameters
|
| 204 |
-
for param in ["lora_rank", "lora_alpha", "
|
| 205 |
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
| 206 |
if param in recovery_ui:
|
| 207 |
ui_state[param] = recovery_ui[param]
|
|
@@ -216,31 +228,55 @@ class VideoTrainerUI:
|
|
| 216 |
# Load values (potentially with recovery updates applied)
|
| 217 |
ui_state = self.load_ui_values()
|
| 218 |
|
| 219 |
-
# Ensure model_type is a display name
|
| 220 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
if model_type_val not in MODEL_TYPES:
|
| 222 |
-
# Convert from internal to display name
|
|
|
|
| 223 |
for display_name, internal_name in MODEL_TYPES.items():
|
| 224 |
if internal_name == model_type_val:
|
| 225 |
model_type_val = display_name
|
|
|
|
| 226 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
# Ensure training_type is a display name
|
| 229 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
| 230 |
if training_type_val not in TRAINING_TYPES:
|
| 231 |
-
# Convert from internal to display name
|
|
|
|
| 232 |
for display_name, internal_name in TRAINING_TYPES.items():
|
| 233 |
if internal_name == training_type_val:
|
| 234 |
training_type_val = display_name
|
|
|
|
| 235 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
|
|
|
| 237 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
# Initial current task value
|
| 246 |
current_task_val = ""
|
|
@@ -259,7 +295,7 @@ class VideoTrainerUI:
|
|
| 259 |
training_type_val,
|
| 260 |
lora_rank_val,
|
| 261 |
lora_alpha_val,
|
| 262 |
-
|
| 263 |
batch_size_val,
|
| 264 |
learning_rate_val,
|
| 265 |
save_iterations_val,
|
|
@@ -275,12 +311,12 @@ class VideoTrainerUI:
|
|
| 275 |
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 276 |
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
| 277 |
ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
|
| 278 |
-
ui_state.get("lora_rank",
|
| 279 |
-
ui_state.get("lora_alpha",
|
| 280 |
-
ui_state.get("
|
| 281 |
-
ui_state.get("batch_size",
|
| 282 |
-
ui_state.get("learning_rate",
|
| 283 |
-
ui_state.get("save_iterations",
|
| 284 |
)
|
| 285 |
|
| 286 |
def update_ui_state(self, **kwargs):
|
|
@@ -296,12 +332,12 @@ class VideoTrainerUI:
|
|
| 296 |
ui_state = self.trainer.load_ui_state()
|
| 297 |
|
| 298 |
# Ensure proper type conversion for numeric values
|
| 299 |
-
ui_state["lora_rank"] = ui_state.get("lora_rank",
|
| 300 |
-
ui_state["lora_alpha"] = ui_state.get("lora_alpha",
|
| 301 |
-
ui_state["
|
| 302 |
-
ui_state["batch_size"] = int(ui_state.get("batch_size",
|
| 303 |
-
ui_state["learning_rate"] = float(ui_state.get("learning_rate",
|
| 304 |
-
ui_state["save_iterations"] = int(ui_state.get("save_iterations",
|
| 305 |
|
| 306 |
return ui_state
|
| 307 |
|
|
|
|
| 9 |
from ..config import (
|
| 10 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
| 11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
| 12 |
+
MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
|
| 13 |
+
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
| 14 |
+
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
| 15 |
+
DEFAULT_LEARNING_RATE,
|
| 16 |
+
DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
|
| 17 |
+
DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
|
| 18 |
)
|
| 19 |
from ..utils import count_media_files, format_media_title, TrainingLogParser
|
| 20 |
from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
|
|
|
|
| 97 |
self.tabs["train_tab"].components["training_type"],
|
| 98 |
self.tabs["train_tab"].components["lora_rank"],
|
| 99 |
self.tabs["train_tab"].components["lora_alpha"],
|
| 100 |
+
self.tabs["train_tab"].components["train_steps"],
|
| 101 |
self.tabs["train_tab"].components["batch_size"],
|
| 102 |
self.tabs["train_tab"].components["learning_rate"],
|
| 103 |
self.tabs["train_tab"].components["save_iterations"],
|
|
|
|
| 109 |
|
| 110 |
def _add_timers(self):
|
| 111 |
"""Add auto-refresh timers to the UI"""
|
| 112 |
+
# Status update timer for text components (every 1 second)
|
| 113 |
status_timer = gr.Timer(value=1)
|
| 114 |
+
status_timer.tick(
|
| 115 |
+
fn=self.tabs["train_tab"].get_status_updates, # Use a new function that returns appropriate updates
|
| 116 |
+
outputs=[
|
| 117 |
+
self.tabs["train_tab"].components["status_box"],
|
| 118 |
+
self.tabs["train_tab"].components["log_box"],
|
| 119 |
+
self.tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.tabs["train_tab"].components else None
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
|
| 123 |
+
# Button update timer for button components (every 1 second)
|
| 124 |
+
button_timer = gr.Timer(value=1)
|
| 125 |
+
button_outputs = [
|
|
|
|
| 126 |
self.tabs["train_tab"].components["start_btn"],
|
| 127 |
self.tabs["train_tab"].components["stop_btn"]
|
| 128 |
]
|
| 129 |
|
| 130 |
+
# Add delete_checkpoints_btn or pause_resume_btn as the third button
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
|
| 132 |
+
button_outputs.append(self.tabs["train_tab"].components["delete_checkpoints_btn"])
|
| 133 |
+
elif "pause_resume_btn" in self.tabs["train_tab"].components:
|
| 134 |
+
button_outputs.append(self.tabs["train_tab"].components["pause_resume_btn"])
|
|
|
|
| 135 |
|
| 136 |
+
button_timer.tick(
|
| 137 |
+
fn=self.tabs["train_tab"].get_button_updates, # Use a new function for button-specific updates
|
| 138 |
+
outputs=button_outputs
|
| 139 |
)
|
| 140 |
|
| 141 |
# Dataset refresh timer (every 5 seconds)
|
|
|
|
| 182 |
if "model_type" in recovery_ui:
|
| 183 |
model_type_value = recovery_ui["model_type"]
|
| 184 |
|
| 185 |
+
# Remove " (LoRA)" suffix if present
|
| 186 |
+
if " (LoRA)" in model_type_value:
|
| 187 |
+
model_type_value = model_type_value.replace(" (LoRA)", "")
|
| 188 |
+
logger.info(f"Removed (LoRA) suffix from model type: {model_type_value}")
|
| 189 |
+
|
| 190 |
# If it's an internal name, convert to display name
|
| 191 |
if model_type_value not in MODEL_TYPES:
|
| 192 |
# Find the display name for this internal model type
|
|
|
|
| 213 |
ui_state["training_type"] = training_type_value
|
| 214 |
|
| 215 |
# Copy other parameters
|
| 216 |
+
for param in ["lora_rank", "lora_alpha", "train_steps",
|
| 217 |
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
| 218 |
if param in recovery_ui:
|
| 219 |
ui_state[param] = recovery_ui[param]
|
|
|
|
| 228 |
# Load values (potentially with recovery updates applied)
|
| 229 |
ui_state = self.load_ui_values()
|
| 230 |
|
| 231 |
+
# Ensure model_type is a valid display name
|
| 232 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
| 233 |
+
# Remove " (LoRA)" suffix if present
|
| 234 |
+
if " (LoRA)" in model_type_val:
|
| 235 |
+
model_type_val = model_type_val.replace(" (LoRA)", "")
|
| 236 |
+
logger.info(f"Removed (LoRA) suffix from model type: {model_type_val}")
|
| 237 |
+
|
| 238 |
+
# Ensure it's a valid model type in the dropdown
|
| 239 |
if model_type_val not in MODEL_TYPES:
|
| 240 |
+
# Convert from internal to display name or use default
|
| 241 |
+
model_type_found = False
|
| 242 |
for display_name, internal_name in MODEL_TYPES.items():
|
| 243 |
if internal_name == model_type_val:
|
| 244 |
model_type_val = display_name
|
| 245 |
+
model_type_found = True
|
| 246 |
break
|
| 247 |
+
# If still not found, use the first model type
|
| 248 |
+
if not model_type_found:
|
| 249 |
+
model_type_val = list(MODEL_TYPES.keys())[0]
|
| 250 |
+
logger.warning(f"Invalid model type '{model_type_val}', using default: {model_type_val}")
|
| 251 |
|
| 252 |
+
# Ensure training_type is a valid display name
|
| 253 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
| 254 |
if training_type_val not in TRAINING_TYPES:
|
| 255 |
+
# Convert from internal to display name or use default
|
| 256 |
+
training_type_found = False
|
| 257 |
for display_name, internal_name in TRAINING_TYPES.items():
|
| 258 |
if internal_name == training_type_val:
|
| 259 |
training_type_val = display_name
|
| 260 |
+
training_type_found = True
|
| 261 |
break
|
| 262 |
+
# If still not found, use the first training type
|
| 263 |
+
if not training_type_found:
|
| 264 |
+
training_type_val = list(TRAINING_TYPES.keys())[0]
|
| 265 |
+
logger.warning(f"Invalid training type '{training_type_val}', using default: {training_type_val}")
|
| 266 |
|
| 267 |
+
# Validate training preset
|
| 268 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
| 269 |
+
if training_preset not in TRAINING_PRESETS:
|
| 270 |
+
training_preset = list(TRAINING_PRESETS.keys())[0]
|
| 271 |
+
logger.warning(f"Invalid training preset '{training_preset}', using default: {training_preset}")
|
| 272 |
+
|
| 273 |
+
# Rest of the function remains unchanged
|
| 274 |
+
lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
| 275 |
+
lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
| 276 |
+
train_steps_val = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS))
|
| 277 |
+
batch_size_val = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE))
|
| 278 |
+
learning_rate_val = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE))
|
| 279 |
+
save_iterations_val = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS))
|
| 280 |
|
| 281 |
# Initial current task value
|
| 282 |
current_task_val = ""
|
|
|
|
| 295 |
training_type_val,
|
| 296 |
lora_rank_val,
|
| 297 |
lora_alpha_val,
|
| 298 |
+
train_steps_val,
|
| 299 |
batch_size_val,
|
| 300 |
learning_rate_val,
|
| 301 |
save_iterations_val,
|
|
|
|
| 311 |
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 312 |
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
| 313 |
ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
|
| 314 |
+
ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
|
| 315 |
+
ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
|
| 316 |
+
ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
| 317 |
+
ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
|
| 318 |
+
ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
|
| 319 |
+
ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
|
| 320 |
)
|
| 321 |
|
| 322 |
def update_ui_state(self, **kwargs):
|
|
|
|
| 332 |
ui_state = self.trainer.load_ui_state()
|
| 333 |
|
| 334 |
# Ensure proper type conversion for numeric values
|
| 335 |
+
ui_state["lora_rank"] = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
| 336 |
+
ui_state["lora_alpha"] = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
| 337 |
+
ui_state["train_steps"] = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS))
|
| 338 |
+
ui_state["batch_size"] = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE))
|
| 339 |
+
ui_state["learning_rate"] = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE))
|
| 340 |
+
ui_state["save_iterations"] = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS))
|
| 341 |
|
| 342 |
return ui_state
|
| 343 |
|