import gradio as gr import numpy as np import torch from CCAgT_utils.types.mask import Mask from PIL import Image from torch import nn from transformers import SegformerFeatureExtractor from transformers import SegformerForSemanticSegmentation device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300' model = SegformerForSemanticSegmentation.from_pretrained( model_hub_name, ).to(device) feature_extractor = SegformerFeatureExtractor.from_pretrained( model_hub_name, ) def query_image(image): image = np.array(image) img = Image.fromarray(image) pixel_values = feature_extractor( image, return_tensors='pt', ).to(device) with torch.no_grad(): outputs = model(pixel_values=pixel_values) logits = outputs.logits upsampled_logits = nn.functional.interpolate( logits, size=img.size[::-1], # (height, width) mode='bilinear', align_corners=False, ) segmentation_mask = upsampled_logits.argmax(dim=1)[0] results = Mask(segmentation_mask).colorized() / 255 return results title = 'SegFormer (b3) - CCAgT dataset' description = f""" This is demo for the SegFormer fine-tuned on sub-dataset from [CCAgT dataset](https://huggingface.co/datasets/lapix/CCAgT). This model was trained to segment cervical cells silver-stained (AgNOR technique) images with resolution of 400x300. The model was available at HF hub at [{model_hub_name}](https://huggingface.co/{model_hub_name}). """ examples = [ [f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'], [f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'], ] demo = gr.Interface( query_image, inputs=[gr.Image()], outputs='image', title=title, description=description, examples=examples, ) demo.launch()