Spaces:
Running
Running
from enum import Enum | |
from typing import Type | |
from .models import ModelSpecification | |
from .models.cogvideox import CogVideoXModelSpecification | |
from .models.hunyuan_video import HunyuanVideoModelSpecification | |
from .models.ltx_video import LTXVideoModelSpecification | |
from .models.wan import WanModelSpecification | |
class ModelType(str, Enum): | |
COGVIDEOX = "cogvideox" | |
HUNYUAN_VIDEO = "hunyuan_video" | |
LTX_VIDEO = "ltx_video" | |
WAN = "wan" | |
class TrainingType(str, Enum): | |
LORA = "lora" | |
FULL_FINETUNE = "full-finetune" | |
SUPPORTED_MODEL_CONFIGS = { | |
ModelType.HUNYUAN_VIDEO: { | |
TrainingType.LORA: HunyuanVideoModelSpecification, | |
TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification, | |
}, | |
ModelType.LTX_VIDEO: { | |
TrainingType.LORA: LTXVideoModelSpecification, | |
TrainingType.FULL_FINETUNE: LTXVideoModelSpecification, | |
}, | |
ModelType.COGVIDEOX: { | |
TrainingType.LORA: CogVideoXModelSpecification, | |
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification, | |
}, | |
ModelType.WAN: { | |
TrainingType.LORA: WanModelSpecification, | |
TrainingType.FULL_FINETUNE: WanModelSpecification, | |
}, | |
} | |
def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]: | |
if model_name not in SUPPORTED_MODEL_CONFIGS: | |
raise ValueError( | |
f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" | |
) | |
if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]: | |
raise ValueError( | |
f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" | |
) | |
return SUPPORTED_MODEL_CONFIGS[model_name][training_type] | |