from transformers import PretrainedConfig | |
from typing import List | |
class UNetFTSRConfig(PretrainedConfig): | |
model_type = "UNet" | |
def __init__( | |
self, | |
in_channels=1, | |
n_classes=1, | |
depth=3, | |
wf=6, | |
padding=True, | |
batch_norm=False, | |
up_mode='upconv', | |
dropout=False, | |
**kwargs): | |
self.in_channels = in_channels | |
self.n_classes = n_classes | |
self.depth = depth | |
self.wf = wf | |
self.padding = padding | |
self.batch_norm = batch_norm | |
self.up_mode = up_mode | |
self.dropout = dropout | |
super().__init__(**kwargs) | |