Spaces:
Sleeping
Sleeping
from sklearn.decomposition import PCA | |
import pickle as pk | |
import numpy as np | |
import pandas as pd | |
import os | |
from huggingface_hub import snapshot_download | |
import requests | |
import matplotlib.pyplot as plt | |
from collections import Counter | |
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb')) | |
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb')) | |
if not os.path.exists('dataset'): | |
REPO_ID='Serrelab/Fossils' | |
token = os.environ.get('READ_TOKEN') | |
print(f"Read token:{token}") | |
if token is None: | |
print("warning! A read token in env variables is needed for authentication.") | |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset') | |
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy') | |
#embedding_leaves = np.load('embedding_leaves.npy') | |
fossils_pd= pd.read_csv('fossils_paths.csv') | |
def pca_distance(pca,sample,embedding,top_k): | |
""" | |
Args: | |
pca:fitted PCA model | |
sample:sample for which to find the closest embeddings | |
embedding:embeddings of the dataset | |
Returns: | |
The indices of the five closest embeddings to the sample | |
""" | |
s = pca.transform(sample.reshape(1,-1)) | |
all = pca.transform(embedding[:,-1]) | |
distances = np.linalg.norm(all - s, axis=1) | |
sorted_indices = np.argsort(distances) | |
filtered_indices = sorted_indices[sorted_indices<=2852] # exclude general fossils, keep florissant only. | |
top_indices = np.concatenate([filtered_indices[:2], filtered_indices[3:top_k+1]]) | |
return top_indices | |
def return_paths(argsorted,files): | |
paths= [] | |
for i in argsorted: | |
paths.append(files[i]) | |
return paths | |
def download_public_image(url, destination_path): | |
response = requests.get(url) | |
if response.status_code == 200: | |
with open(destination_path, 'wb') as f: | |
f.write(response.content) | |
print(f"Downloaded image to {destination_path}") | |
else: | |
print(f"Failed to download image from bucket. Status code: {response.status_code}") | |
def get_images(embedding): | |
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1]) | |
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5) | |
fossils_paths = fossils_pd['file_name'].values | |
paths = return_paths(pca_d,fossils_paths) | |
print(paths) | |
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/' | |
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/' | |
local_paths = [] | |
classes = [] | |
for i, path in enumerate(paths): | |
local_file_path = f'image_{i}.jpg' | |
if 'Florissant_Fossil/512/full/jpg/' in path: | |
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant) | |
elif 'General_Fossil/512/full/jpg/' in path: | |
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general) | |
else: | |
print("no match found") | |
print(public_path) | |
download_public_image(public_path, local_file_path) | |
names = [] | |
parts = [part for part in public_path.split('/') if part] | |
part = parts[-2] | |
classes.append(part) | |
local_paths.append(local_file_path) | |
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', | |
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths] | |
return classes, local_paths | |
def get_diagram(embedding,top_k): | |
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1]) | |
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k) | |
fossils_paths = fossils_pd['file_name'].values | |
paths = return_paths(pca_d,fossils_paths) | |
#print(paths) | |
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/' | |
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/' | |
classes = [] | |
for i, path in enumerate(paths): | |
local_file_path = f'image_{i}.jpg' | |
if 'Florissant_Fossil/512/full/jpg/' in path: | |
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant) | |
elif 'General_Fossil/512/full/jpg/' in path: | |
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general) | |
else: | |
print("no match found") | |
print(public_path) | |
#download_public_image(public_path, local_file_path) | |
parts = [part for part in public_path.split('/') if part] | |
part = parts[-2] | |
classes.append(part) | |
#local_paths.append(local_file_path) | |
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', | |
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths] | |
class_counts = Counter(classes) | |
sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True) | |
sorted_classes, sorted_frequencies = zip(*sorted_class_counts) | |
colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes))) | |
fig, ax = plt.subplots() | |
ax.bar(sorted_classes, sorted_frequencies,color=colors) | |
ax.set_xlabel('Plant Family') | |
ax.set_ylabel('Frequency') | |
ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples') | |
ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right') | |
# Save the diagram to a file | |
diagram_path = 'class_distribution_chart.png' | |
plt.tight_layout() # Adjust layout to make room for rotated x-axis labels | |
plt.savefig(diagram_path) | |
plt.close() # Close the figure to free up memory | |
return diagram_path |