morinop commited on
Commit
9e8b807
1 Parent(s): 6d8baed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -10,13 +10,16 @@ import timm
10
  feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101")
11
  model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101")
12
  model.eval()
13
- print(model)
14
 
 
 
15
 
 
16
 
17
  import os
18
 
19
  def print_bn():
 
20
  bn_data = []
21
  for m in model.modules():
22
  if(type(m) is nn.BatchNorm2d):
@@ -25,9 +28,11 @@ def print_bn():
25
  bn_data.extend(m.running_var.data.numpy().tolist())
26
  bn_data.append(m.momentum)
27
  print(len(bn_data))
 
28
  return bn_data
29
 
30
  def update_bn(image):
 
31
  cursor_im = 0
32
  image = T.Resize((90,90))(image)
33
  image = image.reshape(-1)
@@ -50,6 +55,8 @@ def greet(image):
50
  bn_data = print_bn()
51
  return ','.join([f'{x:.2f}' for x in bn_data])
52
  else:
 
 
53
  print(type(image))
54
  image = torch.tensor(image).float()
55
  print(image.min(), image.max())
 
10
  feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101")
11
  model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101")
12
  model.eval()
 
13
 
14
+
15
+ print(model)
16
 
17
+ print(model.resnet.embedder.embedder.convolution.weight.data)
18
 
19
  import os
20
 
21
  def print_bn():
22
+
23
  bn_data = []
24
  for m in model.modules():
25
  if(type(m) is nn.BatchNorm2d):
 
28
  bn_data.extend(m.running_var.data.numpy().tolist())
29
  bn_data.append(m.momentum)
30
  print(len(bn_data))
31
+ bn_data.extend(model.resnet.embedder.embedder.convolution.weight.data.numpy().tolist())
32
  return bn_data
33
 
34
  def update_bn(image):
35
+
36
  cursor_im = 0
37
  image = T.Resize((90,90))(image)
38
  image = image.reshape(-1)
 
55
  bn_data = print_bn()
56
  return ','.join([f'{x:.2f}' for x in bn_data])
57
  else:
58
+ conv_layer = model.resnet.embedder.embedder.convolution
59
+ conv_layer.weight.data = torch.ones_like(conv_layer.weight.data)
60
  print(type(image))
61
  image = torch.tensor(image).float()
62
  print(image.min(), image.max())