import random import numpy as np import gradio as gr from huggingface_hub import from_pretrained_fastai from PIL import Image from groundingdino.util.inference import load_model from groundingdino.util.inference import predict as grounding_dino_predict import groundingdino.datasets.transforms as T import torch from torchvision.ops import box_convert from torchvision.transforms.functional import to_tensor from torchvision.transforms import GaussianBlur from Ambrosia import pre_process_image # Define a custom transform for Gaussian blur def gaussian_blur(x, p=0.5, kernel_size_min=3, kernel_size_max=20, sigma_min=0.1, sigma_max=3): if x.ndim == 4: for i in range(x.shape[0]): if random.random() < p: kernel_size = random.randrange(kernel_size_min, kernel_size_max + 1, 2) sigma = random.uniform(sigma_min, sigma_max) x[i] = GaussianBlur(kernel_size=kernel_size, sigma=sigma)(x[i]) return x # Custom Label Function def custom_label_func(fpath): # this directs the labels to be 2 levels up from the image folder label = fpath.parents[2].name return label # this function only describes how much a singular value in al ist stands out. # if all values in the lsit are high or low this is 1 # the smaller the proportiopn of number of disimilar vlaues are to other more similar values the lower this number # the larger the gap between the dissimilar numbers and the simialr number the smaller this number # only able to interpret probabilities or values between 0 and 1 # this function outputs an estimate an inverse of the classification confidence based on the probabilities of all the classes. # the wedge threshold splits the data on a threshold with a magnitude of a positive int to force a ledge/peak in the data def unkown_prob_calc(probs, wedge_threshold, wedge_magnitude=1, wedge='strict'): if wedge =='strict': increase_var = (1/(wedge_magnitude)) decrease_var = (wedge_magnitude) if wedge =='dynamic': # this allows pointsthat are furhter from the threshold ot be moved less and points clsoer to be moved more increase_var = (1/(wedge_magnitude*((1-np.abs(probs-wedge_threshold))))) decrease_var = (wedge_magnitude*((1-np.abs(probs-wedge_threshold)))) else: print("Error: use 'strict' (default) or 'dynamic' as options for the wedge parameter!") probs = np.where(probs>=wedge_threshold , probs**increase_var, probs) probs = np.where(probs<=wedge_threshold , probs**decrease_var, probs) diff_matrix = np.abs(probs[:, np.newaxis] - probs) diff_matrix_sum = np.sum(diff_matrix) probs_sum = np.sum(probs) class_val = (diff_matrix_sum/probs_sum) max_class_val = ((len(probs)-1)*2) kown_prob = class_val/max_class_val unknown_prob = 1-kown_prob return(unknown_prob) # def load_image(image_source): # transform = T.Compose( # [ # T.RandomResize([800], max_size=1333), # T.ToTensor(), # T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # ] # ) # image_source = image_source.convert("RGB") # image_transformed, _ = transform(image_source, None) # return image_transformed # # load object detection model # od_model = load_model( # model_checkpoint_path="groundingdino_swint_ogc.pth", # model_config_path="GroundingDINO_SwinT_OGC.cfg.py", # device="cpu") # print("Object detection model loaded") # def detect_objects(og_image, model=od_model, prompt="bug . insect", device="cpu"): # TEXT_PROMPT = prompt # BOX_TRESHOLD = 0.35 # TEXT_TRESHOLD = 0.25 # DEVICE = device # cuda or cpu # # Convert numpy array to PIL Image if needed # if isinstance(og_image, np.ndarray): # og_image_obj = Image.fromarray(og_image) # else: # og_image_obj = og_image # Assuming og_image is already a PIL Image # # Transform the image # image_transformed = load_image(image_source = og_image_obj) # # Your model prediction code here... # boxes, logits, phrases = grounding_dino_predict( # model=model, # image=image_transformed, # caption=TEXT_PROMPT, # box_threshold=BOX_TRESHOLD, # text_threshold=TEXT_TRESHOLD, # device=DEVICE) # # Use og_image_obj directly for further processing # height, width = og_image_obj.size # boxes_norm = boxes * torch.Tensor([height, width, height, width]) # xyxy = box_convert( # boxes=boxes_norm, # in_fmt="cxcywh", # out_fmt="xyxy").numpy() # img_lst = [] # for i in range(len(boxes_norm)): # crop_img = og_image_obj.crop((xyxy[i])) # img_lst.append(crop_img) # return (img_lst) # load beetle classifier model repo_id="ChristopherMarais/beetle-model-mini" bc_model = from_pretrained_fastai(repo_id) # get class names labels = np.append(np.array(bc_model.dls.vocab), "Unknown") print("Classification model loaded") def predict_beetle(img): print("Detecting & classifying beetles...") # Split image into smaller images of detected objects # image_lst = detect_objects(og_image=img, model=od_model, prompt="bug . insect", device="cpu") pre_process = pre_process_image(manual_thresh_buffer=0.15, image = img) # use image_dir if directory of image used pre_process.segment(cluster_num=2, image_edge_buffer=50) image_lst = pre_process.col_image_lst print("Objects detected") # get predictions for all segments conf_dict_lst = [] output_lst = [] img_cnt = len(image_lst) for i in range(0,img_cnt): prob_ar = np.array(bc_model.predict(image_lst[i])[2]) print(f"Beetle classified - {i}") unkown_prob = unkown_prob_calc(probs=prob_ar, wedge_threshold=0.85, wedge_magnitude=5, wedge='dynamic') prob_ar = np.append(prob_ar, unkown_prob) prob_ar = np.around(prob_ar*100, decimals=1) # only show the top 5 predictions # Sorting the dictionary by value in descending order and taking the top items top_num = 3 conf_dict = {labels[i]: float(prob_ar[i]) for i in range(len(prob_ar))} conf_dict = dict(sorted(conf_dict.items(), key=lambda item: item[1], reverse=True)[:top_num]) conf_dict_lst.append(str(conf_dict)[1:-1]) # remove dictionary brackets result = list(zip(image_lst, conf_dict_lst)) print(f"Classification processed - {i}") result = list(zip([img], ["labelzzzzz"])) return(result) # gradio app css = """ button { width: auto; /* Set your desired width */ } """ with gr.Blocks(css=css) as demo: with gr.Column(variant="panel"): with gr.Row(variant="compact"): inputs = gr.Image() # Use the `full_width` parameter directly btn = gr.Button("Classify") # Set the gallery layout and height directly in the constructor gallery = gr.Gallery(label="Show images", show_label=True, elem_id="gallery", columns=8, height="auto") btn.click(predict_beetle, inputs, gallery) demo.launch(debug=True, show_error=True)