File size: 1,288 Bytes
86e7f2a
 
 
 
 
a6909b4
86e7f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6909b4
86e7f2a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# import gradio as gr

# gr.Interface.load("models/nvidia/segformer-b0-finetuned-ade-512-512").launch()


import gradio as gr
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import torch

# Load the processor and model
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

# Define the function for segmentation
def segment_image(image):
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get the segmentation mask
    logits = outputs.logits
    segmentation_mask = logits.argmax(dim=1).squeeze().cpu().numpy()
    
    # Convert to a PIL image and return it
    segmentation_image = Image.fromarray(segmentation_mask.astype("uint8"))
    return segmentation_image

# Create the Gradio interface
interface = gr.Interface(fn=segment_image, 
                         inputs=gr.Image(type="pil"),  # Using gr.Image for Gradio
                         outputs=gr.Image(type="pil"))  # Using gr.Image for Gradio

# Launch the interface
interface.launch()