jadechoghari
commited on
Commit
•
694a3fd
1
Parent(s):
efc43f7
Update unet/openaimodel.py
Browse files- unet/openaimodel.py +4 -3
unet/openaimodel.py
CHANGED
@@ -8,7 +8,7 @@ import torch as th
|
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
-
from util import (
|
12 |
checkpoint,
|
13 |
conv_nd,
|
14 |
linear,
|
@@ -16,13 +16,14 @@ from util import (
|
|
16 |
zero_module,
|
17 |
normalization,
|
18 |
timestep_embedding,
|
|
|
19 |
)
|
20 |
|
21 |
# replace with custom transformer
|
22 |
-
from mv_attention import SPADTransformer as SpatialTransformer
|
|
|
23 |
|
24 |
|
25 |
-
from util import exists
|
26 |
from torch import autocast
|
27 |
|
28 |
# dummy replace
|
|
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
+
from .util import (
|
12 |
checkpoint,
|
13 |
conv_nd,
|
14 |
linear,
|
|
|
16 |
zero_module,
|
17 |
normalization,
|
18 |
timestep_embedding,
|
19 |
+
exists
|
20 |
)
|
21 |
|
22 |
# replace with custom transformer
|
23 |
+
from .mv_attention import SPADTransformer as SpatialTransformer
|
24 |
+
|
25 |
|
26 |
|
|
|
27 |
from torch import autocast
|
28 |
|
29 |
# dummy replace
|