Warlord-K commited on
Commit
968b459
1 Parent(s): b8b5212

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. inference.py +11 -0
  2. model.pt +3 -0
  3. resnet.py +209 -0
  4. test.jpg +0 -0
inference.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from resnet import get_model
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision.transforms.functional import pil_to_tensor
5
+
6
+
7
+ model = get_model("r100", dropout=0.0, fp16=True, num_features=512).cuda()
8
+ model.load_state_dict(torch.load("model.pt"))
9
+ model.eval()
10
+ img = pil_to_tensor(Image.open("test.jpg").resize((112,112))).permute(0, 1, 2).to("cuda", torch.float16).unsqueeze(dim = 0)
11
+ embeddings = model(img)
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:460c0b4bb44130276107536c725c0a5a5500c05533bfb2871d9112f69bc7077f
3
+ size 261152502
resnet.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
6
+ using_ckpt = False
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9
+ """3x3 convolution with padding"""
10
+ return nn.Conv2d(in_planes,
11
+ out_planes,
12
+ kernel_size=3,
13
+ stride=stride,
14
+ padding=dilation,
15
+ groups=groups,
16
+ bias=False,
17
+ dilation=dilation)
18
+
19
+
20
+ def conv1x1(in_planes, out_planes, stride=1):
21
+ """1x1 convolution"""
22
+ return nn.Conv2d(in_planes,
23
+ out_planes,
24
+ kernel_size=1,
25
+ stride=stride,
26
+ bias=False)
27
+
28
+
29
+ class IBasicBlock(nn.Module):
30
+ expansion = 1
31
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
32
+ groups=1, base_width=64, dilation=1):
33
+ super(IBasicBlock, self).__init__()
34
+ if groups != 1 or base_width != 64:
35
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
36
+ if dilation > 1:
37
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
38
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
39
+ self.conv1 = conv3x3(inplanes, planes)
40
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
41
+ self.prelu = nn.PReLU(planes)
42
+ self.conv2 = conv3x3(planes, planes, stride)
43
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
44
+ self.downsample = downsample
45
+ self.stride = stride
46
+
47
+ def forward_impl(self, x):
48
+ identity = x
49
+ out = self.bn1(x)
50
+ out = self.conv1(out)
51
+ out = self.bn2(out)
52
+ out = self.prelu(out)
53
+ out = self.conv2(out)
54
+ out = self.bn3(out)
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+ out += identity
58
+ return out
59
+
60
+ def forward(self, x):
61
+ if self.training and using_ckpt:
62
+ return checkpoint(self.forward_impl, x)
63
+ else:
64
+ return self.forward_impl(x)
65
+
66
+
67
+ class IResNet(nn.Module):
68
+ fc_scale = 7 * 7
69
+ def __init__(self,
70
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
71
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
72
+ super(IResNet, self).__init__()
73
+ self.extra_gflops = 0.0
74
+ self.fp16 = fp16
75
+ self.inplanes = 64
76
+ self.dilation = 1
77
+ if replace_stride_with_dilation is None:
78
+ replace_stride_with_dilation = [False, False, False]
79
+ if len(replace_stride_with_dilation) != 3:
80
+ raise ValueError("replace_stride_with_dilation should be None "
81
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
82
+ self.groups = groups
83
+ self.base_width = width_per_group
84
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
85
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
86
+ self.prelu = nn.PReLU(self.inplanes)
87
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
88
+ self.layer2 = self._make_layer(block,
89
+ 128,
90
+ layers[1],
91
+ stride=2,
92
+ dilate=replace_stride_with_dilation[0])
93
+ self.layer3 = self._make_layer(block,
94
+ 256,
95
+ layers[2],
96
+ stride=2,
97
+ dilate=replace_stride_with_dilation[1])
98
+ self.layer4 = self._make_layer(block,
99
+ 512,
100
+ layers[3],
101
+ stride=2,
102
+ dilate=replace_stride_with_dilation[2])
103
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
104
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
105
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
106
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
107
+ nn.init.constant_(self.features.weight, 1.0)
108
+ self.features.weight.requires_grad = False
109
+
110
+ for m in self.modules():
111
+ if isinstance(m, nn.Conv2d):
112
+ nn.init.normal_(m.weight, 0, 0.1)
113
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
114
+ nn.init.constant_(m.weight, 1)
115
+ nn.init.constant_(m.bias, 0)
116
+
117
+ if zero_init_residual:
118
+ for m in self.modules():
119
+ if isinstance(m, IBasicBlock):
120
+ nn.init.constant_(m.bn2.weight, 0)
121
+
122
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
123
+ downsample = None
124
+ previous_dilation = self.dilation
125
+ if dilate:
126
+ self.dilation *= stride
127
+ stride = 1
128
+ if stride != 1 or self.inplanes != planes * block.expansion:
129
+ downsample = nn.Sequential(
130
+ conv1x1(self.inplanes, planes * block.expansion, stride),
131
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
132
+ )
133
+ layers = []
134
+ layers.append(
135
+ block(self.inplanes, planes, stride, downsample, self.groups,
136
+ self.base_width, previous_dilation))
137
+ self.inplanes = planes * block.expansion
138
+ for _ in range(1, blocks):
139
+ layers.append(
140
+ block(self.inplanes,
141
+ planes,
142
+ groups=self.groups,
143
+ base_width=self.base_width,
144
+ dilation=self.dilation))
145
+
146
+ return nn.Sequential(*layers)
147
+
148
+ def forward(self, x):
149
+ with torch.cuda.amp.autocast(self.fp16):
150
+ x = self.conv1(x)
151
+ x = self.bn1(x)
152
+ x = self.prelu(x)
153
+ x = self.layer1(x)
154
+ x = self.layer2(x)
155
+ x = self.layer3(x)
156
+ x = self.layer4(x)
157
+ x = self.bn2(x)
158
+ x = torch.flatten(x, 1)
159
+ x = self.dropout(x)
160
+ x = self.fc(x.float() if self.fp16 else x)
161
+ x = self.features(x)
162
+ return x
163
+
164
+
165
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
166
+ model = IResNet(block, layers, **kwargs)
167
+ if pretrained:
168
+ raise ValueError()
169
+ return model
170
+
171
+
172
+ def iresnet18(pretrained=False, progress=True, **kwargs):
173
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
174
+ progress, **kwargs)
175
+
176
+
177
+ def iresnet34(pretrained=False, progress=True, **kwargs):
178
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
179
+ progress, **kwargs)
180
+
181
+
182
+ def iresnet50(pretrained=False, progress=True, **kwargs):
183
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
184
+ progress, **kwargs)
185
+
186
+
187
+ def iresnet100(pretrained=False, progress=True, **kwargs):
188
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
189
+ progress, **kwargs)
190
+
191
+
192
+ def iresnet200(pretrained=False, progress=True, **kwargs):
193
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
194
+ progress, **kwargs)
195
+
196
+ def get_model(name, **kwargs):
197
+ # resnet
198
+ if name == "r18":
199
+ return iresnet18(False, **kwargs)
200
+ elif name == "r34":
201
+ return iresnet34(False, **kwargs)
202
+ elif name == "r50":
203
+ return iresnet50(False, **kwargs)
204
+ elif name == "r100":
205
+ return iresnet100(False, **kwargs)
206
+ elif name == "r200":
207
+ return iresnet200(False, **kwargs)
208
+ else:
209
+ raise ValueError
test.jpg ADDED