Spaces:
Runtime error
Runtime error
# import gradio as gr | |
# gr.Interface.load("models/nvidia/segformer-b0-finetuned-ade-512-512").launch() | |
import gradio as gr | |
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
from PIL import Image | |
import torch | |
# Load the processor and model | |
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
# Define the function for segmentation | |
def segment_image(image): | |
# Preprocess the image | |
inputs = processor(images=image, return_tensors="pt") | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the segmentation mask | |
logits = outputs.logits | |
segmentation_mask = logits.argmax(dim=1).squeeze().cpu().numpy() | |
# Convert to a PIL image and return it | |
segmentation_image = Image.fromarray(segmentation_mask.astype("uint8")) | |
return segmentation_image | |
# Create the Gradio interface | |
interface = gr.Interface(fn=segment_image, | |
inputs=gr.Image(type="pil"), # Using gr.Image for Gradio | |
outputs=gr.Image(type="pil")) # Using gr.Image for Gradio | |
# Launch the interface | |
interface.launch() | |