|
|
import contextlib |
|
|
import functools |
|
|
import os |
|
|
from typing import Callable, List, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.backends |
|
|
from diffusers.hooks import HookRegistry, ModelHook |
|
|
|
|
|
from finetrainers import logging, parallel, patches |
|
|
from finetrainers.args import BaseArgsType |
|
|
from finetrainers.logging import get_logger |
|
|
from finetrainers.models.attention_dispatch import AttentionProvider, _AttentionProviderRegistry |
|
|
from finetrainers.state import State |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
_LATEST_ACTIVE_MODULE_HOOK = "latest_active_module_hook" |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
def __init__(self, args: BaseArgsType): |
|
|
self.args = args |
|
|
|
|
|
self.state = State() |
|
|
|
|
|
self._module_name_providers_training = _parse_attention_providers(args.attn_provider_training) |
|
|
self._module_name_providers_inference = _parse_attention_providers(args.attn_provider_inference) |
|
|
|
|
|
self._init_distributed() |
|
|
self._init_config_options() |
|
|
|
|
|
|
|
|
patches.perform_patches_for_training(self.args, self.state.parallel_backend) |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def attention_provider_ctx(self, training: bool = True): |
|
|
name_providers_active = ( |
|
|
self._module_name_providers_training if training else self._module_name_providers_inference |
|
|
) |
|
|
name_providers_dict = dict(name_providers_active) |
|
|
default_provider = _AttentionProviderRegistry._active_provider |
|
|
|
|
|
all_registered_module_names = [ |
|
|
attr for attr in dir(self) if isinstance(getattr(self, attr, None), torch.nn.Module) |
|
|
] |
|
|
for module_name in all_registered_module_names: |
|
|
if module_name in name_providers_dict: |
|
|
continue |
|
|
name_providers_dict[module_name] = default_provider |
|
|
|
|
|
module_providers_dict = {} |
|
|
for module_name, provider in name_providers_dict.items(): |
|
|
module = getattr(self, module_name, None) |
|
|
if module is not None: |
|
|
module_providers_dict[module] = (module_name, provider) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def callback(m: torch.nn.Module): |
|
|
module_name, provider = module_providers_dict[m] |
|
|
|
|
|
if module_name in ["transformer"] and self.state.parallel_backend.context_parallel_enabled: |
|
|
if not _AttentionProviderRegistry.supports_context_parallel(provider): |
|
|
raise ValueError( |
|
|
f"Attention provider {provider} does not support context parallel. Please use a different provider." |
|
|
) |
|
|
_AttentionProviderRegistry._set_context_parallel( |
|
|
mesh=self.state.parallel_backend.get_mesh()["cp"], convert_to_fp32=True, rotate_method="allgather" |
|
|
) |
|
|
_AttentionProviderRegistry._active_provider = provider |
|
|
|
|
|
|
|
|
if "vae" in name_providers_dict: |
|
|
_apply_forward_hooks_hack(self.vae, name_providers_dict["vae"]) |
|
|
|
|
|
for module in module_providers_dict.keys(): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(module) |
|
|
hook = LatestActiveModuleHook(callback) |
|
|
registry.register_hook(hook, _LATEST_ACTIVE_MODULE_HOOK) |
|
|
|
|
|
yield |
|
|
|
|
|
_AttentionProviderRegistry._active_provider = default_provider |
|
|
_AttentionProviderRegistry._set_context_parallel(reset=True) |
|
|
for module in module_providers_dict.keys(): |
|
|
registry: HookRegistry = module._diffusers_hook |
|
|
registry.remove_hook(_LATEST_ACTIVE_MODULE_HOOK) |
|
|
|
|
|
def _init_distributed(self) -> None: |
|
|
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) |
|
|
|
|
|
|
|
|
backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) |
|
|
self.state.parallel_backend = backend_cls( |
|
|
world_size=world_size, |
|
|
pp_degree=self.args.pp_degree, |
|
|
dp_degree=self.args.dp_degree, |
|
|
dp_shards=self.args.dp_shards, |
|
|
cp_degree=self.args.cp_degree, |
|
|
tp_degree=self.args.tp_degree, |
|
|
backend="nccl", |
|
|
timeout=self.args.init_timeout, |
|
|
logging_dir=self.args.logging_dir, |
|
|
output_dir=self.args.output_dir, |
|
|
gradient_accumulation_steps=self.args.gradient_accumulation_steps, |
|
|
) |
|
|
|
|
|
if self.args.seed is not None: |
|
|
self.state.parallel_backend.enable_determinism(self.args.seed) |
|
|
|
|
|
def _init_logging(self) -> None: |
|
|
logging._set_parallel_backend(self.state.parallel_backend) |
|
|
logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process) |
|
|
logger.info("Initialized FineTrainers") |
|
|
|
|
|
def _init_trackers(self) -> None: |
|
|
|
|
|
trackers = [self.args.report_to] |
|
|
experiment_name = self.args.tracker_name or "finetrainers-experiment" |
|
|
self.state.parallel_backend.initialize_trackers( |
|
|
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir |
|
|
) |
|
|
|
|
|
def _init_config_options(self) -> None: |
|
|
|
|
|
if self.args.allow_tf32 and torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.set_float32_matmul_precision(self.args.float32_matmul_precision) |
|
|
|
|
|
@property |
|
|
def tracker(self): |
|
|
return self.state.parallel_backend.tracker |
|
|
|
|
|
|
|
|
class LatestActiveModuleHook(ModelHook): |
|
|
def __init__(self, callback: Callable[[torch.nn.Module], None] = None): |
|
|
super().__init__() |
|
|
self.callback = callback |
|
|
|
|
|
def pre_forward(self, module, *args, **kwargs): |
|
|
self.callback(module) |
|
|
return args, kwargs |
|
|
|
|
|
|
|
|
def _parse_attention_providers(attn_providers: List[str] = None) -> List[Tuple[str, AttentionProvider]]: |
|
|
parsed_providers = [] |
|
|
if attn_providers: |
|
|
for provider_str in attn_providers: |
|
|
parts = provider_str.split(":") |
|
|
if len(parts) != 2: |
|
|
raise ValueError( |
|
|
f"Invalid attention provider format: '{provider_str}'. Expected 'module_name:provider_name'." |
|
|
) |
|
|
parts[1] = AttentionProvider(parts[1]) |
|
|
parsed_providers.append(tuple(parts)) |
|
|
return parsed_providers |
|
|
|
|
|
|
|
|
|
|
|
def _apply_forward_hooks_hack(module: torch.nn.Module, provider: AttentionProvider): |
|
|
if hasattr(module, "_finetrainers_wrapped_methods"): |
|
|
return |
|
|
|
|
|
def create_wrapper(old_method): |
|
|
@functools.wraps(old_method) |
|
|
def wrapper(*args, **kwargs): |
|
|
_AttentionProviderRegistry._set_context_parallel(reset=True) |
|
|
old_provider = _AttentionProviderRegistry._active_provider |
|
|
_AttentionProviderRegistry._active_provider = provider |
|
|
output = old_method(*args, **kwargs) |
|
|
_AttentionProviderRegistry._active_provider = old_provider |
|
|
return output |
|
|
|
|
|
return wrapper |
|
|
|
|
|
methods = ["encode", "decode", "_encode", "_decode", "tiled_encode", "tiled_decode"] |
|
|
finetrainers_wrapped_methods = [] |
|
|
for method_name in methods: |
|
|
if not hasattr(module, method_name): |
|
|
continue |
|
|
method = getattr(module, method_name) |
|
|
wrapper = create_wrapper(method) |
|
|
setattr(module, method_name, wrapper) |
|
|
finetrainers_wrapped_methods.append(method_name) |
|
|
module._finetrainers_wrapped_methods = finetrainers_wrapped_methods |
|
|
|