import cv2 import gradio as gr import matplotlib import matplotlib.pyplot as plt import numpy as np import torch from CCAgT_utils.categories import CategoriesInfos from CCAgT_utils.slice import __create_xy_slice from CCAgT_utils.types.mask import Mask from CCAgT_utils.visualization import plot from PIL import Image from torch import nn from transformers import SegformerFeatureExtractor from transformers import SegformerForSemanticSegmentation from transformers.modeling_outputs import SemanticSegmenterOutput matplotlib.use('Agg') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300' model = SegformerForSemanticSegmentation.from_pretrained( model_hub_name, ).to(device) model.eval() feature_extractor = SegformerFeatureExtractor.from_pretrained( model_hub_name, ) def segment( image: Image.Image, ) -> SemanticSegmenterOutput: inputs = feature_extractor( image, return_tensors='pt', ).to(device) outputs = model(**inputs) return outputs def post_processing( outputs: SemanticSegmenterOutput, target_size: tuple[int, int], ) -> np.ndarray: logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=target_size, mode='bilinear', align_corners=False, ) segmentation_mask = upsampled_logits.argmax(dim=1)[0] return np.array(segmentation_mask) def colorize( mask: Mask, ) -> np.ndarray: return mask.colorized(CategoriesInfos()) / 255 def check_and_resize( image: np.ndarray, ) -> np.ndarray: if image.shape[0] > 1200 or image.shape[1] > 1600: r = 1600.0 / image.shape[1] dim = (1600, int(image.shape[0] * r)) return cv2.resize(image, dim, interpolation=cv2.INTER_AREA) return image def process_big_images( image: Image.Image, ) -> Mask: '''Process and post-processing for images bigger than 400x300''' img = check_and_resize(np.asarray(image)) mask = np.zeros(shape=(img.shape[0], img.shape[1]), dtype=np.uint8) for bbox in __create_xy_slice(image.size[1], image.size[0], 300, 400): part = cv2.copyMakeBorder( img, bbox.y_init, bbox.y_end, bbox.x_init, bbox.x_end, cv2.BORDER_REFLECT, ) target_size = (part.shape[0], part.shape[1]) outputs = segment(Image.fromarray(part)) msk = post_processing(outputs, target_size) mask[bbox.slice_y, bbox.slice_x] = msk[bbox.slice_y, bbox.slice_x] return Mask(mask) def image_with_mask( image: Image.Image, mask: Mask, ) -> plt.Figure: fig = plt.figure(dpi=600) plt.imshow(image) plt.imshow( mask.categorical, cmap=mask.cmap(CategoriesInfos()), vmax=max(mask.unique_ids), vmin=min(mask.unique_ids), interpolation='nearest', alpha=0.4, ) plt.axis('off') return fig def categories_map( mask: Mask, ) -> plt.Figure: fig = plt.figure(dpi=600) handles = plot.create_handles( CategoriesInfos(), selected_categories=mask.unique_ids, ) plt.legend(handles=handles, fontsize=24, loc='center') plt.axis('off') return fig def main(image): img = Image.fromarray(image) mask = process_big_images(img) mask_colorized = colorize(mask) fig = image_with_mask(img, mask) return categories_map(mask), mask_colorized, fig title = 'SegFormer (b3) - CCAgT dataset' description = f""" This is demo for the SegFormer fine-tuned on sub-dataset from [CCAgT dataset](https://huggingface.co/datasets/lapix/CCAgT). This model was trained to segment cervical cells silver-stained (AgNOR technique) images with resolution of 400x300. The model was available at HF hub at [{model_hub_name}](https://huggingface.co/{model_hub_name}). """ examples = [ [f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'], [f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'], ] + [ [f'https://datasets-server.huggingface.co/assets/lapix/CCAgT/--/semantic_segmentation/test/{x}/image/image.jpg'] for x in {3, 10, 12, 18, 35, 78, 89} ] demo = gr.Interface( main, inputs=[gr.Image()], outputs=[ gr.Plot(label='Categories map'), gr.Image(label='Mask'), gr.Plot(label='Image with mask'), ], title=title, description=description, examples=examples, allow_flagging='never', cache_examples=False, ) if __name__ == '__main__': demo.launch()