not-lain commited on
Commit
de880d5
·
verified ·
1 Parent(s): 51185ce

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +1 -14
train.py CHANGED
@@ -10,8 +10,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint
10
  from torch.utils.data import DataLoader
11
  from huggingface_hub import PyTorchModelHubMixin
12
 
13
- from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet \
14
- , InSPyReNet, InSPyReNet_Res2Net50, InSPyReNet_SwinB
15
 
16
 
17
  # warnings.filterwarnings("ignore")
@@ -23,18 +22,6 @@ def get_net(net_name, img_size):
23
  return ISNetDIS()
24
  elif net_name == "isnet_is":
25
  return ISNetDIS()
26
- elif net_name == "isnet_gt":
27
- return ISNetGTEncoder()
28
- elif net_name == "u2net":
29
- return U2NET_full2()
30
- elif net_name == "u2netl":
31
- return U2NET_lite2()
32
- elif net_name == "modnet":
33
- return MODNet()
34
- elif net_name == "inspyrnet_res":
35
- return InSPyReNet_Res2Net50(base_size=img_size)
36
- elif net_name == "inspyrnet_swin":
37
- return InSPyReNet_SwinB(base_size=img_size)
38
  raise NotImplementedError
39
 
40
 
 
10
  from torch.utils.data import DataLoader
11
  from huggingface_hub import PyTorchModelHubMixin
12
 
13
+ from isnet import ISNetDIS
 
14
 
15
 
16
  # warnings.filterwarnings("ignore")
 
22
  return ISNetDIS()
23
  elif net_name == "isnet_is":
24
  return ISNetDIS()
 
 
 
 
 
 
 
 
 
 
 
 
25
  raise NotImplementedError
26
 
27