import os import sys from env import config_env config_env() import gradio as gr from huggingface_hub import snapshot_download import cv2 import dotenv dotenv.load_dotenv() import numpy as np import gradio as gr import glob from inference_sam import segmentation_sam from explanations import explain from inference_resnet import get_triplet_model from inference_resnet_v2 import get_resnet_model,inference_resnet_embedding_v2,inference_resnet_finer_v2 from inference_beit import get_triplet_model_beit import pathlib import tensorflow as tf from closest_sample import get_images,get_diagram if not os.path.exists('images'): REPO_ID='Serrelab/image_examples_gradio' snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images') if not os.path.exists('dataset'): REPO_ID='Serrelab/Fossils' token = os.environ.get('READ_TOKEN') print(f"Read token:{token}") if token is None: print("warning! A read token in env variables is needed for authentication.") snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset') HEADER = '''

Official Gradio Demo:

🍁 Identifying Florissant Leaf Fossils to Family using Deep Neural Networks

Code: GitHub. Paper: ArXiv. ''' """ **Fossil** a brief intro to the project. # ❗️❗️❗️**Important Notes:** # - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users . # - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users. """ USER_GUIDE = """

❗️ User Guide

Welcome to the interactive fossil exploration tool. Here's how to get started:

Tips

Enjoy exploring! 🌟

""" TIPS = """ ## Tips - Zoom into images on the workbench for finer details. - Use the examples below as references for what types of images to upload. Enjoy exploring! """ CITATION = ''' 📧 **Contact**
If you have any questions, feel free to contact us at ivan_felipe_rodriguez@brown.edu. ''' """ 📝 **Citation** cite using this bibtex:... ``` ``` 📋 **License** """ def get_model(model_name): if model_name=='Mummified 170': n_classes = 170 model = get_triplet_model(input_shape = (600, 600, 3), embedding_units = 256, embedding_depth = 2, backbone_class=tf.keras.applications.ResNet50V2, nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2') model.load_weights('model_classification/mummified-170.h5') elif model_name=='Rock 170': n_classes = 171 model = get_triplet_model(input_shape = (600, 600, 3), embedding_units = 256, embedding_depth = 2, backbone_class=tf.keras.applications.ResNet50V2, nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2') model.load_weights('model_classification/rock-170.h5') # elif model_name == 'Fossils 142': #BEiT # n_classes = 142 # model = get_triplet_model_beit(input_shape = (384, 384, 3), # embedding_units = 256, # embedding_depth = 2, # n_classes = n_classes) # model.load_weights('model_classification/fossil-142.h5') # elif model_name == 'Fossils new': # BEiT-v2 # n_classes = 142 # model = get_triplet_model_beit(input_shape = (384, 384, 3), # embedding_units = 256, # embedding_depth = 2, # n_classes = n_classes) # model.load_weights('model_classification/fossil-new.h5') elif model_name == 'Fossils 142': # new resnet n_classes = 142 model,_,_ = get_resnet_model('model_classification/fossil-model.h5') else: raise ValueError(f"Model name '{model_name}' is not recognized") return model,n_classes def segment_image(input_image): img = segmentation_sam(input_image) return img def classify_image(input_image, model_name): #segmented_image = segment_image(input_image) if 'Rock 170' ==model_name: from inference_resnet import inference_resnet_finer model,n_classes= get_model(model_name) result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes) return result elif 'Mummified 170' ==model_name: from inference_resnet import inference_resnet_finer model, n_classes= get_model(model_name) result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes) return result elif 'Fossils BEiT' ==model_name: from inference_beit import inference_resnet_finer_beit model,n_classes = get_model(model_name) result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes) return result # elif 'Fossils new' ==model_name: # from inference_beit import inference_resnet_finer_beit # model,n_classes = get_model(model_name) # result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes) # return result elif 'Fossils 142' ==model_name: from inference_beit import inference_resnet_finer_beit model,n_classes = get_model(model_name) result = inference_resnet_finer_v2(input_image,model,size=384,n_classes=n_classes) return result return None def get_embeddings(input_image,model_name): if 'Rock 170' ==model_name: from inference_resnet import inference_resnet_embedding model,n_classes= get_model(model_name) result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes) return result elif 'Mummified 170' ==model_name: from inference_resnet import inference_resnet_embedding model, n_classes= get_model(model_name) result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes) return result elif 'Fossils BEiT' ==model_name: from inference_beit import inference_resnet_embedding_beit model,n_classes = get_model(model_name) result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes) return result # elif 'Fossils new' ==model_name: # from inference_beit import inference_resnet_embedding_beit # model,n_classes = get_model(model_name) # result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes) # return result elif 'Fossils 142' ==model_name: from inference_beit import inference_resnet_embedding_beit model,n_classes = get_model(model_name) result = inference_resnet_embedding_v2(input_image,model,size=384,n_classes=n_classes) return result return None def find_closest(input_image,model_name): embedding = get_embeddings(input_image,model_name) classes, paths = get_images(embedding,model_name) #outputs = classes+paths return classes,paths def generate_diagram_closest(input_image,model_name,top_k): embedding = get_embeddings(input_image,model_name) diagram_path = get_diagram(embedding,top_k,model_name) return diagram_path def explain_image(input_image,model_name,explain_method,nb_samples): model,n_classes= get_model(model_name) if model_name=='Fossils BEiT' or 'Fossils 142': size = 384 else: size = 600 #saliency, integrated, smoothgrad, classes,exp_list = explain(model,input_image,explain_method,nb_samples,size = size, n_classes=n_classes) #original = saliency + integrated + smoothgrad print('done') return classes,exp_list def setup_examples(): paths = sorted(pathlib.Path('images/').rglob('*.jpg')) samples = [path.as_posix() for path in paths if 'selected fossil examples' in str(path)][:23] examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=8,label='Fossils Examples from the dataset') samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19] examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=8,label='Leaves Examples from the dataset') return examples_fossils,examples_leaves def preprocess_image(image, output_size=(300, 300)): #shape (height, width, channels) h, w = image.shape[:2] #padding if h > w: padding = (h - w) // 2 image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0]) else: padding = (w - h) // 2 image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0]) # resize image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA) return image_resized def increase_brightness(img, value=30): hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # Convert to HSV h, s, v = cv2.split(hsv) lim = 255 - value v[v > lim] = 255 v[v <= lim] += value final_hsv = cv2.merge((h, s, v)) img_bright = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR) return img_bright def update_display(image): original_image = image processed_image = preprocess_image(image) instruction = "Image ready. Please switch to the 'Specimen Workbench' tab to check out further analysis and outputs." model_name = "Fossils 142" # gr.Dropdown( # ["Mummified 170", "Rock 170","Fossils 142","Fossils new"], # multiselect=False, # value="Fossils new", # default option # label="Model", # interactive=True, # info="Choose the model you'd like to use" # ) explain_method = "Rise" # gr.Dropdown( # ["Sobol", "HSIC","Rise","Saliency"], # multiselect=False, # value="Rise", # default option # label="Explain method", # interactive=True, # info="Choose one method to explain the model" # ) sampling_size = 10 # gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True, # info="Choose between 1 and 5000") top_k = 50 # gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200") class_predicted = None # gr.Label(label='Class Predicted',num_top_classes=10) exp_gallery = None # gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) closest_gallery = None # gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) diagram= None # gr.Image(label = 'Bar Chart') return original_image,processed_image,processed_image,instruction,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram def update_slider_visibility(explain_method): bool = explain_method=="Rise" return {sampling_size: gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise", visible=bool, interactive=True)} #minimalist theme with gr.Blocks(theme='sudeepshouche/minimalist') as demo: with gr.Tab(" Florrissant Fossils"): gr.Markdown(HEADER) with gr.Row(): with gr.Column(): gr.Markdown(USER_GUIDE) with gr.Column(scale=2): with gr.Column(scale=2): instruction_text = gr.Textbox(label="Instructions", value="Upload/Choose an image and click 'Process Image'.") input_image = gr.Image(label="Input",width="100%",container=True) process_button = gr.Button("Process Image") with gr.Column(scale=1): examples_fossils,examples_leaves = setup_examples() gr.Markdown(CITATION) with gr.Tab("Specimen Workbench"): with gr.Row(): with gr.Column(): original_image = gr.Image(visible = False) workbench_image = gr.Image(label="Workbench Image") classify_image_button = gr.Button("Classify Image") # with gr.Column(): # #segmented_image = gr.outputs.Image(label="SAM output",type='numpy') # segmented_image=gr.Image(label="Segmented Image", type='numpy') # segment_button = gr.Button("Segment Image") # #classify_segmented_button = gr.Button("Classify Segmented Image") with gr.Column(): model_name = gr.Dropdown( ["Fossils 142"],#"Mummified 170", "Rock 170","Fossils BEiT" removed multiselect=False, value="Fossils 142", # default option label="Model", interactive=True, info="Choose the model you'd like to use" ) explain_method = gr.Dropdown( ["Sobol", "HSIC","Rise","Saliency"], multiselect=False, value="Rise", # default option label="Explain method", interactive=True, info="Choose one method to explain the model" ) # explain_method = gr.CheckboxGroup(["Sobol", "HSIC","Rise","Saliency"], # label="explain method", # value="Rise", # multiselect=False, # interactive=True,) sampling_size = gr.Slider(10, 3000, value=10, label="Sampling Size in Rise",interactive=True,visible=True, info="Choose between 10 and 3000") top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200") explain_method.change( fn=update_slider_visibility, inputs=explain_method, outputs=sampling_size ) with gr.Row(): with gr.Column(scale=1): class_predicted = gr.Label(label='Plant Family Predicted',num_top_classes=10) with gr.Column(scale=4): with gr.Accordion("Explanations "): gr.Markdown("Computing Explanations from the model for Top 5 Predicted Plant Families") with gr.Column(): with gr.Row(): #original_input = gr.Image(label="Original Frame") #saliency = gr.Image(label="saliency") #gradcam = gr.Image(label='integraged gradients') #guided_gradcam = gr.Image(label='gradcam') #guided_backprop = gr.Image(label='guided backprop') # exp1 = gr.Image(label = 'Class_name1') # exp2= gr.Image(label = 'Class_name2') # exp3= gr.Image(label = 'Class_name3') # exp4= gr.Image(label = 'Class_name4') # exp5= gr.Image(label = 'Class_name5') exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) generate_explanations = gr.Button("Generate Explanations") # with gr.Accordion('Closest Images'): # gr.Markdown("Finding the closest images in the dataset") # with gr.Row(): # with gr.Column(): # label_closest_image_0 = gr.Markdown('') # closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200) # with gr.Column(): # label_closest_image_1 = gr.Markdown('') # closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200) # with gr.Column(): # label_closest_image_2 = gr.Markdown('') # closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200) # with gr.Column(): # label_closest_image_3 = gr.Markdown('') # closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200) # with gr.Column(): # label_closest_image_4 = gr.Markdown('') # closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200) # find_closest_btn = gr.Button("Find Closest Images") with gr.Accordion('Closest Fossil Images'): gr.Markdown("Finding 5 closest images in the dataset") with gr.Row(): closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) #.style(grid=[1, 5], height=200, width=200) find_closest_btn = gr.Button("Find Closest Images") #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image) classify_image_button.click(classify_image, inputs=[original_image,model_name], outputs=class_predicted) # generate_exp.click(exp_image, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp1,exp2,exp3,exp4,exp5]) # # with gr.Accordion('Closest Leaves Images'): # gr.Markdown("5 closest leaves") with gr.Accordion("Family Distribution of Closest Samples "): gr.Markdown("Visualize plant family distribution of top-k closest samples in our dataset") with gr.Column(): with gr.Row(): diagram= gr.Image(label = 'Bar Chart') generate_diagram = gr.Button("Generate Diagram") # with gr.Accordion("Using Diffuser"): # with gr.Column(): # prompt = gr.Textbox(lines=1, label="Prompt") # output_image = gr.Image(label="Output") # generate_button = gr.Button("Generate Leave") # with gr.Column(): # class_predicted2 = gr.Label(label='Class Predicted from diffuser') # classify_button = gr.Button("Classify Image") def update_exp_outputs(input_image,model_name,explain_method,nb_samples): labels, images = explain_image(input_image,model_name,explain_method,nb_samples) #labels_html = "".join([f'
{label}
' for label in labels]) #labels_markdown = f"
{labels_html}
" image_caption=[] for i in range(5): image_caption.append((images[i],"Predicted Plant Family "+str(i)+": "+labels[i])) return image_caption generate_explanations.click(fn=update_exp_outputs, inputs=[original_image,model_name,explain_method,sampling_size], outputs=[exp_gallery]) #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4]) def update_closest_outputs(input_image,model_name): labels, images = find_closest(input_image,model_name) #labels_html = "".join([f'
{label}
' for label in labels]) #labels_markdown = f"
{labels_html}
" image_caption=[] for i in range(5): image_caption.append((images[i],labels[i])) return image_caption find_closest_btn.click(fn=update_closest_outputs, inputs=[original_image,model_name], outputs=[closest_gallery]) #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted) generate_diagram.click(generate_diagram_closest, inputs=[original_image,model_name,top_k], outputs=diagram) process_button.click( fn=update_display, inputs=input_image, outputs=[original_image,input_image,workbench_image,instruction_text,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram] ) demo.queue() # manage multiple incoming requests if os.getenv('SYSTEM') == 'spaces': demo.launch(width='40%') #,auth=(os.environ.get('USERNAME'), os.environ.get('PASSWORD')) else: demo.launch()