vobecant
commited on
Commit
•
179cb5d
1
Parent(s):
bd42ce3
Initial commit.
Browse files- .idea/workspace.xml +10 -4
- segmenter_model/factory.py +10 -16
.idea/workspace.xml
CHANGED
@@ -2,8 +2,7 @@
|
|
2 |
<project version="4">
|
3 |
<component name="ChangeListManager">
|
4 |
<list default="true" id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="Initial commit.">
|
5 |
-
<change beforePath="$PROJECT_DIR
|
6 |
-
<change beforePath="$PROJECT_DIR$/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
|
7 |
</list>
|
8 |
<option name="SHOW_DIALOG" value="false" />
|
9 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
@@ -51,7 +50,7 @@
|
|
51 |
<option name="number" value="Default" />
|
52 |
<option name="presentableId" value="Default" />
|
53 |
<updated>1647350746642</updated>
|
54 |
-
<workItem from="1647350750956" duration="
|
55 |
</task>
|
56 |
<task id="LOCAL-00001" summary="Initial commit.">
|
57 |
<created>1647352693910</created>
|
@@ -137,7 +136,14 @@
|
|
137 |
<option name="project" value="LOCAL" />
|
138 |
<updated>1647356274640</updated>
|
139 |
</task>
|
140 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
<servers />
|
142 |
</component>
|
143 |
<component name="TypeScriptGeneratedFilesManager">
|
|
|
2 |
<project version="4">
|
3 |
<component name="ChangeListManager">
|
4 |
<list default="true" id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="Initial commit.">
|
5 |
+
<change beforePath="$PROJECT_DIR$/segmenter_model/factory.py" beforeDir="false" afterPath="$PROJECT_DIR$/segmenter_model/factory.py" afterDir="false" />
|
|
|
6 |
</list>
|
7 |
<option name="SHOW_DIALOG" value="false" />
|
8 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
|
50 |
<option name="number" value="Default" />
|
51 |
<option name="presentableId" value="Default" />
|
52 |
<updated>1647350746642</updated>
|
53 |
+
<workItem from="1647350750956" duration="5731000" />
|
54 |
</task>
|
55 |
<task id="LOCAL-00001" summary="Initial commit.">
|
56 |
<created>1647352693910</created>
|
|
|
136 |
<option name="project" value="LOCAL" />
|
137 |
<updated>1647356274640</updated>
|
138 |
</task>
|
139 |
+
<task id="LOCAL-00013" summary="Initial commit.">
|
140 |
+
<created>1647356326582</created>
|
141 |
+
<option name="number" value="00013" />
|
142 |
+
<option name="presentableId" value="LOCAL-00013" />
|
143 |
+
<option name="project" value="LOCAL" />
|
144 |
+
<updated>1647356326582</updated>
|
145 |
+
</task>
|
146 |
+
<option name="localTasksCounter" value="14" />
|
147 |
<servers />
|
148 |
</component>
|
149 |
<component name="TypeScriptGeneratedFilesManager">
|
segmenter_model/factory.py
CHANGED
@@ -1,18 +1,17 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
import yaml
|
3 |
-
import torch
|
4 |
-
import math
|
5 |
import os
|
6 |
-
|
7 |
|
|
|
|
|
8 |
from timm.models.helpers import load_pretrained, load_custom_pretrained
|
9 |
-
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
|
10 |
from timm.models.registry import register_model
|
11 |
from timm.models.vision_transformer import _create_vision_transformer
|
|
|
|
|
|
|
|
|
12 |
from segmenter_model.decoder import MaskTransformer
|
13 |
from segmenter_model.segmenter import Segmenter
|
14 |
-
import segmenter_model.torch as ptu
|
15 |
-
|
16 |
from segmenter_model.vit_dino import vit_small, VisionTransformer
|
17 |
|
18 |
|
@@ -48,14 +47,9 @@ def create_vit(model_cfg):
|
|
48 |
model_cfg['drop_rate'] = model_cfg['dropout']
|
49 |
model = vit_small(**model_cfg)
|
50 |
# hard-coded for now, too lazy
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
pretrained_weights = ciirc_path
|
55 |
-
elif os.path.exists(karolina_path):
|
56 |
-
pretrained_weights = karolina_path
|
57 |
-
else:
|
58 |
-
raise Exception('DINO weights not found!')
|
59 |
model.load_state_dict(torch.load(pretrained_weights), strict=True)
|
60 |
else:
|
61 |
model = torch.hub.load('facebookresearch/dino:main', backbone)
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from pathlib import Path
|
3 |
|
4 |
+
import requests
|
5 |
+
import yaml
|
6 |
from timm.models.helpers import load_pretrained, load_custom_pretrained
|
|
|
7 |
from timm.models.registry import register_model
|
8 |
from timm.models.vision_transformer import _create_vision_transformer
|
9 |
+
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
|
10 |
+
|
11 |
+
import segmenter_model.torch as ptu
|
12 |
+
import torch
|
13 |
from segmenter_model.decoder import MaskTransformer
|
14 |
from segmenter_model.segmenter import Segmenter
|
|
|
|
|
15 |
from segmenter_model.vit_dino import vit_small, VisionTransformer
|
16 |
|
17 |
|
|
|
47 |
model_cfg['drop_rate'] = model_cfg['dropout']
|
48 |
model = vit_small(**model_cfg)
|
49 |
# hard-coded for now, too lazy
|
50 |
+
pretrained_weights = 'dino_deitsmall16_pretrain.pth'
|
51 |
+
if not os.path.exists(pretrained_weights):
|
52 |
+
requests.get(pretrained_weights, allow_redirects=True)
|
|
|
|
|
|
|
|
|
|
|
53 |
model.load_state_dict(torch.load(pretrained_weights), strict=True)
|
54 |
else:
|
55 |
model = torch.hub.load('facebookresearch/dino:main', backbone)
|