TabPFN commited on
Commit
6022ee7
1 Parent(s): fcd83f4

Update TabPFN/layer.py

Browse files
Files changed (1) hide show
  1. TabPFN/layer.py +10 -2
TabPFN/layer.py CHANGED
@@ -1,8 +1,16 @@
1
  from functools import partial
2
 
3
  from torch import nn
4
- from torch.nn.modules.transformer import *
5
- from torch.nn.modules.transformer import _get_activation_fn
 
 
 
 
 
 
 
 
6
 
7
  from torch.utils.checkpoint import checkpoint
8
 
 
1
  from functools import partial
2
 
3
  from torch import nn
4
+ from torch.nn.modules.transformer import (
5
+ _get_activation_fn,
6
+ Module,
7
+ Tensor,
8
+ Optional,
9
+ MultiheadAttention,
10
+ Linear,
11
+ Dropout,
12
+ LayerNorm,
13
+ )
14
 
15
  from torch.utils.checkpoint import checkpoint
16