Spaces:
Sleeping
Sleeping
update closest_sample to use new embeddings
Browse files- app.py +3 -3
- check_arch.py +0 -21
- closest_sample.py +3 -3
- test.py +29 -0
app.py
CHANGED
@@ -110,21 +110,21 @@ def get_model(model_name):
|
|
110 |
backbone_class=tf.keras.applications.ResNet50V2,
|
111 |
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
|
112 |
model.load_weights('model_classification/rock-170.h5')
|
113 |
-
# elif model_name == 'Fossils 142':
|
114 |
# n_classes = 142
|
115 |
# model = get_triplet_model_beit(input_shape = (384, 384, 3),
|
116 |
# embedding_units = 256,
|
117 |
# embedding_depth = 2,
|
118 |
# n_classes = n_classes)
|
119 |
# model.load_weights('model_classification/fossil-142.h5')
|
120 |
-
# elif model_name == 'Fossils new':
|
121 |
# n_classes = 142
|
122 |
# model = get_triplet_model_beit(input_shape = (384, 384, 3),
|
123 |
# embedding_units = 256,
|
124 |
# embedding_depth = 2,
|
125 |
# n_classes = n_classes)
|
126 |
# model.load_weights('model_classification/fossil-new.h5')
|
127 |
-
elif model_name == 'Fossils 142':
|
128 |
n_classes = 142
|
129 |
model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
|
130 |
else:
|
|
|
110 |
backbone_class=tf.keras.applications.ResNet50V2,
|
111 |
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
|
112 |
model.load_weights('model_classification/rock-170.h5')
|
113 |
+
# elif model_name == 'Fossils 142': #BEiT
|
114 |
# n_classes = 142
|
115 |
# model = get_triplet_model_beit(input_shape = (384, 384, 3),
|
116 |
# embedding_units = 256,
|
117 |
# embedding_depth = 2,
|
118 |
# n_classes = n_classes)
|
119 |
# model.load_weights('model_classification/fossil-142.h5')
|
120 |
+
# elif model_name == 'Fossils new': # BEiT-v2
|
121 |
# n_classes = 142
|
122 |
# model = get_triplet_model_beit(input_shape = (384, 384, 3),
|
123 |
# embedding_units = 256,
|
124 |
# embedding_depth = 2,
|
125 |
# n_classes = n_classes)
|
126 |
# model.load_weights('model_classification/fossil-new.h5')
|
127 |
+
elif model_name == 'Fossils 142': # new resnet
|
128 |
n_classes = 142
|
129 |
model,_,_ = get_resnet_model('model_classification/fossil-model.h5')
|
130 |
else:
|
check_arch.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import h5py
|
2 |
-
|
3 |
-
def print_model_details(file_path):
|
4 |
-
with h5py.File(file_path, 'r') as f:
|
5 |
-
print(f.keys()) # Print layers
|
6 |
-
print(len(f.keys()))
|
7 |
-
print("")
|
8 |
-
for key in f.keys():
|
9 |
-
print(len(list(f[key].keys())))
|
10 |
-
print(f"{key}: {list(f[key].keys())}") # Print details of each layer)
|
11 |
-
print('rock')
|
12 |
-
print_model_details('model_classification/rock-170.h5')
|
13 |
-
print('mummified-170')
|
14 |
-
print_model_details('model_classification/mummified-170.h5')
|
15 |
-
print('BEiT')
|
16 |
-
print_model_details('model_classification/fossil-142.h5')
|
17 |
-
print('BEiT New')
|
18 |
-
print_model_details('model_classification/fossil-new.h5')
|
19 |
-
print("Newest:")
|
20 |
-
print_model_details('model_classification/fossil-model.h5')
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
closest_sample.py
CHANGED
@@ -9,8 +9,8 @@ import matplotlib.pyplot as plt
|
|
9 |
from collections import Counter
|
10 |
|
11 |
|
12 |
-
pca_fossils = pk.load(open('
|
13 |
-
pca_leaves = pk.load(open('
|
14 |
|
15 |
if not os.path.exists('dataset'):
|
16 |
REPO_ID='Serrelab/Fossils'
|
@@ -20,7 +20,7 @@ if not os.path.exists('dataset'):
|
|
20 |
print("warning! A read token in env variables is needed for authentication.")
|
21 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
|
22 |
|
23 |
-
embedding_fossils = np.load('dataset/
|
24 |
#embedding_leaves = np.load('embedding_leaves.npy')
|
25 |
|
26 |
fossils_pd= pd.read_csv('fossils_paths.csv')
|
|
|
9 |
from collections import Counter
|
10 |
|
11 |
|
12 |
+
pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
|
13 |
+
pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
|
14 |
|
15 |
if not os.path.exists('dataset'):
|
16 |
REPO_ID='Serrelab/Fossils'
|
|
|
20 |
print("warning! A read token in env variables is needed for authentication.")
|
21 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
|
22 |
|
23 |
+
embedding_fossils = np.load('dataset/embedding_leaves_142_finer.npy')
|
24 |
#embedding_leaves = np.load('embedding_leaves.npy')
|
25 |
|
26 |
fossils_pd= pd.read_csv('fossils_paths.csv')
|
test.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import h5py
|
2 |
+
|
3 |
+
# def print_model_details(file_path):
|
4 |
+
# with h5py.File(file_path, 'r') as f:
|
5 |
+
# print(f.keys()) # Print layers
|
6 |
+
# print(len(f.keys()))
|
7 |
+
# print("")
|
8 |
+
# for key in f.keys():
|
9 |
+
# print(len(list(f[key].keys())))
|
10 |
+
# print(f"{key}: {list(f[key].keys())}") # Print details of each layer)
|
11 |
+
# print('rock')
|
12 |
+
# print_model_details('model_classification/rock-170.h5')
|
13 |
+
# print('mummified-170')
|
14 |
+
# print_model_details('model_classification/mummified-170.h5')
|
15 |
+
# print('BEiT')
|
16 |
+
# print_model_details('model_classification/fossil-142.h5')
|
17 |
+
# print('BEiT New')
|
18 |
+
# print_model_details('model_classification/fossil-new.h5')
|
19 |
+
# print("Newest:")
|
20 |
+
# print_model_details('model_classification/fossil-model.h5')
|
21 |
+
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
# Load the .npy file
|
26 |
+
embedding = np.load('embedding.npy')
|
27 |
+
|
28 |
+
# Check the shape of the array
|
29 |
+
print(embedding.shape)
|