|
from enum import Enum |
|
from typing import Type |
|
|
|
from .models import ModelSpecification |
|
from .models.cogvideox import CogVideoXModelSpecification |
|
from .models.cogview4 import CogView4ControlModelSpecification, CogView4ModelSpecification |
|
from .models.flux import FluxModelSpecification |
|
from .models.hunyuan_video import HunyuanVideoModelSpecification |
|
from .models.ltx_video import LTXVideoModelSpecification |
|
from .models.wan import WanControlModelSpecification, WanModelSpecification |
|
|
|
|
|
class ModelType(str, Enum): |
|
COGVIDEOX = "cogvideox" |
|
COGVIEW4 = "cogview4" |
|
FLUX = "flux" |
|
HUNYUAN_VIDEO = "hunyuan_video" |
|
LTX_VIDEO = "ltx_video" |
|
WAN = "wan" |
|
|
|
|
|
class TrainingType(str, Enum): |
|
|
|
LORA = "lora" |
|
FULL_FINETUNE = "full-finetune" |
|
|
|
|
|
CONTROL_LORA = "control-lora" |
|
CONTROL_FULL_FINETUNE = "control-full-finetune" |
|
|
|
|
|
SUPPORTED_MODEL_CONFIGS = { |
|
|
|
|
|
ModelType.COGVIDEOX: { |
|
TrainingType.LORA: CogVideoXModelSpecification, |
|
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification, |
|
}, |
|
ModelType.COGVIEW4: { |
|
TrainingType.LORA: CogView4ModelSpecification, |
|
TrainingType.FULL_FINETUNE: CogView4ModelSpecification, |
|
TrainingType.CONTROL_LORA: CogView4ControlModelSpecification, |
|
TrainingType.CONTROL_FULL_FINETUNE: CogView4ControlModelSpecification, |
|
}, |
|
ModelType.FLUX: { |
|
TrainingType.LORA: FluxModelSpecification, |
|
TrainingType.FULL_FINETUNE: FluxModelSpecification, |
|
}, |
|
ModelType.HUNYUAN_VIDEO: { |
|
TrainingType.LORA: HunyuanVideoModelSpecification, |
|
TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification, |
|
}, |
|
ModelType.LTX_VIDEO: { |
|
TrainingType.LORA: LTXVideoModelSpecification, |
|
TrainingType.FULL_FINETUNE: LTXVideoModelSpecification, |
|
}, |
|
ModelType.WAN: { |
|
TrainingType.LORA: WanModelSpecification, |
|
TrainingType.FULL_FINETUNE: WanModelSpecification, |
|
TrainingType.CONTROL_LORA: WanControlModelSpecification, |
|
TrainingType.CONTROL_FULL_FINETUNE: WanControlModelSpecification, |
|
}, |
|
} |
|
|
|
|
|
def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]: |
|
if model_name not in SUPPORTED_MODEL_CONFIGS: |
|
raise ValueError( |
|
f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" |
|
) |
|
if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]: |
|
raise ValueError( |
|
f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" |
|
) |
|
return SUPPORTED_MODEL_CONFIGS[model_name][training_type] |
|
|