| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| |
|
| | from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward |
| | from ..models.attention_processor import Attention, MochiAttention |
| |
|
| |
|
| | _ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin) |
| | _FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) |
| |
|
| | _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") |
| | _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) |
| | _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") |
| |
|
| | _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( |
| | { |
| | *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, |
| | *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, |
| | *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, |
| | } |
| | ) |
| |
|
| | |
| | _GO_LC_SUPPORTED_PYTORCH_LAYERS = ( |
| | torch.nn.Conv1d, |
| | torch.nn.Conv2d, |
| | torch.nn.Conv3d, |
| | torch.nn.ConvTranspose1d, |
| | torch.nn.ConvTranspose2d, |
| | torch.nn.ConvTranspose3d, |
| | torch.nn.Linear, |
| | |
| | |
| | ) |
| |
|
| |
|
| | def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]: |
| | for submodule_name, submodule in module.named_modules(): |
| | if submodule_name == fqn: |
| | return submodule |
| | return None |
| |
|