Files changed (1) hide show
  1. app.py +36 -1
app.py CHANGED
@@ -1,3 +1,38 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/nvidia/segformer-b0-finetuned-ade-512-512").launch()
 
 
1
+ # import gradio as gr
2
+
3
+ # gr.Interface.load("models/nvidia/segformer-b0-finetuned-ade-512-512").launch()
4
+
5
+
6
  import gradio as gr
7
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
8
+ from PIL import Image
9
+ import torch
10
+
11
+ # Load the processor and model
12
+ processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
13
+ model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
14
+
15
+ # Define the function for segmentation
16
+ def segment_image(image):
17
+ # Preprocess the image
18
+ inputs = processor(images=image, return_tensors="pt")
19
+
20
+ # Perform inference
21
+ with torch.no_grad():
22
+ outputs = model(**inputs)
23
+
24
+ # Get the segmentation mask
25
+ logits = outputs.logits
26
+ segmentation_mask = logits.argmax(dim=1).squeeze().cpu().numpy()
27
+
28
+ # Convert to a PIL image and return it
29
+ segmentation_image = Image.fromarray(segmentation_mask.astype("uint8"))
30
+ return segmentation_image
31
+
32
+ # Create the Gradio interface
33
+ interface = gr.Interface(fn=segment_image,
34
+ inputs=gr.Image(type="pil"), # Using gr.Image for Gradio
35
+ outputs=gr.Image(type="pil")) # Using gr.Image for Gradio
36
 
37
+ # Launch the interface
38
+ interface.launch()