File size: 371 Bytes
6a97a18
 
 
 
 
 
 
 
 
 
d154b53
6a97a18
 
d154b53
 
6a97a18
d154b53
6a97a18
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# modeling_ndlinear_dit.py

import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig

from mlp import NdMlp
from ndlinear import NdLinear
from models_hf import DiT, DiTConfig

class DiTConfig(PretrainedConfig):
    model_type = "ndlinear_dit"

class DiT(PreTrainedModel):
    config_class = DiTConfig

__all__ = ["DiT", "DiTConfig"]