AAAAAAyq commited on
Commit
4d26566
1 Parent(s): d910d42

Add application file

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ from PIL import Image
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import gradio as gr
6
+ import io
7
+ # import cv2
8
+
9
+ model = YOLO('checkpoints/FastSAM.pt') # load a custom model
10
+
11
+ def show_mask(annotation, ax, random_color=False, bbox=None, points=None):
12
+ if random_color : # random mask color
13
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
14
+ else:
15
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
16
+ if type(annotation) == dict:
17
+ annotation = annotation['segmentation']
18
+ mask = annotation
19
+ h, w = mask.shape[-2:]
20
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
21
+ # draw box
22
+ if bbox is not None:
23
+ x1, y1, x2, y2 = bbox
24
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
25
+ # draw point
26
+ if points is not None:
27
+ ax.scatter([point[0] for point in points], [point[1] for point in points], s=10, c='g')
28
+ ax.imshow(mask_image)
29
+ return mask_image
30
+
31
+ def post_process(annotations, image, mask_random_color=False, bbox=None, points=None):
32
+ # image = cv2.imread(image_path)
33
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
34
+ plt.figure(figsize=(10, 10))
35
+ plt.imshow(image)
36
+ for i, mask in enumerate(annotations):
37
+ show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
38
+ plt.axis('off')
39
+
40
+ # create a BytesIO object
41
+ buf = io.BytesIO()
42
+
43
+ # save plot to buf
44
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.0)
45
+ # plt.savefig('buffer/tmp.png', bbox_inches='tight', pad_inches=0.0)
46
+
47
+ # use PIL to open the image
48
+ img = Image.open(buf)
49
+
50
+ # don't forget to close the buffer
51
+ buf.close()
52
+ return img
53
+
54
+
55
+ # def show_mask(annotation, ax, random_color=False):
56
+ # if random_color : # 掩膜颜色是否随机决定
57
+ # color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
58
+ # else:
59
+ # color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
60
+ # mask = annotation.cpu().numpy()
61
+ # h, w = mask.shape[-2:]
62
+ # mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
63
+ # ax.imshow(mask_image)
64
+
65
+ # def post_process(annotations, image):
66
+ # plt.figure(figsize=(10, 10))
67
+ # plt.imshow(image)
68
+ # for i, mask in enumerate(annotations):
69
+ # show_mask(mask.data, plt.gca(),random_color=True)
70
+ # plt.axis('off')
71
+
72
+ # 获取渲染后的像素数据并转换为PIL图像
73
+
74
+ return pil_image
75
+
76
+
77
+ # post_process(results[0].masks, Image.open("../data/cake.png"))
78
+
79
+ def predict(inp):
80
+ results = model(inp, device='0', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
81
+ pil_image = post_process(results[0].masks, inp)
82
+ return pil_image
83
+
84
+
85
+ demo = gr.Interface(fn=predict,
86
+ inputs=gr.inputs.Image(type='pil'),
87
+ outputs=gr.outputs.Image(type='pil'),
88
+ examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
89
+ ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
90
+ ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
91
+ ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
92
+ )
93
+
94
+ demo.launch()
assets/sa_10039.jpg ADDED

Git LFS Details

  • SHA256: 4a9735583a997fa08e5eb36b3ba8bf17a31771bb2aea71e6d51ab9824c1d141e
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB
assets/sa_11025.jpg ADDED

Git LFS Details

  • SHA256: b7edd63aa5121414bc29a760770606d09387561ff990c89f9b82c35803bd20aa
  • Pointer size: 131 Bytes
  • Size of remote file: 988 kB
assets/sa_1309.jpg ADDED

Git LFS Details

  • SHA256: b1012cbfd3ffe4ee0da940dc45961fbd1ce7546bea566f650514ec56d72b0460
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
assets/sa_192.jpg ADDED

Git LFS Details

  • SHA256: dcec4fce91382cbfeb2711fff3caeae183c23cb6d8a6c9e2ca0cd2e8eac39512
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
assets/sa_414.jpg ADDED

Git LFS Details

  • SHA256: 69dbead40b43e54d3bb80fb372c2e241b0f3ff2159d32525433a75153e067c65
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
assets/sa_561.jpg ADDED

Git LFS Details

  • SHA256: 837d725885e427534623dcc7d82ea846fffea046877c94e2e9c5b027d593796b
  • Pointer size: 131 Bytes
  • Size of remote file: 822 kB
assets/sa_862.jpg ADDED

Git LFS Details

  • SHA256: 06efc970f0d95faa6e8c69ee73f2032627569dde1c28bc783faebdaefa5eb2a8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
assets/sa_8776.jpg ADDED

Git LFS Details

  • SHA256: 7d71aea32d9f14122378a0707a4243de968d87b292a20a905351b5eacd924212
  • Pointer size: 131 Bytes
  • Size of remote file: 471 kB
checkpoints/FastSAM.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
3
+ size 144943063
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base-----------------------------------
2
+ matplotlib>=3.2.2
3
+ opencv-python>=4.6.0
4
+ Pillow>=7.1.2
5
+ PyYAML>=5.3.1
6
+ requests>=2.23.0
7
+ scipy>=1.4.1
8
+ torch>=1.7.0
9
+ torchvision>=0.8.1
10
+ tqdm>=4.64.0
11
+
12
+ pandas>=1.1.4
13
+ seaborn>=0.11.0
14
+
15
+ # Ultralytics-----------------------------------
16
+ ultralytics
17
+