vobecant commited on
Commit
179cb5d
1 Parent(s): bd42ce3

Initial commit.

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +10 -4
  2. segmenter_model/factory.py +10 -16
.idea/workspace.xml CHANGED
@@ -2,8 +2,7 @@
2
  <project version="4">
3
  <component name="ChangeListManager">
4
  <list default="true" id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="Initial commit.">
5
- <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
6
- <change beforePath="$PROJECT_DIR$/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
7
  </list>
8
  <option name="SHOW_DIALOG" value="false" />
9
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -51,7 +50,7 @@
51
  <option name="number" value="Default" />
52
  <option name="presentableId" value="Default" />
53
  <updated>1647350746642</updated>
54
- <workItem from="1647350750956" duration="5496000" />
55
  </task>
56
  <task id="LOCAL-00001" summary="Initial commit.">
57
  <created>1647352693910</created>
@@ -137,7 +136,14 @@
137
  <option name="project" value="LOCAL" />
138
  <updated>1647356274640</updated>
139
  </task>
140
- <option name="localTasksCounter" value="13" />
 
 
 
 
 
 
 
141
  <servers />
142
  </component>
143
  <component name="TypeScriptGeneratedFilesManager">
 
2
  <project version="4">
3
  <component name="ChangeListManager">
4
  <list default="true" id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="Initial commit.">
5
+ <change beforePath="$PROJECT_DIR$/segmenter_model/factory.py" beforeDir="false" afterPath="$PROJECT_DIR$/segmenter_model/factory.py" afterDir="false" />
 
6
  </list>
7
  <option name="SHOW_DIALOG" value="false" />
8
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
 
50
  <option name="number" value="Default" />
51
  <option name="presentableId" value="Default" />
52
  <updated>1647350746642</updated>
53
+ <workItem from="1647350750956" duration="5731000" />
54
  </task>
55
  <task id="LOCAL-00001" summary="Initial commit.">
56
  <created>1647352693910</created>
 
136
  <option name="project" value="LOCAL" />
137
  <updated>1647356274640</updated>
138
  </task>
139
+ <task id="LOCAL-00013" summary="Initial commit.">
140
+ <created>1647356326582</created>
141
+ <option name="number" value="00013" />
142
+ <option name="presentableId" value="LOCAL-00013" />
143
+ <option name="project" value="LOCAL" />
144
+ <updated>1647356326582</updated>
145
+ </task>
146
+ <option name="localTasksCounter" value="14" />
147
  <servers />
148
  </component>
149
  <component name="TypeScriptGeneratedFilesManager">
segmenter_model/factory.py CHANGED
@@ -1,18 +1,17 @@
1
- from pathlib import Path
2
- import yaml
3
- import torch
4
- import math
5
  import os
6
- import torch.nn as nn
7
 
 
 
8
  from timm.models.helpers import load_pretrained, load_custom_pretrained
9
- from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
10
  from timm.models.registry import register_model
11
  from timm.models.vision_transformer import _create_vision_transformer
 
 
 
 
12
  from segmenter_model.decoder import MaskTransformer
13
  from segmenter_model.segmenter import Segmenter
14
- import segmenter_model.torch as ptu
15
-
16
  from segmenter_model.vit_dino import vit_small, VisionTransformer
17
 
18
 
@@ -48,14 +47,9 @@ def create_vit(model_cfg):
48
  model_cfg['drop_rate'] = model_cfg['dropout']
49
  model = vit_small(**model_cfg)
50
  # hard-coded for now, too lazy
51
- ciirc_path = '/home/vobecant/PhD/weights/dino/dino_deitsmall16_pretrain.pth'
52
- karolina_path = '/scratch/project/dd-21-20/pretrained_weights/dino/dino_deitsmall16_pretrain.pth'
53
- if os.path.exists(ciirc_path):
54
- pretrained_weights = ciirc_path
55
- elif os.path.exists(karolina_path):
56
- pretrained_weights = karolina_path
57
- else:
58
- raise Exception('DINO weights not found!')
59
  model.load_state_dict(torch.load(pretrained_weights), strict=True)
60
  else:
61
  model = torch.hub.load('facebookresearch/dino:main', backbone)
 
 
 
 
 
1
  import os
2
+ from pathlib import Path
3
 
4
+ import requests
5
+ import yaml
6
  from timm.models.helpers import load_pretrained, load_custom_pretrained
 
7
  from timm.models.registry import register_model
8
  from timm.models.vision_transformer import _create_vision_transformer
9
+ from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
10
+
11
+ import segmenter_model.torch as ptu
12
+ import torch
13
  from segmenter_model.decoder import MaskTransformer
14
  from segmenter_model.segmenter import Segmenter
 
 
15
  from segmenter_model.vit_dino import vit_small, VisionTransformer
16
 
17
 
 
47
  model_cfg['drop_rate'] = model_cfg['dropout']
48
  model = vit_small(**model_cfg)
49
  # hard-coded for now, too lazy
50
+ pretrained_weights = 'dino_deitsmall16_pretrain.pth'
51
+ if not os.path.exists(pretrained_weights):
52
+ requests.get(pretrained_weights, allow_redirects=True)
 
 
 
 
 
53
  model.load_state_dict(torch.load(pretrained_weights), strict=True)
54
  else:
55
  model = torch.hub.load('facebookresearch/dino:main', backbone)