kerzel commited on
Commit
f45fd07
·
1 Parent(s): 7b764b0

update example image

Browse files
X4-Aligned_cropped_upperleft_small.png → SE_001_cut.png RENAMED
File without changes
app.py CHANGED
@@ -18,7 +18,7 @@ from tensorflow import keras
18
  # --- Constants and Model Loading ---
19
  IMAGE_PATH = "classified_damage_sites.png"
20
  CSV_PATH = "classified_damage_sites.csv"
21
- DEFAULT_IMAGE_PATH = "X4-Aligned_cropped_upperleft_small.png"
22
 
23
  model1_windowsize = [250,250]
24
  #model1_threshold = 0.7
 
18
  # --- Constants and Model Loading ---
19
  IMAGE_PATH = "classified_damage_sites.png"
20
  CSV_PATH = "classified_damage_sites.csv"
21
+ DEFAULT_IMAGE_PATH = "SE_001_cut.png"
22
 
23
  model1_windowsize = [250,250]
24
  #model1_threshold = 0.7
app.py.save DELETED
@@ -1,192 +0,0 @@
1
- import gradio as gr
2
- import numpy as np
3
- import pandas as pd
4
-
5
- # our own helper tools
6
- import clustering
7
- import utils
8
-
9
- import logging
10
- logging.getLogger().setLevel(logging.INFO)
11
-
12
- from tensorflow import keras
13
-
14
-
15
- IMAGE_PATH = 'classified_damage_sites.png'
16
- CSV_PATH = 'classified_damage_sites.csv'
17
-
18
-
19
- #image_threshold = 20
20
-
21
- model1_windowsize = [250,250]
22
- #model1_threshold = 0.7
23
-
24
- model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5')
25
- model1.compile()
26
-
27
- damage_classes = {3: "Martensite",2: "Interface",0:"Notch",1:"Shadowing"}
28
-
29
- model2_windowsize = [100,100]
30
- #model2_threshold = 0.5
31
-
32
- model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5')
33
- model2.compile()
34
-
35
-
36
-
37
- ##
38
- ## Function to do the actual damage classification
39
- ##
40
- def damage_classification(SEM_image,image_threshold, model1_threshold, model2_threshold):
41
- if SEM_image is None:
42
- logging.error('No image provided')
43
- return None
44
-
45
- damage_sites = {}
46
- ##
47
- ## clustering
48
- ##
49
- logging.debug('---------------: clustering :=====================')
50
- all_centroids = clustering.get_centroids(SEM_image, image_threshold=image_threshold,
51
- fill_holes=True, filter_close_centroids=True)
52
-
53
- for i in range(len(all_centroids)) :
54
- key = (all_centroids[i][0],all_centroids[i][1])
55
- damage_sites[key] = 'Not Classified'
56
-
57
- ##
58
- ## Inclusions vs the rest
59
- ##
60
- logging.debug('---------------: prepare model 1 :=====================')
61
- images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
62
-
63
- logging.debug('---------------: run model 1 :=====================')
64
- y1_pred = model1.predict(np.asarray(images_model1, float))
65
-
66
- logging.debug('---------------: model1 threshold :=====================')
67
- inclusions = y1_pred[:,0].reshape(len(y1_pred),1)
68
- inclusions = np.where(inclusions > model1_threshold)
69
-
70
- logging.debug('---------------: model 1 update dict :=====================')
71
- for i in range(len(inclusions[0])):
72
- centroid_id = inclusions[0][i]
73
- coordinates = all_centroids[centroid_id]
74
- key = (coordinates[0], coordinates[1])
75
- damage_sites[key] = 'Inclusion'
76
- logging.debug('Damage sites after model 1')
77
- logging.debug(damage_sites)
78
-
79
- ##
80
- ## Martensite cracking, etc
81
- ##
82
- logging.debug('---------------: prepare model 2 :=====================')
83
- centroids_model2 = []
84
- for key, value in damage_sites.items():
85
- if value == 'Not Classified':
86
- coordinates = list([key[0],key[1]])
87
- centroids_model2.append(coordinates)
88
- logging.debug('Centroids model 2')
89
- logging.debug(centroids_model2)
90
-
91
- logging.debug('---------------: prepare model 2 :=====================')
92
- images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize)
93
- logging.debug('Images model 2')
94
- logging.debug(images_model2)
95
-
96
- logging.debug('---------------: run model 2 :=====================')
97
- y2_pred = model2.predict(np.asarray(images_model2, float))
98
-
99
- damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
100
-
101
-
102
- for i in range(len(damage_index[0])):
103
- index = damage_index[0][i]
104
- identified_class = damage_index[1][i]
105
- label = damage_classes[identified_class]
106
- coordinates = centroids_model2[index]
107
- #print('Damage {} \t identified as {}, \t coordinates {}'.format(i, label, coordinates))
108
- key = (coordinates[0], coordinates[1])
109
- damage_sites[key] = label
110
-
111
- ##
112
- ## show the damage sites on the image
113
- ##
114
- logging.debug("-----------------: final damage sites :=================")
115
- logging.debug(damage_sites)
116
-
117
- image_path = 'classified_damage_sites.png'
118
- image = utils.show_boxes(SEM_image, damage_sites,
119
- save_image=True,
120
- image_path=image_path)
121
-
122
- ##
123
- ## export data
124
- ##
125
- csv_path = 'classified_damage_sites.csv'
126
- cols = ['x', 'y', 'damage_type']
127
-
128
- data = []
129
- for key, value in damage_sites.items():
130
- data.append([key[0], key[1], value])
131
-
132
- df = pd.DataFrame(columns=cols, data=data)
133
-
134
- df.to_csv(csv_path)
135
-
136
-
137
- return image# , image_path, csv_path
138
-
139
- ## ---------------------------------------------------------------------------------------------------------------
140
- ## main app interface
141
- ## -----------------------------------------------------------------------------------------------------------------
142
- with gr.Blocks() as app:
143
- gr.Markdown('# Damage Classification in Dual Phase Steels')
144
- gr.Markdown('This app classifies damage types in dual phase steels. Two models are used. The first model is used to identify inclusions in the steel. The second model is used to identify the remaining damage types: Martensite cracking, Interface Decohesion, Notch effect and Shadows.')
145
-
146
- gr.Markdown('The models used in this app are based on the following papers:')
147
- gr.Markdown('Kusche, C., Reclik, T., Freund, M., Al-Samman, T., Kerzel, U., & Korte-Kerzel, S. (2019). Large-area, high-resolution characterisation and classification of damage mechanisms in dual-phase steel using deep learning. PloS one, 14(5), e0216493. [Link](https://doi.org/10.1371/journal.pone.0216493)')
148
- #gr.Markdown('Medghalchi, S., Kusche, C. F., Karimi, E., Kerzel, U., & Korte-Kerzel, S. (2020). Damage analysis in dual-phase steel using deep learning: transfer from uniaxial to biaxial straining conditions by image data augmentation. Jom, 72, 4420-4430. [Link](https://link.springer.com/article/10.1007/s11837-020-04404-0)')
149
- gr.Markdown('Setareh Medghalchi, Ehsan Karimi, Sang-Hyeok Lee, Benjamin Berkels, Ulrich Kerzel, Sandra Korte-Kerzel, Three-dimensional characterisation of deformation-induced damage in dual phase steel using deep learning, Materials & Design, Volume 232, 2023, 112108, ISSN 0264-1275, [link] (https://doi.org/10.1016/j.matdes.2023.112108')
150
- gr.Markdown('Original data and code, including the network weights, can be found at Zenodo [link](https://zenodo.org/records/8065752)')
151
-
152
- #image_input = gr.Image(value='data/X4-Aligned_cropped_upperleft_small.png', label='Example SEM Image (DP800 steel)',)
153
- image_input = gr.Image()
154
-
155
- with gr.Row():
156
- cluster_threshold_input = gr.Number(label='Cluster Threshold', value = 20,
157
- info='Grayscale value at which a pixel is attributed to a potential damage site')
158
- model1_threshold_input = gr.Number(label='Model 1 Threshold', value = 0.7, info='Threshold for the model identifying inclusions')
159
- model2_threshold_input = gr.Number(label='Model 2 Threshold', value = 0.5, info='Thrshold for the model identifying the remaining damage types')
160
-
161
-
162
- button = gr.Button("Classify")
163
- #output_image = gr.Image()
164
-
165
- download_image_btn = gr.DownloadButton(label="Download Image", value=IMAGE_PATH, interactive=False)
166
- download_csv_btn = gr.DownloadButton(label="Download Damage List", value=CSV_PATH, interactive=False)
167
-
168
-
169
-
170
- button.click(
171
- damage_classification,
172
- inputs=[image_input, cluster_threshold_input, model1_threshold_input, model2_threshold_input],
173
- outputs=gr.Image(label="Output Image")
174
- )
175
-
176
-
177
-
178
-
179
-
180
-
181
- # simple interface, no title, etc
182
- # app = gr.Interface(damage_classification,
183
- # inputs=[gr.Image(),
184
- # gr.Number(label='Cluster Threshold', value = 20, info='Grayscale value at which a pixel is attributed to a potential damage site'),
185
- # gr.Number(label='Model 1 Threshold', value = 0.7, info='Threshold for the model identifying inclusions'),
186
- # gr.Number(label='Model 2 Threshold', value = 0.5, info='Thrshold for the model identifying the remaining damage types')
187
- # ],
188
- # outputs=[gr.Image(),
189
- # gr.DownloadButton(label='Download Image'),
190
- # gr.DownloadButton(label='Download Damage List')])
191
- if __name__ == "__main__":
192
- app.launch()