andy-wyx commited on
Commit
0b77991
1 Parent(s): fc50afd

update closest_sample to use new embeddings

Browse files
Files changed (4) hide show
  1. app.py +3 -3
  2. check_arch.py +0 -21
  3. closest_sample.py +3 -3
  4. test.py +29 -0
app.py CHANGED
@@ -110,21 +110,21 @@ def get_model(model_name):
110
  backbone_class=tf.keras.applications.ResNet50V2,
111
  nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
112
  model.load_weights('model_classification/rock-170.h5')
113
- # elif model_name == 'Fossils 142':
114
  # n_classes = 142
115
  # model = get_triplet_model_beit(input_shape = (384, 384, 3),
116
  # embedding_units = 256,
117
  # embedding_depth = 2,
118
  # n_classes = n_classes)
119
  # model.load_weights('model_classification/fossil-142.h5')
120
- # elif model_name == 'Fossils new':
121
  # n_classes = 142
122
  # model = get_triplet_model_beit(input_shape = (384, 384, 3),
123
  # embedding_units = 256,
124
  # embedding_depth = 2,
125
  # n_classes = n_classes)
126
  # model.load_weights('model_classification/fossil-new.h5')
127
- elif model_name == 'Fossils 142':
128
  n_classes = 142
129
  model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
130
  else:
 
110
  backbone_class=tf.keras.applications.ResNet50V2,
111
  nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
112
  model.load_weights('model_classification/rock-170.h5')
113
+ # elif model_name == 'Fossils 142': #BEiT
114
  # n_classes = 142
115
  # model = get_triplet_model_beit(input_shape = (384, 384, 3),
116
  # embedding_units = 256,
117
  # embedding_depth = 2,
118
  # n_classes = n_classes)
119
  # model.load_weights('model_classification/fossil-142.h5')
120
+ # elif model_name == 'Fossils new': # BEiT-v2
121
  # n_classes = 142
122
  # model = get_triplet_model_beit(input_shape = (384, 384, 3),
123
  # embedding_units = 256,
124
  # embedding_depth = 2,
125
  # n_classes = n_classes)
126
  # model.load_weights('model_classification/fossil-new.h5')
127
+ elif model_name == 'Fossils 142': # new resnet
128
  n_classes = 142
129
  model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
130
  else:
check_arch.py DELETED
@@ -1,21 +0,0 @@
1
- import h5py
2
-
3
- def print_model_details(file_path):
4
- with h5py.File(file_path, 'r') as f:
5
- print(f.keys()) # Print layers
6
- print(len(f.keys()))
7
- print("")
8
- for key in f.keys():
9
- print(len(list(f[key].keys())))
10
- print(f"{key}: {list(f[key].keys())}") # Print details of each layer)
11
- print('rock')
12
- print_model_details('model_classification/rock-170.h5')
13
- print('mummified-170')
14
- print_model_details('model_classification/mummified-170.h5')
15
- print('BEiT')
16
- print_model_details('model_classification/fossil-142.h5')
17
- print('BEiT New')
18
- print_model_details('model_classification/fossil-new.h5')
19
- print("Newest:")
20
- print_model_details('model_classification/fossil-model.h5')
21
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
closest_sample.py CHANGED
@@ -9,8 +9,8 @@ import matplotlib.pyplot as plt
9
  from collections import Counter
10
 
11
 
12
- pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
13
- pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
14
 
15
  if not os.path.exists('dataset'):
16
  REPO_ID='Serrelab/Fossils'
@@ -20,7 +20,7 @@ if not os.path.exists('dataset'):
20
  print("warning! A read token in env variables is needed for authentication.")
21
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
22
 
23
- embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
24
  #embedding_leaves = np.load('embedding_leaves.npy')
25
 
26
  fossils_pd= pd.read_csv('fossils_paths.csv')
 
9
  from collections import Counter
10
 
11
 
12
+ pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
13
+ pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
14
 
15
  if not os.path.exists('dataset'):
16
  REPO_ID='Serrelab/Fossils'
 
20
  print("warning! A read token in env variables is needed for authentication.")
21
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
22
 
23
+ embedding_fossils = np.load('dataset/embedding_leaves_142_finer.npy')
24
  #embedding_leaves = np.load('embedding_leaves.npy')
25
 
26
  fossils_pd= pd.read_csv('fossils_paths.csv')
test.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import h5py
2
+
3
+ # def print_model_details(file_path):
4
+ # with h5py.File(file_path, 'r') as f:
5
+ # print(f.keys()) # Print layers
6
+ # print(len(f.keys()))
7
+ # print("")
8
+ # for key in f.keys():
9
+ # print(len(list(f[key].keys())))
10
+ # print(f"{key}: {list(f[key].keys())}") # Print details of each layer)
11
+ # print('rock')
12
+ # print_model_details('model_classification/rock-170.h5')
13
+ # print('mummified-170')
14
+ # print_model_details('model_classification/mummified-170.h5')
15
+ # print('BEiT')
16
+ # print_model_details('model_classification/fossil-142.h5')
17
+ # print('BEiT New')
18
+ # print_model_details('model_classification/fossil-new.h5')
19
+ # print("Newest:")
20
+ # print_model_details('model_classification/fossil-model.h5')
21
+
22
+
23
+ import numpy as np
24
+
25
+ # Load the .npy file
26
+ embedding = np.load('embedding.npy')
27
+
28
+ # Check the shape of the array
29
+ print(embedding.shape)