|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
class WNet3DConfig(PretrainedConfig): |
|
model_type = "WNet" |
|
def __init__( |
|
self, |
|
in_ch=1, |
|
out_ch=5, |
|
init_features=64, |
|
**kwargs): |
|
self.in_ch = in_ch |
|
self.out_ch = out_ch |
|
self.init_features = init_features |
|
super().__init__(**kwargs) |
|
|
|
class AttWNet3DConfig(PretrainedConfig): |
|
model_type = "AttWNet" |
|
def __init__( |
|
self, |
|
in_ch=1, |
|
out_ch=5, |
|
init_features=64, |
|
**kwargs): |
|
self.in_ch = in_ch |
|
self.out_ch = out_ch |
|
self.init_features = init_features |
|
super().__init__(**kwargs) |
|
|
|
class WNetMSS3DConfig(PretrainedConfig): |
|
model_type = "WNetMSS" |
|
def __init__( |
|
self, |
|
in_ch=1, |
|
out_ch=5, |
|
init_features=64, |
|
**kwargs): |
|
self.in_ch = in_ch |
|
self.out_ch = out_ch |
|
self.init_features = init_features |
|
super().__init__(**kwargs) |