fossil_app / closest_sample.py
andy-wyx's picture
update embeddings from resnet model
1662a5d
raw
history blame
7.46 kB
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests
import matplotlib.pyplot as plt
from collections import Counter
if not os.path.exists('dataset'):
REPO_ID='Serrelab/Fossils'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
fossils_pd= pd.read_csv('all_fossils_filtered_100.csv')
def pca_distance(pca,sample,embedding,top_k):
"""
Args:
pca:fitted PCA model
sample:sample for which to find the closest embeddings
embedding:embeddings of the dataset
Returns:
The indices of the five closest embeddings to the sample
"""
s = pca.transform(sample.reshape(1,-1))
all = pca.transform(embedding[:,-1])
distances = np.linalg.norm(all - s, axis=1)
sorted_indices = np.argsort(distances)
filtered_indices = sorted_indices[sorted_indices<=2852] # exclude general fossils, keep florissant only.
top_indices = np.concatenate([filtered_indices[:2], filtered_indices[3:top_k+1]])
return top_indices
def return_paths(argsorted,files):
paths= []
for i in argsorted:
paths.append(files[i])
return paths
def download_public_image(url, destination_path):
response = requests.get(url)
if response.status_code == 200:
with open(destination_path, 'wb') as f:
f.write(response.content)
print(f"Downloaded image to {destination_path}")
else:
print(f"Failed to download image from bucket. Status code: {response.status_code}")
def get_images(embedding,model_name):
if model_name in ['Rock 170','Mummified 170']:
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
elif model_name in ['Fossils 142']:
pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
else:
print(f'{model_name} not recognized')
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
local_paths = []
classes = []
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
if 'Florissant_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
elif 'General_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
else:
print("no match found")
print(public_path)
download_public_image(public_path, local_file_path)
names = []
parts = [part for part in public_path.split('/') if part]
part = parts[-2]
classes.append(part)
local_paths.append(local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
return classes, local_paths
def get_diagram(embedding,top_k,model_name):
if model_name in ['Rock 170','Mummified 170']:
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
elif model_name in ['Fossils 142']:
pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
else:
print(f'{model_name} not recognized')
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
#print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
classes = []
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
if 'Florissant_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
elif 'General_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
else:
print("no match found")
print(public_path)
#download_public_image(public_path, local_file_path)
parts = [part for part in public_path.split('/') if part]
part = parts[-2]
classes.append(part)
#local_paths.append(local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
class_counts = Counter(classes)
sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
fig, ax = plt.subplots()
ax.bar(sorted_classes, sorted_frequencies,color=colors)
ax.set_xlabel('Plant Family')
ax.set_ylabel('Frequency')
ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples')
ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')
# Save the diagram to a file
diagram_path = 'class_distribution_chart.png'
plt.tight_layout() # Adjust layout to make room for rotated x-axis labels
plt.savefig(diagram_path)
plt.close() # Close the figure to free up memory
return diagram_path