File size: 371 Bytes
6a97a18 d154b53 6a97a18 d154b53 6a97a18 d154b53 6a97a18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
# modeling_ndlinear_dit.py
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from mlp import NdMlp
from ndlinear import NdLinear
from models_hf import DiT, DiTConfig
class DiTConfig(PretrainedConfig):
model_type = "ndlinear_dit"
class DiT(PreTrainedModel):
config_class = DiTConfig
__all__ = ["DiT", "DiTConfig"]
|