timm
/

Image Classification
timm
PyTorch
Safetensors
rwightman HF staff commited on
Commit
1513219
1 Parent(s): 09e8a8f
Files changed (4) hide show
  1. README.md +146 -0
  2. config.json +37 -0
  3. model.safetensors +3 -0
  4. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - image-classification
4
+ - timm
5
+ library_tag: timm
6
+ license: mit
7
+ datasets:
8
+ - imagenet-1k
9
+ - imagenet-22k
10
+ ---
11
+ # Model card for swinv2_large_window12to24_192to384.ms_in22k_ft_in1k
12
+
13
+ A Swin Transformer V2 image classification model. Pretrained on ImageNet-22k and fine-tuned on ImageNet-1k by paper authors.
14
+
15
+
16
+ ## Model Details
17
+ - **Model Type:** Image classification / feature backbone
18
+ - **Model Stats:**
19
+ - Params (M): 196.7
20
+ - GMACs: 116.1
21
+ - Activations (M): 407.8
22
+ - Image size: 384 x 384
23
+ - **Papers:**
24
+ - Swin Transformer V2: Scaling Up Capacity and Resolution: https://arxiv.org/abs/2111.09883
25
+ - **Original:** https://github.com/microsoft/Swin-Transformer
26
+ - **Dataset:** ImageNet-1k
27
+ - **Pretrain Dataset:** ImageNet-22k
28
+
29
+ ## Model Usage
30
+ ### Image Classification
31
+ ```python
32
+ from urllib.request import urlopen
33
+ from PIL import Image
34
+ import timm
35
+
36
+ img = Image.open(urlopen(
37
+ 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
38
+ ))
39
+
40
+ model = timm.create_model('swinv2_large_window12to24_192to384.ms_in22k_ft_in1k', pretrained=True)
41
+ model = model.eval()
42
+
43
+ # get model specific transforms (normalization, resize)
44
+ data_config = timm.data.resolve_model_data_config(model)
45
+ transforms = timm.data.create_transform(**data_config, is_training=False)
46
+
47
+ output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1
48
+
49
+ top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
50
+ ```
51
+
52
+ ### Feature Map Extraction
53
+ ```python
54
+ from urllib.request import urlopen
55
+ from PIL import Image
56
+ import timm
57
+
58
+ img = Image.open(urlopen(
59
+ 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
60
+ ))
61
+
62
+ model = timm.create_model(
63
+ 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k',
64
+ pretrained=True,
65
+ features_only=True,
66
+ )
67
+ model = model.eval()
68
+
69
+ # get model specific transforms (normalization, resize)
70
+ data_config = timm.data.resolve_model_data_config(model)
71
+ transforms = timm.data.create_transform(**data_config, is_training=False)
72
+
73
+ output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1
74
+
75
+ for o in output:
76
+ # print shape of each feature map in output
77
+ # e.g. for swin_base_patch4_window7_224 (NHWC output)
78
+ # torch.Size([1, 56, 56, 128])
79
+ # torch.Size([1, 28, 28, 256])
80
+ # torch.Size([1, 14, 14, 512])
81
+ # torch.Size([1, 7, 7, 1024])
82
+ # e.g. for swinv2_cr_small_ns_224 (NCHW output)
83
+ # torch.Size([1, 96, 56, 56])
84
+ # torch.Size([1, 192, 28, 28])
85
+ # torch.Size([1, 384, 14, 14])
86
+ # torch.Size([1, 768, 7, 7])
87
+ print(o.shape)
88
+ ```
89
+
90
+ ### Image Embeddings
91
+ ```python
92
+ from urllib.request import urlopen
93
+ from PIL import Image
94
+ import timm
95
+
96
+ img = Image.open(urlopen(
97
+ 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
98
+ ))
99
+
100
+ model = timm.create_model(
101
+ 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k',
102
+ pretrained=True,
103
+ num_classes=0, # remove classifier nn.Linear
104
+ )
105
+ model = model.eval()
106
+
107
+ # get model specific transforms (normalization, resize)
108
+ data_config = timm.data.resolve_model_data_config(model)
109
+ transforms = timm.data.create_transform(**data_config, is_training=False)
110
+
111
+ output = model(transforms(img).unsqueeze(0)) # output is (batch_size, num_features) shaped tensor
112
+
113
+ # or equivalently (without needing to set num_classes=0)
114
+
115
+ output = model.forward_features(transforms(img).unsqueeze(0))
116
+ # output is unpooled (ie.e a (batch_size, H, W, num_features) tensor for swin / swinv2
117
+ # or (batch_size, num_features, H, W) for swinv2_cr
118
+
119
+ output = model.forward_head(output, pre_logits=True)
120
+ # output is (batch_size, num_features) tensor
121
+ ```
122
+
123
+ ## Model Comparison
124
+ Explore the dataset and runtime metrics of this model in timm [model results](https://github.com/huggingface/pytorch-image-models/tree/main/results).
125
+
126
+
127
+ ## Citation
128
+ ```bibtex
129
+ @inproceedings{liu2021swinv2,
130
+ title={Swin Transformer V2: Scaling Up Capacity and Resolution},
131
+ author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
132
+ booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
133
+ year={2022}
134
+ }
135
+ ```
136
+ ```bibtex
137
+ @misc{rw2019timm,
138
+ author = {Ross Wightman},
139
+ title = {PyTorch Image Models},
140
+ year = {2019},
141
+ publisher = {GitHub},
142
+ journal = {GitHub repository},
143
+ doi = {10.5281/zenodo.4414861},
144
+ howpublished = {\url{https://github.com/huggingface/pytorch-image-models}}
145
+ }
146
+ ```
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "swinv2_large_window12to24_192to384",
3
+ "num_classes": 1000,
4
+ "num_features": 1536,
5
+ "global_pool": "avg",
6
+ "pretrained_cfg": {
7
+ "tag": "ms_in22k_ft_in1k",
8
+ "custom_load": false,
9
+ "input_size": [
10
+ 3,
11
+ 384,
12
+ 384
13
+ ],
14
+ "fixed_input_size": true,
15
+ "interpolation": "bicubic",
16
+ "crop_pct": 1.0,
17
+ "crop_mode": "center",
18
+ "mean": [
19
+ 0.485,
20
+ 0.456,
21
+ 0.406
22
+ ],
23
+ "std": [
24
+ 0.229,
25
+ 0.224,
26
+ 0.225
27
+ ],
28
+ "num_classes": 1000,
29
+ "pool_size": [
30
+ 12,
31
+ 12
32
+ ],
33
+ "first_conv": "patch_embed.proj",
34
+ "classifier": "head.fc",
35
+ "license": "mit"
36
+ }
37
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:204631b53b5d8e2a4382403673e15938caa59b052fbb6472ac428355a39f7545
3
+ size 813545880
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9309fd934dbbc4dd2ef3e5363ac714947a3bf0bd5478a3087e9134632cb328c6
3
+ size 813662073