SerdarHelli kadirnar commited on
Commit
6f4e6ed
1 Parent(s): cb733ec

Upload 6 files (#1)

Browse files

- Upload 6 files (4792ad8931ac2c0630332b275b58cb77e23bc263)
- Update app.py (4551826032cdb26b2c75b62ba0475c2c6e23af5b)
- Update requirements.txt (ac6852893d6075b22d895dacf24e62a09ace7ef1)


Co-authored-by: Kadir Nar <kadirnar@users.noreply.huggingface.co>

Files changed (6) hide show
  1. app.py +86 -3
  2. data/1.png +0 -0
  3. data/2.png +0 -0
  4. data/3.png +0 -0
  5. data/4.png +0 -0
  6. requirements.txt +7 -0
app.py CHANGED
@@ -1,7 +1,90 @@
 
 
 
1
  import gradio as gr
 
 
 
 
2
 
3
 
 
 
 
 
 
 
4
 
5
- #examples=["example_Benign1.png" ,"example_Benign2.png", "example_Malign2.png"]
6
- gr.Interface.load("huggingface/deprem-ml/deprem_satellite_semantic_whu",
7
- title="deprem_satellite_semantic_whu",cache_examples=False).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SegformerForSemanticSegmentation
2
+ from transformers import SegformerImageProcessor
3
+ from PIL import Image
4
  import gradio as gr
5
+ import numpy as np
6
+ import random
7
+ import cv2
8
+ import torch
9
 
10
 
11
+ image_list = [
12
+ "data/1.png",
13
+ "data/2.png",
14
+ "data/3.png",
15
+ "data/4.png",
16
+ ]
17
 
18
+ model_path = ['deprem-ml/deprem_satellite_semantic_whu']
19
+
20
+ def visualize_instance_seg_mask(mask):
21
+ # Initialize image with zeros with the image resolution
22
+ # of the segmentation mask and 3 channels
23
+ image = np.zeros((mask.shape[0], mask.shape[1], 3))
24
+
25
+ # Create labels
26
+ labels = np.unique(mask)
27
+ label2color = {
28
+ label: (
29
+ random.randint(0, 255),
30
+ random.randint(0, 255),
31
+ random.randint(0, 255),
32
+ )
33
+ for label in labels
34
+ }
35
+
36
+ for height in range(image.shape[0]):
37
+ for width in range(image.shape[1]):
38
+ image[height, width, :] = label2color[mask[height, width]]
39
+
40
+ image = image / 255
41
+ return image
42
+
43
+
44
+ def Segformer_Segmentation(image_path, model_id):
45
+ output_save = "output.png"
46
+
47
+ test_image = Image.open(image_path)
48
+
49
+ model = SegformerForSemanticSegmentation.from_pretrained(model_id)
50
+ proccessor = SegformerImageProcessor(model_id)
51
+
52
+ inputs = proccessor(images=test_image, return_tensors="pt")
53
+ with torch.no_grad():
54
+ outputs = model(**inputs)
55
+
56
+ result = proccessor.post_process_semantic_segmentation(outputs)[0]
57
+ result = np.array(result)
58
+ result = visualize_instance_seg_mask(result)
59
+ cv2.imwrite(output_save, result*255)
60
+
61
+ return image_path, output_save
62
+
63
+ examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"],
64
+ [image_list[1], "deprem-ml/deprem_satellite_semantic_whu"],
65
+ [image_list[2], "deprem-ml/deprem_satellite_semantic_whu"],
66
+ [image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]]
67
+
68
+ title = "Deprem ML - Segformer Semantic Segmentation"
69
+
70
+ app = gr.Blocks()
71
+ with app:
72
+ gr.HTML("<h1 style='text-align: center'>{}</h1>".format(title))
73
+ with gr.Row():
74
+ with gr.Column():
75
+ gr.Markdown("Video")
76
+ input_video = gr.Image(type='filepath')
77
+ model_id = gr.Dropdown(value=model_path[0], choices=model_path)
78
+ input_video_button = gr.Button(value="Predict")
79
+
80
+ with gr.Column():
81
+ output_orijinal_image = gr.Image(type='filepath')
82
+
83
+ with gr.Column():
84
+ output_mask_image = gr.Image(type='filepath')
85
+
86
+
87
+ gr.Examples(examples, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image], fn=Segformer_Segmentation, cache_examples=True)
88
+ input_video_button.click(Segformer_Segmentation, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image])
89
+
90
+ app.launch()
data/1.png ADDED
data/2.png ADDED
data/3.png ADDED
data/4.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==3.18.0
2
+ matplotlib==3.6.2
3
+ numpy==1.24.2
4
+ Pillow==9.4.0
5
+ torch==1.12.1
6
+ transformers==4.26.0
7
+ opencv-python