Spaces:
Sleeping
Sleeping
:initial commit
Browse files- .gitignore +8 -0
- app.py +121 -0
- inference_beit.py +0 -0
- inference_diffuser.py +0 -0
- inference_resnet.py +167 -0
- inference_sam.py +175 -0
- labels.py +175 -0
- pre-requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
venv/
|
3 |
+
images/
|
4 |
+
*.pyc
|
5 |
+
*.pyo
|
6 |
+
*.pyd
|
7 |
+
*.swp
|
8 |
+
*.__pycache__
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
if os.getenv('SYSTEM') == 'spaces':
|
5 |
+
subprocess.call('pip install tensorflow==2.9'.split())
|
6 |
+
subprocess.call('pip install keras==2.9'.split())
|
7 |
+
subprocess.call('pip install git+https://github.com/facebookresearch/segment-anything.git')
|
8 |
+
subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
|
9 |
+
subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
from huggingface_hub import snapshot_download
|
13 |
+
import cv2
|
14 |
+
import dotenv
|
15 |
+
dotenv.load_dotenv()
|
16 |
+
import numpy as np
|
17 |
+
import gradio as gr
|
18 |
+
import glob
|
19 |
+
from inference_sam import segmentation_sam
|
20 |
+
|
21 |
+
import pathlib
|
22 |
+
|
23 |
+
if not os.path.exists('images'):
|
24 |
+
REPO_ID='Serrelab/image_examples_gradio'
|
25 |
+
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
|
26 |
+
|
27 |
+
|
28 |
+
def segment_image(input_image):
|
29 |
+
img = segmentation_sam(input_image)
|
30 |
+
return img
|
31 |
+
|
32 |
+
def classify_image(input_image, model_name):
|
33 |
+
if 'Rock 170' ==model_name:
|
34 |
+
from inference_resnet import inference_resnet_finer
|
35 |
+
result = inference_resnet_finer(input_image,model_name,n_classes=171)
|
36 |
+
return result
|
37 |
+
elif 'Mummified 170' ==model_name:
|
38 |
+
from inference_resnet import inference_resnet_finer
|
39 |
+
result = inference_resnet_finer(input_image,model_name,n_classes=170)
|
40 |
+
return result
|
41 |
+
if 'Fossils 19' ==model_name:
|
42 |
+
from inference_beit import inference_dino
|
43 |
+
return inference_dino(input_image,model_name)
|
44 |
+
return None
|
45 |
+
|
46 |
+
def find_closest(input_image):
|
47 |
+
return None
|
48 |
+
|
49 |
+
|
50 |
+
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
51 |
+
|
52 |
+
with gr.Tab(" 19 Classes Support"):
|
53 |
+
|
54 |
+
with gr.Row():
|
55 |
+
with gr.Column():
|
56 |
+
input_image = gr.Image(label="Input")
|
57 |
+
classify_image_button = gr.Button("Classify Image")
|
58 |
+
|
59 |
+
with gr.Column():
|
60 |
+
segmented_image = gr.outputs.Image(label="SAM output",type='numpy')
|
61 |
+
segment_button = gr.Button("Segment Image")
|
62 |
+
#classify_segmented_button = gr.Button("Classify Segmented Image")
|
63 |
+
|
64 |
+
with gr.Column():
|
65 |
+
drop_2 = gr.Dropdown(
|
66 |
+
["Mummified 170", "Rock 170", "Fossils 19"],
|
67 |
+
multiselect=False,
|
68 |
+
value=["Rock 170"],
|
69 |
+
label="Model",
|
70 |
+
interactive=True,
|
71 |
+
)
|
72 |
+
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
|
73 |
+
|
74 |
+
with gr.Row():
|
75 |
+
|
76 |
+
paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
|
77 |
+
samples=[[path.as_posix()] for path in paths if 'fossils' in str(path) ][:19]
|
78 |
+
examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=10,label='Fossils Examples from the dataset')
|
79 |
+
samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
|
80 |
+
examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
|
81 |
+
|
82 |
+
with gr.Accordion("Using Diffuser"):
|
83 |
+
with gr.Column():
|
84 |
+
prompt = gr.Textbox(lines=1, label="Prompt")
|
85 |
+
output_image = gr.Image(label="Output")
|
86 |
+
generate_button = gr.Button("Generate Leave")
|
87 |
+
with gr.Column():
|
88 |
+
class_predicted2 = gr.Label(label='Class Predicted from diffuser')
|
89 |
+
classify_button = gr.Button("Classify Image")
|
90 |
+
|
91 |
+
|
92 |
+
with gr.Accordion("Explanations "):
|
93 |
+
gr.Markdown("Computing Explanations from the model")
|
94 |
+
with gr.Row():
|
95 |
+
original_input = gr.Image(label="Original Frame")
|
96 |
+
saliency = gr.Image(label="saliency")
|
97 |
+
gradcam = gr.Image(label='gradcam')
|
98 |
+
guided_gradcam = gr.Image(label='guided gradcam')
|
99 |
+
guided_backprop = gr.Image(label='guided backprop')
|
100 |
+
generate_explanations = gr.Button("Generate Explanations")
|
101 |
+
|
102 |
+
with gr.Accordion('Closest Images'):
|
103 |
+
gr.Markdown("Finding the closest images in the dataset")
|
104 |
+
with gr.Row():
|
105 |
+
closest_image_0 = gr.Image(label='Closest Image')
|
106 |
+
closest_image_1 = gr.Image(label='Second Closest Image')
|
107 |
+
closest_image_2 = gr.Image(label='Third Closest Image')
|
108 |
+
closest_image_3 = gr.Image(label='Forth Closest Image')
|
109 |
+
closest_image_4 = gr.Image(label='Fifth Closest Image')
|
110 |
+
find_closest_btn = gr.Button("Find Closest Images")
|
111 |
+
|
112 |
+
segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
|
113 |
+
classify_image_button.click(classify_image, inputs=[input_image,drop_2], outputs=class_predicted)
|
114 |
+
#classify_segmented_button.click(classify_image, inputs=[segmented_image,drop_2], outputs=class_predicted)
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
demo.launch(debug=True)
|
119 |
+
|
120 |
+
|
121 |
+
|
inference_beit.py
ADDED
File without changes
|
inference_diffuser.py
ADDED
File without changes
|
inference_resnet.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
|
3 |
+
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
|
4 |
+
from keras.applications import resnet
|
5 |
+
import tensorflow.keras.layers as L
|
6 |
+
import os
|
7 |
+
|
8 |
+
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from typing import Tuple
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
from labels import lookup_170
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
REPO_ID='Serrelab/fossil_classification_models'
|
17 |
+
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='model',local_dir='model_classification')
|
18 |
+
|
19 |
+
|
20 |
+
def get_model(base_arch='Nasnet',weights='imagenet',input_shape=(600,600,3),classes=64500):
|
21 |
+
|
22 |
+
if base_arch == 'Nasnet':
|
23 |
+
base_model = tf.keras.applications.NASNetLarge(
|
24 |
+
input_shape=input_shape,
|
25 |
+
include_top=False,
|
26 |
+
weights=weights,
|
27 |
+
input_tensor=None,
|
28 |
+
pooling=None,
|
29 |
+
|
30 |
+
)
|
31 |
+
elif base_arch == 'Resnet50v2':
|
32 |
+
base_model = tf.keras.applications.ResNet50V2(weights=weights,
|
33 |
+
include_top=False,
|
34 |
+
pooling='avg',
|
35 |
+
input_shape=input_shape)
|
36 |
+
elif base_arch == 'Resnet50v2_finer':
|
37 |
+
base_model = tf.keras.applications.ResNet50V2(weights=weights,
|
38 |
+
include_top=False,
|
39 |
+
pooling='avg',
|
40 |
+
input_shape=input_shape)
|
41 |
+
base_model = resnet.stack2(base_model.output, 512, 2, name="conv6")
|
42 |
+
base_model = resnet.stack2(base_model, 512, 2, name="conv7")
|
43 |
+
base_model = tf.keras.Model(base_model.input,base_model)
|
44 |
+
|
45 |
+
model = tf.keras.Sequential([
|
46 |
+
base_model,
|
47 |
+
L.Dense(classes,activation='softmax')
|
48 |
+
])
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
model.compile(optimizer='adam',
|
53 |
+
loss='categorical_crossentropy',
|
54 |
+
)
|
55 |
+
|
56 |
+
return model
|
57 |
+
|
58 |
+
|
59 |
+
def get_triplet_model(input_shape = (600, 600, 3),
|
60 |
+
embedding_units = 256,
|
61 |
+
embedding_depth = 2,
|
62 |
+
backbone_class=tf.keras.applications.ResNet50V2,
|
63 |
+
nb_classes = 19,load_weights=False,finer_model=False,backbone_name ='Resnet50v2'):
|
64 |
+
|
65 |
+
|
66 |
+
backbone = backbone_class(input_shape=input_shape, include_top=False)
|
67 |
+
if load_weights:
|
68 |
+
model = get_model(backbone_name,input_shape=input_shape)
|
69 |
+
model.load_weights('/users/irodri15/data/irodri15/Fossils/Models/pretrained-herbarium/Resnet50v2_NO_imagenet_None_best_1600.h5')
|
70 |
+
trw = model.layers[0].get_weights()
|
71 |
+
backbone.set_weights(trw)
|
72 |
+
if finer_model:
|
73 |
+
base_model = resnet.stack2(backbone.output, 512, 2, name="conv6")
|
74 |
+
base_model = resnet.stack2(base_model, 512, 2, name="conv7")
|
75 |
+
backbone = tf.keras.Model(backbone.input,base_model)
|
76 |
+
|
77 |
+
features = GlobalAveragePooling2D()(backbone.output)
|
78 |
+
|
79 |
+
embedding_head = features
|
80 |
+
for embed_i in range(embedding_depth):
|
81 |
+
embedding_head = Dense(embedding_units, activation="relu" if embed_i < embedding_depth-1 else "linear")(embedding_head)
|
82 |
+
embedding_head = tf.nn.l2_normalize(embedding_head, -1, epsilon=1e-5)
|
83 |
+
|
84 |
+
logits_head = Dense(nb_classes)(features)
|
85 |
+
|
86 |
+
model = tf.keras.Model(backbone.input, [embedding_head, logits_head])
|
87 |
+
model.compile(loss='cce',metrics=['accuracy'])
|
88 |
+
#model.summary()
|
89 |
+
|
90 |
+
return model
|
91 |
+
|
92 |
+
load_size = 600
|
93 |
+
crop_size = 600
|
94 |
+
def _clever_crop(img: tf.Tensor,
|
95 |
+
target_size: Tuple[int]=(128,128),
|
96 |
+
grayscale: bool=False
|
97 |
+
) -> tf.Tensor:
|
98 |
+
"""[summary]
|
99 |
+
Args:
|
100 |
+
img (tf.Tensor): [description]
|
101 |
+
target_size (Tuple[int], optional): [description]. Defaults to (128,128).
|
102 |
+
grayscale (bool, optional): [description]. Defaults to False.
|
103 |
+
Returns:
|
104 |
+
tf.Tensor: [description]
|
105 |
+
"""
|
106 |
+
maxside = tf.math.maximum(tf.shape(img)[0],tf.shape(img)[1])
|
107 |
+
minside = tf.math.minimum(tf.shape(img)[0],tf.shape(img)[1])
|
108 |
+
new_img = img
|
109 |
+
|
110 |
+
if tf.math.divide(maxside,minside) > 1.2:
|
111 |
+
repeating = tf.math.floor(tf.math.divide(maxside,minside))
|
112 |
+
new_img = img
|
113 |
+
if tf.math.equal(tf.shape(img)[1],minside):
|
114 |
+
for _ in range(int(repeating)):
|
115 |
+
new_img = tf.concat((new_img, img), axis=1)
|
116 |
+
|
117 |
+
if tf.math.equal(tf.shape(img)[0],minside):
|
118 |
+
for _ in range(int(repeating)):
|
119 |
+
new_img = tf.concat((new_img, img), axis=0)
|
120 |
+
new_img = tf.image.rot90(new_img)
|
121 |
+
else:
|
122 |
+
new_img = img
|
123 |
+
repeating = 0
|
124 |
+
img = tf.image.resize(new_img, target_size)
|
125 |
+
if grayscale:
|
126 |
+
img = tf.image.rgb_to_grayscale(img)
|
127 |
+
img = tf.image.grayscale_to_rgb(img)
|
128 |
+
|
129 |
+
return img,repeating
|
130 |
+
|
131 |
+
def preprocess(img,size=600):
|
132 |
+
img = np.array(img, np.float32) / 255.0
|
133 |
+
img = tf.image.resize(img, (size, size))
|
134 |
+
return np.array(img, np.float32)
|
135 |
+
|
136 |
+
|
137 |
+
def select_top_n(preds,n=10):
|
138 |
+
top_n = np.argsort(preds)[-n:][::-1]
|
139 |
+
return top_n
|
140 |
+
|
141 |
+
|
142 |
+
def parse_results(top_n,logits):
|
143 |
+
results = {}
|
144 |
+
for n in top_n:
|
145 |
+
label = lookup_170[n]
|
146 |
+
results[label] = float(logits[n])
|
147 |
+
return results
|
148 |
+
|
149 |
+
def inference_resnet_finer(x,type_model,size=576,n_classes=170,n_top=10):
|
150 |
+
|
151 |
+
model = get_triplet_model(input_shape = (size, size, 3),
|
152 |
+
embedding_units = 256,
|
153 |
+
embedding_depth = 2,
|
154 |
+
backbone_class=tf.keras.applications.ResNet50V2,
|
155 |
+
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
|
156 |
+
if type_model=='Mummified 170':
|
157 |
+
model.load_weights('model_classification/mummified-170.h5')
|
158 |
+
elif type_model=='Rock 170':
|
159 |
+
model.load_weights('model_classification/rock-170.h5')
|
160 |
+
else:
|
161 |
+
return 'Error'
|
162 |
+
cropped = _clever_crop(x,(size,size))[0]
|
163 |
+
prep = preprocess(cropped,size=size)
|
164 |
+
logits = tf.nn.softmax(model.predict(np.array([prep]))[1][0]).cpu().numpy()
|
165 |
+
top_n = select_top_n(logits,n=n_top)
|
166 |
+
|
167 |
+
return parse_results(top_n,logits)
|
inference_sam.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
torch.cuda.set_per_process_memory_fraction(0.3, device=0)
|
3 |
+
import tensorflow as tf
|
4 |
+
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
|
5 |
+
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
|
6 |
+
|
7 |
+
from segment_anything import SamPredictor, sam_model_registry
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
from math import ceil
|
12 |
+
import os
|
13 |
+
from huggingface_hub import snapshot_download
|
14 |
+
|
15 |
+
REPO_ID='Serrelab/SAM_Leaves'
|
16 |
+
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='model',local_dir='model')
|
17 |
+
|
18 |
+
sam = sam_model_registry["default"]("model/sam_02-06_dice_mse_0.pth")
|
19 |
+
sam.cuda()
|
20 |
+
predictor = SamPredictor(sam)
|
21 |
+
|
22 |
+
|
23 |
+
from torch.nn import functional as F
|
24 |
+
|
25 |
+
|
26 |
+
def pad_gt(x):
|
27 |
+
h, w = x.shape[-2:]
|
28 |
+
padh = sam.image_encoder.img_size - h
|
29 |
+
padw = sam.image_encoder.img_size - w
|
30 |
+
x = F.pad(x, (0, padw, 0, padh))
|
31 |
+
return x
|
32 |
+
|
33 |
+
def preprocess(img):
|
34 |
+
|
35 |
+
img = np.array(img).astype(np.uint8)
|
36 |
+
|
37 |
+
#assert img.max() > 127.0
|
38 |
+
|
39 |
+
img_preprocess = predictor.transform.apply_image(img)
|
40 |
+
intermediate_shape = img_preprocess.shape
|
41 |
+
|
42 |
+
img_preprocess = torch.as_tensor(img_preprocess).cuda()
|
43 |
+
img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :]
|
44 |
+
|
45 |
+
img_preprocess = sam.preprocess(img_preprocess)
|
46 |
+
if len(intermediate_shape) == 3:
|
47 |
+
intermediate_shape = intermediate_shape[:2]
|
48 |
+
elif len(intermediate_shape) == 4:
|
49 |
+
intermediate_shape = intermediate_shape[1:3]
|
50 |
+
|
51 |
+
return img_preprocess, intermediate_shape
|
52 |
+
|
53 |
+
|
54 |
+
def normalize(img):
|
55 |
+
img = img - tf.math.reduce_min(img)
|
56 |
+
img = img / tf.math.reduce_max(img)
|
57 |
+
img = img * 2.0 - 1.0
|
58 |
+
return img
|
59 |
+
|
60 |
+
def resize(img):
|
61 |
+
# default resize function for all pi outputs
|
62 |
+
return tf.image.resize(img, (SIZE, SIZE), method="bicubic")
|
63 |
+
|
64 |
+
def smooth_mask(mask, ds=20):
|
65 |
+
shape = tf.shape(mask)
|
66 |
+
w, h = shape[0], shape[1]
|
67 |
+
return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic")
|
68 |
+
|
69 |
+
def pi(img, mask):
|
70 |
+
img = tf.cast(img, tf.float32)
|
71 |
+
|
72 |
+
shape = tf.shape(img)
|
73 |
+
w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64)
|
74 |
+
|
75 |
+
mask = smooth_mask(mask.cpu().numpy().astype(float))
|
76 |
+
mask = tf.reduce_mean(mask, -1)
|
77 |
+
|
78 |
+
img = img * tf.cast(mask > 0.01, tf.float32)[:, :, None]
|
79 |
+
|
80 |
+
|
81 |
+
img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True)
|
82 |
+
img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
|
83 |
+
|
84 |
+
# building 2 anchors
|
85 |
+
anchors = tf.where(mask > 0.15)
|
86 |
+
anchor_xmin = tf.math.reduce_min(anchors[:, 0])
|
87 |
+
anchor_xmax = tf.math.reduce_max(anchors[:, 0])
|
88 |
+
anchor_ymin = tf.math.reduce_min(anchors[:, 1])
|
89 |
+
anchor_ymax = tf.math.reduce_max(anchors[:, 1])
|
90 |
+
|
91 |
+
if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50:
|
92 |
+
|
93 |
+
img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax])
|
94 |
+
|
95 |
+
delta_x = (anchor_xmax - anchor_xmin) // 4
|
96 |
+
delta_y = (anchor_ymax - anchor_ymin) // 4
|
97 |
+
img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x,
|
98 |
+
anchor_ymin+delta_y:anchor_ymax-delta_y]
|
99 |
+
img_anchor_2 = resize(img_anchor_2)
|
100 |
+
else:
|
101 |
+
img_anchor_1 = img_resize
|
102 |
+
img_anchor_2 = img_pad
|
103 |
+
|
104 |
+
# building the anchors max
|
105 |
+
anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0]
|
106 |
+
anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1]
|
107 |
+
|
108 |
+
img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w),
|
109 |
+
tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)]
|
110 |
+
|
111 |
+
img_max_zoom1 = resize(img_max_zoom1)
|
112 |
+
img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2,
|
113 |
+
anchor_max_y-SIZE//2:anchor_max_y+SIZE//2]
|
114 |
+
#img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w),
|
115 |
+
# tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)]
|
116 |
+
#tf.print(img_max_zoom2.shape)
|
117 |
+
#img_max_zoom2 = resize(img_max_zoom2)
|
118 |
+
return tf.cast([
|
119 |
+
img_resize,
|
120 |
+
#img_pad,
|
121 |
+
img_anchor_1,
|
122 |
+
img_anchor_2,
|
123 |
+
img_max_zoom1,
|
124 |
+
#img_max_zoom2,
|
125 |
+
], tf.float32)
|
126 |
+
|
127 |
+
def one_step_inference(x):
|
128 |
+
if len(x.shape) == 3:
|
129 |
+
original_size = x.shape[:2]
|
130 |
+
elif len(x.shape) == 4:
|
131 |
+
original_size = x.shape[1:3]
|
132 |
+
|
133 |
+
x, intermediate_shape = preprocess(x)
|
134 |
+
|
135 |
+
with torch.no_grad():
|
136 |
+
image_embedding = sam.image_encoder(x)
|
137 |
+
|
138 |
+
with torch.no_grad():
|
139 |
+
sparse_embeddings, dense_embeddings = sam.prompt_encoder(points = None, boxes = None,masks = None)
|
140 |
+
low_res_masks, iou_predictions = sam.mask_decoder(
|
141 |
+
image_embeddings=image_embedding,
|
142 |
+
image_pe=sam.prompt_encoder.get_dense_pe(),
|
143 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
144 |
+
dense_prompt_embeddings=dense_embeddings,
|
145 |
+
multimask_output=False,
|
146 |
+
)
|
147 |
+
if len(x.shape) == 3:
|
148 |
+
input_size = tuple(x.shape[:2])
|
149 |
+
elif len(x.shape) == 4:
|
150 |
+
input_size = tuple(x.shape[-2:])
|
151 |
+
|
152 |
+
|
153 |
+
#upscaled_masks = sam.postprocess_masks(low_res_masks, input_size, original_size).cuda()
|
154 |
+
mask = F.interpolate(low_res_masks, (1024, 1024))[:, :, :intermediate_shape[0], :intermediate_shape[1]]
|
155 |
+
mask = F.interpolate(mask, (original_size[0], original_size[1]))
|
156 |
+
|
157 |
+
return mask
|
158 |
+
|
159 |
+
def segmentation_sam(x,SIZE=384):
|
160 |
+
|
161 |
+
x = tf.image.resize_with_pad(x, SIZE, SIZE)
|
162 |
+
predicted_mask = one_step_inference(x)
|
163 |
+
fig, ax = plt.subplots()
|
164 |
+
img = x.cpu().numpy()
|
165 |
+
mask = predicted_mask.cpu().numpy()[0][0]>0.2
|
166 |
+
ax.imshow(img)
|
167 |
+
ax.imshow(mask, cmap='jet', alpha=0.4)
|
168 |
+
plt.savefig('test.png')
|
169 |
+
ax.axis('off')
|
170 |
+
fig.canvas.draw()
|
171 |
+
# Now we can save it to a numpy array.
|
172 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
173 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
174 |
+
plt.close()
|
175 |
+
return data
|
labels.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lookup_170 = {0: 'Anacardiaceae',
|
2 |
+
1: 'Betulaceae',
|
3 |
+
2: 'Cornaceae',
|
4 |
+
3: 'Cunoniaceae',
|
5 |
+
4: 'Euphorbiaceae',
|
6 |
+
5: 'Fabaceae',
|
7 |
+
6: 'Fagaceae',
|
8 |
+
7: 'Juglandaceae',
|
9 |
+
8: 'Lauraceae',
|
10 |
+
9: 'Malvaceae',
|
11 |
+
10: 'Meliaceae',
|
12 |
+
11: 'Menispermaceae',
|
13 |
+
12: 'Myrtaceae',
|
14 |
+
13: 'Proteaceae',
|
15 |
+
14: 'Rhamnaceae',
|
16 |
+
15: 'Rosaceae',
|
17 |
+
16: 'Salicaceae',
|
18 |
+
17: 'Sapindaceae',
|
19 |
+
18: 'Ulmaceae',
|
20 |
+
19: 'Acanthaceae',
|
21 |
+
20: 'Achariaceae',
|
22 |
+
21: 'Achatocarpaceae',
|
23 |
+
22: 'Actinidiaceae',
|
24 |
+
23: 'Adoxaceae',
|
25 |
+
24: 'Altingiaceae',
|
26 |
+
25: 'Amaranthaceae',
|
27 |
+
26: 'Ancistrocladaceae',
|
28 |
+
27: 'Anisophylleaceae',
|
29 |
+
28: 'Annonaceae',
|
30 |
+
29: 'Apiaceae',
|
31 |
+
30: 'Apocynaceae',
|
32 |
+
31: 'Berberidaceae',
|
33 |
+
32: 'Bignoniaceae',
|
34 |
+
33: 'Bixaceae',
|
35 |
+
34: 'Bonnetiaceae',
|
36 |
+
35: 'Boraginaceae',
|
37 |
+
36: 'Brunelliaceae',
|
38 |
+
37: 'Burseraceae',
|
39 |
+
38: 'Buxaceae',
|
40 |
+
39: 'Calophyllaceae',
|
41 |
+
40: 'Calycanthaceae',
|
42 |
+
41: 'Campanulaceae',
|
43 |
+
42: 'Canellaceae',
|
44 |
+
43: 'Cannabaceae',
|
45 |
+
44: 'Capparaceae',
|
46 |
+
45: 'Caprifoliaceae',
|
47 |
+
46: 'Cardiopteridaceae',
|
48 |
+
47: 'Caricaceae',
|
49 |
+
48: 'Caryocaraceae',
|
50 |
+
49: 'Celastraceae',
|
51 |
+
50: 'Centroplacaceae',
|
52 |
+
51: 'Cercidiphyllaceae',
|
53 |
+
52: 'Chloranthaceae',
|
54 |
+
53: 'Chrysobalanaceae',
|
55 |
+
54: 'Clethraceae',
|
56 |
+
55: 'Clusiaceae',
|
57 |
+
56: 'Combretaceae',
|
58 |
+
57: 'Connaraceae',
|
59 |
+
58: 'Coriariaceae',
|
60 |
+
59: 'Crassulaceae',
|
61 |
+
60: 'Crossosomataceae',
|
62 |
+
61: 'Cucurbitaceae',
|
63 |
+
62: 'Dichapetalaceae',
|
64 |
+
63: 'Dilleniaceae',
|
65 |
+
64: 'Dipterocarpaceae',
|
66 |
+
65: 'Ebenaceae',
|
67 |
+
66: 'Elaeocarpaceae',
|
68 |
+
67: 'Ericaceae',
|
69 |
+
68: 'Erythroxylaceae',
|
70 |
+
69: 'Escalloniaceae',
|
71 |
+
70: 'Eucommiaceae',
|
72 |
+
71: 'Garryaceae',
|
73 |
+
72: 'Gentianaceae',
|
74 |
+
73: 'Geraniaceae',
|
75 |
+
74: 'Gesneriaceae',
|
76 |
+
75: 'Gnetaceae',
|
77 |
+
76: 'Grossulariaceae',
|
78 |
+
77: 'Gunneraceae',
|
79 |
+
78: 'Hamamelidaceae',
|
80 |
+
79: 'Hernandiaceae',
|
81 |
+
80: 'Humiriaceae',
|
82 |
+
81: 'Hydrangeaceae',
|
83 |
+
82: 'Hypericaceae',
|
84 |
+
83: 'Icacinaceae',
|
85 |
+
84: 'Irvingiaceae',
|
86 |
+
85: 'Iteaceae',
|
87 |
+
86: 'Ixonanthaceae',
|
88 |
+
87: 'Lamiaceae',
|
89 |
+
88: 'Lardizabalaceae',
|
90 |
+
89: 'Lecythidaceae',
|
91 |
+
90: 'Liliaceae',
|
92 |
+
91: 'Linaceae',
|
93 |
+
92: 'Loganiaceae',
|
94 |
+
93: 'Loranthaceae',
|
95 |
+
94: 'Lythraceae',
|
96 |
+
95: 'Magnoliaceae',
|
97 |
+
96: 'Malpighiaceae',
|
98 |
+
97: 'Marantaceae',
|
99 |
+
98: 'Marcgraviaceae',
|
100 |
+
99: 'Melastomataceae',
|
101 |
+
100: 'Melianthaceae',
|
102 |
+
101: 'Monimiaceae',
|
103 |
+
102: 'Moraceae',
|
104 |
+
103: 'Myricaceae',
|
105 |
+
104: 'Myristicaceae',
|
106 |
+
105: 'Nitrariaceae',
|
107 |
+
106: 'Nothofagaceae',
|
108 |
+
107: 'Nyctaginaceae',
|
109 |
+
108: 'Ochnaceae',
|
110 |
+
109: 'Olacaceae',
|
111 |
+
110: 'Oleaceae',
|
112 |
+
111: 'Onagraceae',
|
113 |
+
112: 'Opiliaceae',
|
114 |
+
113: 'Orchidaceae',
|
115 |
+
114: 'Orobanchaceae',
|
116 |
+
115: 'Oxalidaceae',
|
117 |
+
116: 'Pandaceae',
|
118 |
+
117: 'Papaveraceae',
|
119 |
+
118: 'Paracryphiaceae',
|
120 |
+
119: 'Passifloraceae',
|
121 |
+
120: 'Pedaliaceae',
|
122 |
+
121: 'Penaeaceae',
|
123 |
+
122: 'Pentaphylacaceae',
|
124 |
+
123: 'Peridiscaceae',
|
125 |
+
124: 'Phyllanthaceae',
|
126 |
+
125: 'Phytolaccaceae',
|
127 |
+
126: 'Picramniaceae',
|
128 |
+
127: 'Picrodendraceae',
|
129 |
+
128: 'Piperaceae',
|
130 |
+
129: 'Pittosporaceae',
|
131 |
+
130: 'Platanaceae',
|
132 |
+
131: 'Polemoniaceae',
|
133 |
+
132: 'Polygalaceae',
|
134 |
+
133: 'Polygonaceae',
|
135 |
+
134: 'Primulaceae',
|
136 |
+
135: 'Ranunculaceae',
|
137 |
+
136: 'Rhabdodendraceae',
|
138 |
+
137: 'Rhizophoraceae',
|
139 |
+
138: 'Rubiaceae',
|
140 |
+
139: 'Rutaceae',
|
141 |
+
140: 'Sabiaceae',
|
142 |
+
141: 'Santalaceae',
|
143 |
+
142: 'Sapotaceae',
|
144 |
+
143: 'Sarcolaenaceae',
|
145 |
+
144: 'Saxifragaceae',
|
146 |
+
145: 'Schisandraceae',
|
147 |
+
146: 'Schoepfiaceae',
|
148 |
+
147: 'Scrophulariaceae',
|
149 |
+
148: 'Simaroubaceae',
|
150 |
+
149: 'Siparunaceae',
|
151 |
+
150: 'Smilacaceae',
|
152 |
+
151: 'Solanaceae',
|
153 |
+
152: 'Sphaerosepalaceae',
|
154 |
+
153: 'Stachyuraceae',
|
155 |
+
154: 'Staphyleaceae',
|
156 |
+
155: 'Stegnospermataceae',
|
157 |
+
156: 'Stemonuraceae',
|
158 |
+
157: 'Styracaceae',
|
159 |
+
158: 'Symplocaceae',
|
160 |
+
159: 'Theaceae',
|
161 |
+
160: 'Thymelaeaceae',
|
162 |
+
161: 'Trigoniaceae',
|
163 |
+
162: 'Trochodendraceae',
|
164 |
+
163: 'Urticaceae',
|
165 |
+
164: 'Verbenaceae',
|
166 |
+
165: 'Violaceae',
|
167 |
+
166: 'Vitaceae',
|
168 |
+
167: 'Vochysiaceae',
|
169 |
+
168: 'Winteraceae',
|
170 |
+
169: 'Zygophyllaceae',
|
171 |
+
170:'Araceae'}
|
172 |
+
|
173 |
+
dict_lu ={}
|
174 |
+
for i in range(171):
|
175 |
+
dict_lu[i] = lookup_170[i]
|
pre-requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.22.4
|
2 |
+
opencv-python-headless==4.5.5.64
|
3 |
+
openmim==0.1.5
|
4 |
+
torch==1.11.0
|
5 |
+
torchvision==0.12.0
|
6 |
+
tensorflow==2.8
|