jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
from enum import Enum
from typing import Type
from .models import ModelSpecification
from .models.cogvideox import CogVideoXModelSpecification
from .models.cogview4 import CogView4ControlModelSpecification, CogView4ModelSpecification
from .models.flux import FluxModelSpecification
from .models.hunyuan_video import HunyuanVideoModelSpecification
from .models.ltx_video import LTXVideoModelSpecification
from .models.wan import WanControlModelSpecification, WanModelSpecification
class ModelType(str, Enum):
COGVIDEOX = "cogvideox"
COGVIEW4 = "cogview4"
FLUX = "flux"
HUNYUAN_VIDEO = "hunyuan_video"
LTX_VIDEO = "ltx_video"
WAN = "wan"
class TrainingType(str, Enum):
# SFT
LORA = "lora"
FULL_FINETUNE = "full-finetune"
# Control
CONTROL_LORA = "control-lora"
CONTROL_FULL_FINETUNE = "control-full-finetune"
SUPPORTED_MODEL_CONFIGS = {
# TODO(aryan): autogenerate this
# SFT
ModelType.COGVIDEOX: {
TrainingType.LORA: CogVideoXModelSpecification,
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
},
ModelType.COGVIEW4: {
TrainingType.LORA: CogView4ModelSpecification,
TrainingType.FULL_FINETUNE: CogView4ModelSpecification,
TrainingType.CONTROL_LORA: CogView4ControlModelSpecification,
TrainingType.CONTROL_FULL_FINETUNE: CogView4ControlModelSpecification,
},
ModelType.FLUX: {
TrainingType.LORA: FluxModelSpecification,
TrainingType.FULL_FINETUNE: FluxModelSpecification,
},
ModelType.HUNYUAN_VIDEO: {
TrainingType.LORA: HunyuanVideoModelSpecification,
TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
},
ModelType.LTX_VIDEO: {
TrainingType.LORA: LTXVideoModelSpecification,
TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
},
ModelType.WAN: {
TrainingType.LORA: WanModelSpecification,
TrainingType.FULL_FINETUNE: WanModelSpecification,
TrainingType.CONTROL_LORA: WanControlModelSpecification,
TrainingType.CONTROL_FULL_FINETUNE: WanControlModelSpecification,
},
}
def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]:
if model_name not in SUPPORTED_MODEL_CONFIGS:
raise ValueError(
f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
)
if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
raise ValueError(
f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
)
return SUPPORTED_MODEL_CONFIGS[model_name][training_type]