import pathlib import gradio as gr import torch from transformers import AutoFeatureExtractor, DetrForObjectDetection from visualization import visualize_attention_map, visualize_prediction from style import css, description, title from PIL import Image def make_prediction(img, feature_extractor, model): inputs = feature_extractor(img, return_tensors="pt") outputs = model(**inputs) img_size = torch.tensor([tuple(reversed(img.size))]) processed_outputs = feature_extractor.post_process(outputs, img_size) print(outputs.keys()) return ( processed_outputs[0], outputs["decoder_attentions"], outputs["encoder_attentions"], ) def construct_model_name( experiment_type, convbase, attention_heads_num, enc_dec_layers, ffn_dim, act_func, d_model, dilation=None ): base = "polejowska/" if convbase == "RESNET-50": base += "detr-r50" elif convbase == "RESNET-101": if enc_dec_layers == 6: return "polejowska/detr-r101-official" elif enc_dec_layers == 4: return "polejowska/detr-r101-cd45rb-8ah-4l" elif enc_dec_layers == 12: return "polejowska/detr-r101-cd45rb-8ah-12l" base += "-cd45rb" base += f"-{attention_heads_num}ah" base += f"-{enc_dec_layers}l" if attention_heads_num == 1: base += "-corrected" if d_model != 256: base += f"-{d_model}d" if ffn_dim == 1024: base += "-1024ffn" elif ffn_dim == 4096: base += "-4096ffn-correcetd" if act_func == "GeLU": base += "-gelu-corrected" if dilation == "True": base += "-dilation-corrected" return base def detect_objects( experiment_type, convbase, attention_heads_num, enc_dec_layers, ffn_dim, act_func, d_model, dilation, image_input, threshold=0.7, display_mask=False, img_input_mask=None ): if experiment_type in ["Parameters verification", "Reproducability check (1)", "Reproducability check (2)", "Reproducability check (3)", "Reproducability check (4)"]: if experiment_type == "Parameters verification": model_repo = construct_model_name(experiment_type, convbase, attention_heads_num, enc_dec_layers, ffn_dim, act_func, d_model, dilation) elif experiment_type == "Reproducability check (1)": model_repo = "polejowska/detr-r50-cd45rb-all-2ah" elif experiment_type == "Reproducability check (2)": model_repo = "polejowska/detr-r50-cd45rb-all-4ah" elif experiment_type == "Reproducability check (3)": model_repo = "polejowska/detr-r50-cd45rb-all-8ah" elif experiment_type == "Reproducability check (4)": model_repo = "polejowska/detr-r50-cd45rb-all-16ah" model = DetrForObjectDetection.from_pretrained(model_repo) feature_extractor = AutoFeatureExtractor.from_pretrained(model_repo) ( processed_outputs, decoder_attention_map, encoder_attention_map, ) = make_prediction(image_input, feature_extractor, model) viz_img = visualize_prediction( pil_img=image_input, output_dict=processed_outputs, threshold=threshold, id2label=model.config.id2label, display_mask=display_mask, mask=img_input_mask ) decoder_attention_map_img = visualize_attention_map( image_input, decoder_attention_map ) encoder_attention_map_img = visualize_attention_map( image_input, encoder_attention_map ) return ( viz_img, decoder_attention_map_img, encoder_attention_map_img, ) def set_example_image(example: list): return gr.Image.update(value=example[0]), gr.Image.update(value=example[1]) with gr.Blocks(css=css) as app: gr.Markdown(title) with gr.Tabs(): with gr.TabItem("Image upload and detections visualization"): with gr.Row(): with gr.Column(): with gr.Row(): experiment_type = gr.Dropdown( value="Parameters verification", choices=[ "Parameters verification", "Reproducability check (1)", "Reproducability check (2)", "Reproducability check (3)", "Reproducability check (4)", ], label="Select an experiment type", show_label=True, ) with gr.Row(): convbase= gr.Dropdown( value="RESNET-50", choices=[ "RESNET-50", "RESNET-101", ], label="Select a base model for convolution part", show_label=True, ) with gr.Row(): attention_heads_num = gr.Dropdown( value=8, choices=[1, 2, 4, 8, 16], label="The number of attention heads in encoder and decoder", show_label=True, ) with gr.Row(): enc_dec_layers = gr.Dropdown( value=6, choices=[4, 6, 12], label="The number of layers in encoder and decoder", show_label=True, ) with gr.Row(): ffn_dim = gr.Dropdown( value=2048, choices=[1024, 2048, 4096], label="Select FFN dimension", show_label=True, ) with gr.Row(): act_func= gr.Dropdown( value="ReLU", choices=[ "ReLU", "GeLU", ], label="Select an activation function", show_label=True, ) with gr.Row(): d_model= gr.Dropdown( value=256, choices=[128, 256, 512], label="Select a hidden size", show_label=True, ) with gr.Row(): dilation= gr.Dropdown( value="False", choices=[ "True", "False", ], label="Use dilation", show_label=True, ) with gr.Row(): slider_input = gr.Slider( minimum=0.2, maximum=1, value=0.7, label="Prediction threshold" ) with gr.Column(): with gr.Row(): img_input = gr.Image(type="pil") img_input_mask = gr.Image(type="pil", visible=False) with gr.Row(): example_images = gr.Dataset( components=[img_input, img_input_mask], samples=[ [path.as_posix(), path.as_posix().replace("_HE", "_mask")] for path in sorted( pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png") ) ], samples_per_page=2, ) with gr.Row(): display_mask = gr.Checkbox( label="Display masks", default=False ) with gr.Row(): detect_button = gr.Button("Detect leukocytes") with gr.Row(): with gr.Column(): img_output_from_upload = gr.Image(shape=(900, 900)) with gr.TabItem("Attentions visualization"): gr.Markdown("""Encoder attentions""") with gr.Row(): encoder_att_map_output = gr.Image(shape=(850, 850)) gr.Markdown("""Decoder attentions""") with gr.Row(): decoder_att_map_output = gr.Image(shape=(850, 850)) with gr.TabItem("Dataset details"): with gr.Row(): gr.Markdown(description) detect_button.click( detect_objects, inputs=[ experiment_type, convbase, attention_heads_num, enc_dec_layers, ffn_dim, act_func, d_model, dilation, img_input, slider_input, display_mask, img_input_mask ], outputs=[ img_output_from_upload, decoder_att_map_output, encoder_att_map_output, ], queue=True, ) example_images.click( fn=set_example_image, inputs=[example_images], outputs=[img_input, img_input_mask], show_progress=True ) app.launch(enable_queue=True)