Spaces:
Runtime error
Runtime error
File size: 1,288 Bytes
86e7f2a a6909b4 86e7f2a a6909b4 86e7f2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# import gradio as gr
# gr.Interface.load("models/nvidia/segformer-b0-finetuned-ade-512-512").launch()
import gradio as gr
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import torch
# Load the processor and model
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
# Define the function for segmentation
def segment_image(image):
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Get the segmentation mask
logits = outputs.logits
segmentation_mask = logits.argmax(dim=1).squeeze().cpu().numpy()
# Convert to a PIL image and return it
segmentation_image = Image.fromarray(segmentation_mask.astype("uint8"))
return segmentation_image
# Create the Gradio interface
interface = gr.Interface(fn=segment_image,
inputs=gr.Image(type="pil"), # Using gr.Image for Gradio
outputs=gr.Image(type="pil")) # Using gr.Image for Gradio
# Launch the interface
interface.launch()
|