jigara56's picture
Update app.py
86e7f2a verified
raw
history blame
1.29 kB
# import gradio as gr
# gr.Interface.load("models/nvidia/segformer-b0-finetuned-ade-512-512").launch()
import gradio as gr
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import torch
# Load the processor and model
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
# Define the function for segmentation
def segment_image(image):
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Get the segmentation mask
logits = outputs.logits
segmentation_mask = logits.argmax(dim=1).squeeze().cpu().numpy()
# Convert to a PIL image and return it
segmentation_image = Image.fromarray(segmentation_mask.astype("uint8"))
return segmentation_image
# Create the Gradio interface
interface = gr.Interface(fn=segment_image,
inputs=gr.Image(type="pil"), # Using gr.Image for Gradio
outputs=gr.Image(type="pil")) # Using gr.Image for Gradio
# Launch the interface
interface.launch()