|
|
|
|
|
import json |
|
import os |
|
import pathlib |
|
import tempfile |
|
import time |
|
import unittest |
|
|
|
import pytest |
|
import torch |
|
from diffusers.utils import export_to_video |
|
from parameterized import parameterized |
|
from PIL import Image |
|
|
|
from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger |
|
|
|
|
|
os.environ["WANDB_MODE"] = "disabled" |
|
os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" |
|
|
|
from ..models.cogvideox.base_specification import DummyCogVideoXModelSpecification |
|
from ..models.cogview4.base_specification import DummyCogView4ModelSpecification |
|
from ..models.flux.base_specification import DummyFluxModelSpecification |
|
from ..models.hunyuan_video.base_specification import DummyHunyuanVideoModelSpecification |
|
from ..models.ltx_video.base_specification import DummyLTXVideoModelSpecification |
|
from ..models.wan.base_specification import DummyWanModelSpecification |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
def slow_down_tests(): |
|
yield |
|
|
|
|
|
|
|
time.sleep(5) |
|
|
|
|
|
class SFTTrainerFastTestsMixin: |
|
model_specification_cls = None |
|
num_data_files = 4 |
|
num_frames = 4 |
|
height = 64 |
|
width = 64 |
|
|
|
def setUp(self): |
|
self.tmpdir = tempfile.TemporaryDirectory() |
|
self.data_files = [] |
|
for i in range(self.num_data_files): |
|
data_file = pathlib.Path(self.tmpdir.name) / f"{i}.mp4" |
|
export_to_video( |
|
[Image.new("RGB", (self.width, self.height))] * self.num_frames, data_file.as_posix(), fps=2 |
|
) |
|
self.data_files.append(data_file.as_posix()) |
|
|
|
csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv" |
|
with open(csv_filename.as_posix(), "w") as f: |
|
f.write("file_name,caption\n") |
|
for i in range(self.num_data_files): |
|
prompt = f"A cat ruling the world - {i}" |
|
f.write(f'{i}.mp4,"{prompt}"\n') |
|
|
|
dataset_config = { |
|
"datasets": [ |
|
{ |
|
"data_root": self.tmpdir.name, |
|
"dataset_type": "video", |
|
"id_token": "TEST", |
|
"video_resolution_buckets": [[self.num_frames, self.height, self.width]], |
|
"reshape_mode": "bicubic", |
|
} |
|
] |
|
} |
|
|
|
self.dataset_config_filename = pathlib.Path(self.tmpdir.name) / "dataset_config.json" |
|
with open(self.dataset_config_filename.as_posix(), "w") as f: |
|
json.dump(dataset_config, f) |
|
|
|
def tearDown(self): |
|
self.tmpdir.cleanup() |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.destroy_process_group() |
|
time.sleep(3) |
|
|
|
def get_base_args(self) -> BaseArgs: |
|
args = BaseArgs() |
|
args.dataset_config = self.dataset_config_filename.as_posix() |
|
args.train_steps = 10 |
|
args.max_data_samples = 25 |
|
args.batch_size = 1 |
|
args.gradient_checkpointing = True |
|
args.output_dir = self.tmpdir.name |
|
args.checkpointing_steps = 6 |
|
args.enable_precomputation = False |
|
args.precomputation_items = self.num_data_files |
|
args.precomputation_dir = os.path.join(self.tmpdir.name, "precomputed") |
|
args.compile_scopes = "regional" |
|
|
|
|
|
return args |
|
|
|
def get_args(self) -> BaseArgs: |
|
raise NotImplementedError("`get_args` must be implemented in the subclass.") |
|
|
|
def _test_training(self, args: BaseArgs): |
|
model_specification = self.model_specification_cls() |
|
trainer = SFTTrainer(args, model_specification) |
|
trainer.run() |
|
|
|
|
|
|
|
|
|
|
|
class SFTTrainerLoRATestsMixin___Accelerate(SFTTrainerFastTestsMixin): |
|
def get_args(self) -> BaseArgs: |
|
args = self.get_base_args() |
|
args.parallel_backend = "accelerate" |
|
args.training_type = TrainingType.LORA |
|
args.rank = 4 |
|
args.lora_alpha = 4 |
|
args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] |
|
return args |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
args.layerwise_upcasting_modules = ["transformer"] |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
|
|
class SFTTrainerFullFinetuneTestsMixin___Accelerate(SFTTrainerFastTestsMixin): |
|
def get_args(self) -> BaseArgs: |
|
args = self.get_base_args() |
|
args.parallel_backend = "accelerate" |
|
args.training_type = TrainingType.FULL_FINETUNE |
|
return args |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
|
|
class SFTTrainerCogVideoXLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): |
|
model_specification_cls = DummyCogVideoXModelSpecification |
|
|
|
|
|
class SFTTrainerCogVideoXFullFinetuneTests___Accelerate( |
|
SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase |
|
): |
|
model_specification_cls = DummyCogVideoXModelSpecification |
|
|
|
|
|
class SFTTrainerCogView4LoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): |
|
model_specification_cls = DummyCogView4ModelSpecification |
|
|
|
|
|
class SFTTrainerCogView4FullFinetuneTests___Accelerate( |
|
SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase |
|
): |
|
model_specification_cls = DummyCogView4ModelSpecification |
|
|
|
|
|
class SFTTrainerFluxLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): |
|
model_specification_cls = DummyFluxModelSpecification |
|
|
|
|
|
class SFTTrainerFluxFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase): |
|
model_specification_cls = DummyFluxModelSpecification |
|
|
|
|
|
class SFTTrainerHunyuanVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): |
|
model_specification_cls = DummyHunyuanVideoModelSpecification |
|
|
|
|
|
class SFTTrainerHunyuanVideoFullFinetuneTests___Accelerate( |
|
SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase |
|
): |
|
model_specification_cls = DummyHunyuanVideoModelSpecification |
|
|
|
|
|
class SFTTrainerLTXVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): |
|
model_specification_cls = DummyLTXVideoModelSpecification |
|
|
|
|
|
class SFTTrainerLTXVideoFullFinetuneTests___Accelerate( |
|
SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase |
|
): |
|
model_specification_cls = DummyLTXVideoModelSpecification |
|
|
|
|
|
class SFTTrainerWanLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase): |
|
model_specification_cls = DummyWanModelSpecification |
|
|
|
|
|
class SFTTrainerWanFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase): |
|
model_specification_cls = DummyWanModelSpecification |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SFTTrainerLoRATestsMixin___PTD(SFTTrainerFastTestsMixin): |
|
def get_args(self) -> BaseArgs: |
|
args = self.get_base_args() |
|
args.parallel_backend = "ptd" |
|
args.training_type = TrainingType.LORA |
|
args.rank = 4 |
|
args.lora_alpha = 4 |
|
args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] |
|
return args |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
args.layerwise_upcasting_modules = ["transformer"] |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
args.compile_modules = ["transformer"] |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 2 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___layerwise_upcasting___dp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
args.layerwise_upcasting_modules = ["transformer"] |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
args.compile_modules = ["transformer"] |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 2 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_shards = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_shards = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.dp_shards = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.tp_degree = 2 |
|
args.batch_size = 2 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@unittest.skip( |
|
"TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." |
|
) |
|
@parameterized.expand([(True,)]) |
|
def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.cp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@unittest.skip( |
|
"TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." |
|
) |
|
@parameterized.expand([(True,)]) |
|
def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.cp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
|
|
class SFTTrainerFullFinetuneTestsMixin___PTD(SFTTrainerFastTestsMixin): |
|
def get_args(self) -> BaseArgs: |
|
args = self.get_base_args() |
|
args.parallel_backend = "ptd" |
|
args.training_type = TrainingType.FULL_FINETUNE |
|
return args |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
args.compile_modules = ["transformer"] |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 1 |
|
args.batch_size = 2 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
args.compile_modules = ["transformer"] |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.batch_size = 2 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_shards = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(True,)]) |
|
def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_shards = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.dp_shards = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@parameterized.expand([(False,), (True,)]) |
|
def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.tp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@unittest.skip( |
|
"TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." |
|
) |
|
@parameterized.expand([(True,)]) |
|
def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.cp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
@unittest.skip( |
|
"TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test." |
|
) |
|
@parameterized.expand([(True,)]) |
|
def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool): |
|
args = self.get_args() |
|
args.dp_degree = 2 |
|
args.cp_degree = 2 |
|
args.batch_size = 1 |
|
args.enable_precomputation = enable_precomputation |
|
self._test_training(args) |
|
|
|
|
|
class SFTTrainerCogVideoXLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyCogVideoXModelSpecification |
|
|
|
|
|
class SFTTrainerCogVideoXFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyCogVideoXModelSpecification |
|
|
|
|
|
class SFTTrainerCogView4LoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyCogView4ModelSpecification |
|
|
|
|
|
class SFTTrainerCogView4FullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyCogView4ModelSpecification |
|
|
|
|
|
class SFTTrainerFluxLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyFluxModelSpecification |
|
|
|
|
|
class SFTTrainerFluxFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyFluxModelSpecification |
|
|
|
|
|
class SFTTrainerHunyuanVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyHunyuanVideoModelSpecification |
|
|
|
|
|
class SFTTrainerHunyuanVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyHunyuanVideoModelSpecification |
|
|
|
|
|
class SFTTrainerLTXVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyLTXVideoModelSpecification |
|
|
|
|
|
class SFTTrainerLTXVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyLTXVideoModelSpecification |
|
|
|
|
|
class SFTTrainerWanLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyWanModelSpecification |
|
|
|
|
|
class SFTTrainerWanFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase): |
|
model_specification_cls = DummyWanModelSpecification |
|
|
|
|
|
|
|
|