soumickmj's picture
Upload UNetFTSR
301db00 verified
raw
history blame contribute delete
687 Bytes
from transformers import PretrainedConfig
from typing import List
class UNetFTSRConfig(PretrainedConfig):
model_type = "UNet"
def __init__(
self,
in_channels=1,
n_classes=1,
depth=3,
wf=6,
padding=True,
batch_norm=False,
up_mode='upconv',
dropout=False,
**kwargs):
self.in_channels = in_channels
self.n_classes = n_classes
self.depth = depth
self.wf = wf
self.padding = padding
self.batch_norm = batch_norm
self.up_mode = up_mode
self.dropout = dropout
super().__init__(**kwargs)