NassimeBejaia commited on
Commit
5555b23
·
1 Parent(s): 3213d15

Upload 4 files

Browse files
Files changed (4) hide show
  1. Utils(2).py +122 -0
  2. app(2).py +379 -0
  3. requirements(1).txt +13 -0
  4. text_line_model(1).h5 +3 -0
Utils(2).py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset,DataLoader
3
+ from PIL import Image
4
+
5
+ import os
6
+ import cv2
7
+
8
+ import tensorflow as tf
9
+ import numpy as np
10
+ from tensorflow import keras
11
+
12
+
13
+ ############ extractdetection ################
14
+
15
+
16
+ #alpha = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
17
+ alpha =['A', 'A1', 'A2' ,'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L','M',
18
+ 'N','O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
19
+ 'Z1', 'Z2','Z3','Z4','Z5']
20
+
21
+
22
+ ############ decode_detection ################
23
+
24
+ # labels from KHATT dataset
25
+ """characters = [' ', '!', '"', '#', '%', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '=', '>', '?', '[', '\\', ']', 'x', '\xa0', '×', '،', '؛', '؟', 'ء', 'آ', 'أ', 'ؤ', 'إ', 'ئ', 'ا', 'ب', 'ة', 'ت', 'ث', 'ج', 'ح', 'خ', 'د', 'ذ', 'ر', 'ز', 'س', 'ش', 'ص', 'ض', 'ط', 'ظ', 'ع', 'غ', 'ـ', 'ف', 'ق', 'ك', 'ل', 'م', 'ن', 'ه', 'و', 'ى', 'ي', 'ً', 'ٌ', 'ٍ', 'َ', 'ُ', 'ِ', 'ّ', 'ْ', '–', '‘']
26
+ characters.sort()"""
27
+
28
+
29
+ # labels from old arabic
30
+ characters = [' ', '.', '[', ']', '؟', 'ء', 'آ', 'أ', 'ؤ', 'إ', 'ئ', 'ا', 'ب', 'ة', 'ت', 'ث', 'ج', 'ح', 'خ', 'د', 'ذ', 'ر', 'ز', 'س', 'ش', 'ص', 'ض', 'ط', 'ظ', 'ع', 'غ', 'ـ', 'ف', 'ق', 'ك', 'ل', 'م', 'ن', 'ه', 'و', 'ى', 'ي', 'ً', 'ٌ', 'ٍ', 'َ', 'ُ', 'ِ', 'ّ', 'ْ', 'ٔ', 'ٕ', '١', '٢', '٣', '٤', '٥', '٧', '٨', 'ٮ', 'ٯ', 'ٰ', 'ڡ', 'ک', 'ں', 'ی', '۴', '\u202c', 'ﭐ', 'ﺟ', 'ﺣ', 'ﻛ', '�']
31
+ characters.sort()
32
+
33
+ characters.sort()
34
+ max_length = 132
35
+ img_height, img_width = 1056,64
36
+
37
+ def label_to_num(label,max_length=max_length):
38
+ label_num = []
39
+ for ch in label:
40
+ try:
41
+ label_num.append(characters.index(ch))
42
+ except:
43
+ pass
44
+
45
+ return keras.utils.pad_sequences( [label_num], maxlen=max_length, dtype='int32',
46
+ padding='post',truncating='pre',value=len(characters)+2)[0]
47
+
48
+
49
+ def num_to_label(num):
50
+ #if isinstance(num,torch.Tensor):
51
+ #num=num.to(torch.int8)
52
+ ret = ""
53
+ for ch in num:
54
+ if int(ch)==-1 or ch ==len(characters)+2 :# or ch==-1 : # pad symbole
55
+ break
56
+ try:
57
+ ret+=characters[int(ch)]
58
+ except:
59
+ pass
60
+ return ret
61
+
62
+
63
+
64
+ def decode_predictions(pred,greedy = True):
65
+ input_len = np.ones(pred.shape[0]) * pred.shape[1]
66
+ results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=greedy)[0][0][
67
+ :, :max_length ]
68
+ output_text = []
69
+ for res in results:
70
+ #decoded = tokenizer.sequences_to_texts([res.numpy()])
71
+ #decoded = list(map(num_to_label,results))
72
+ decoded = num_to_label( res)
73
+ output_text.append(decoded)
74
+ return output_text
75
+
76
+
77
+ ############ dataloader ################
78
+
79
+
80
+
81
+ # def encode_single_sample(path_dir,label = None):
82
+ # img = tf.io.read_file(path_dir)
83
+ # img = tf.io.decode_jpeg( img, name=None)
84
+ # img.set_shape([img.shape[0], img.shape[1],img.shape[-1]])
85
+ # img = tf.image.rot90(img, k=1, name=None)
86
+ # img = tf.image.resize(img, [img_height, img_width])
87
+ # img=img/255.0
88
+ # #img=tf.concat([img,img,img],axis=-1)
89
+
90
+ # return img
91
+
92
+ def encode_single_sample(path_dir, label=None):
93
+ img = tf.io.read_file(path_dir)
94
+ img = tf.io.decode_jpeg(img, name=None)
95
+ img.set_shape([img.shape[0], img.shape[1], img.shape[-1]])
96
+ img = tf.image.rot90(img, k=1, name=None)
97
+ img = tf.image.resize(img, [img_height, img_width])
98
+ img = tf.image.rgb_to_grayscale(img) # Convert image to grayscale
99
+ img = img/255.0
100
+ return img
101
+
102
+
103
+ batch_size = 16
104
+
105
+ def Loadlines(path_lines):
106
+ path_lines.sort()
107
+ test_dataset = tf.data.Dataset.from_tensor_slices(path_lines )
108
+ test_dataset = (
109
+ test_dataset.map(
110
+ encode_single_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE
111
+ )
112
+ .batch(batch_size,drop_remainder=False )
113
+ .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
114
+ )
115
+ return test_dataset
116
+
117
+
118
+ ############ load model ################
119
+
120
+ """ load model_finetuned """
121
+ def load_model():
122
+ return keras.models.load_model('/home/user/app/text_line_model.h5')
app(2).py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ import cv2
6
+
7
+ from PIL import Image
8
+
9
+ import os
10
+
11
+ import zipfile
12
+ import gdown
13
+
14
+ from tempfile import TemporaryDirectory
15
+ from Utils import Loadlines, decode_predictions, load_model
16
+
17
+ import shutil
18
+ import tempfile
19
+ import subprocess
20
+
21
+
22
+ # @st.cache(allow_output_mutation=True)
23
+ # # @st.cache_resource(allow_output_mutation=True)
24
+ # def get_model():
25
+ # return load_model()
26
+
27
+ # Load the OCR model outside the function to prevent reloading it every time
28
+ ocr_model = load_model()
29
+
30
+
31
+
32
+
33
+ def main():
34
+ st.title("Arabic Manuscript OCR")
35
+
36
+ # download_folder()
37
+
38
+ # Load sample images
39
+ sample_images_dir = "sample_images"
40
+ sample_images = [os.path.join(sample_images_dir, img) for img in os.listdir(sample_images_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]
41
+
42
+ # Display images in a grid and let user select
43
+ col1, col2, col3 = st.columns(3) # Adjust based on how many images you have; this example assumes 3
44
+
45
+ # Normalized size for display
46
+ display_size = (200, 300)
47
+
48
+ # Placeholder for selected image
49
+ selected_image_pil = None
50
+
51
+ processed_img_path = None
52
+
53
+ with col1:
54
+ if st.button("Select Image 1", key="img1"):
55
+ selected_image_pil = display_image(sample_images[0])
56
+ col1.image(resize_image(sample_images[0], display_size), use_column_width=True)
57
+
58
+ with col2:
59
+ if st.button("Select Image 2", key="img2"):
60
+ selected_image_pil = display_image(sample_images[1])
61
+ col2.image(resize_image(sample_images[1], display_size), use_column_width=True)
62
+
63
+ with col3:
64
+ if st.button("Select Image 3", key="img3"):
65
+ selected_image_pil = display_image(sample_images[2])
66
+ col3.image(resize_image(sample_images[2], display_size), use_column_width=True)
67
+
68
+
69
+
70
+ # Option to upload a new image
71
+ uploaded_image = st.file_uploader("Or upload a new image of the Arabic manuscript", type=["jpg", "jpeg", "png"])
72
+ if uploaded_image:
73
+ selected_image_pil = Image.open(uploaded_image)
74
+ st.image(selected_image_pil, caption="Uploaded Image", use_column_width=True)
75
+
76
+
77
+ if selected_image_pil:
78
+ thresh_pil = process_image(selected_image_pil)
79
+ st.image(thresh_pil, caption="Thresholded Image", use_column_width=True)
80
+
81
+ processed_img_path = process_with_yolo(selected_image_pil)
82
+
83
+ if processed_img_path:
84
+ st.image(processed_img_path, caption="Processed with YOLO", use_column_width=True)
85
+
86
+ txt_file_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(processed_img_path).replace(".jpg", ".txt"))
87
+ if os.path.exists(txt_file_path):
88
+ if uploaded_image:
89
+ original_img_path = uploaded_image # If you have saved it to a location
90
+ else:
91
+ original_img_path = next((img for img in sample_images if Image.open(img) == selected_image_pil), None) # Match selected image with sample images
92
+ display_detected_lines(original_img_path, processed_img_path)
93
+ else:
94
+ st.error("Annotation file (.txt) not found!")
95
+
96
+ else:
97
+ st.error("Error displaying the processed image.")
98
+
99
+ # display_files_in_directory('/home/user/app/detected_lines/')
100
+
101
+
102
+
103
+ # Function to display files in a directory
104
+ def display_files_in_directory(path="."):
105
+ if os.path.exists(path):
106
+ files = os.listdir(path)
107
+ st.write(f"Files in directory: {path}")
108
+ for file in files:
109
+ st.write(file)
110
+ else:
111
+ st.write(f"Directory {path} does not exist!")
112
+
113
+ def resize_image(image_path, size):
114
+ """Function to resize an image"""
115
+ with Image.open(image_path) as img:
116
+ img = img.resize(size)
117
+ return img
118
+
119
+ def display_image(img):
120
+ """Function to display an image. If img is a path, open it. Otherwise, just display it."""
121
+ if isinstance(img, str): # img is a file path
122
+ img = Image.open(img)
123
+
124
+ st.image(img, caption="Selected Image", use_column_width=True)
125
+ return img
126
+
127
+
128
+ # --------------------------------------------
129
+ def process_image(selected_image):
130
+ # Convert PIL image to OpenCV format
131
+ opencv_image = np.array(selected_image)
132
+
133
+ # Check if the image is grayscale or RGB
134
+ if len(opencv_image.shape) == 3 and opencv_image.shape[2] == 3: # RGB image
135
+ gray_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2GRAY)
136
+ else: # Image is already grayscale
137
+ gray_image = opencv_image
138
+
139
+ # Ensure the image is 8-bit grayscale
140
+ if gray_image.dtype != np.uint8:
141
+ gray_image = (gray_image * 255).astype(np.uint8)
142
+
143
+ # Optionally apply Gaussian Blur
144
+ blurred_img = cv2.GaussianBlur(gray_image, (5, 5), 0)
145
+
146
+ # Apply OTSU's thresholding
147
+ _, thresh = cv2.threshold(blurred_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
148
+
149
+ # Convert thresholded image back to PIL format to display in Streamlit
150
+ thresh_pil = Image.fromarray(thresh)
151
+
152
+ return thresh_pil
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+ # ---------------------------------------------------------------------------------
164
+
165
+ def download_ultralytics_yolov3_folder():
166
+ st.text("Downloading Ultralytics YOLOv3 folder from Google Drive. This may take a while...")
167
+
168
+ # url = 'https://drive.google.com/file/d/1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I/view?usp=share_link'
169
+ url = 'https://drive.google.com/uc?id=1n6YcqHl5Y2xRpWw7DQPrZ2FoAqfcf12I'
170
+
171
+ output = 'ultralytics_yolov3.zip'
172
+ gdown.download(url, output, quiet=False)
173
+
174
+ # Extracting the zip file
175
+ with zipfile.ZipFile(output, 'r') as zip_ref:
176
+ zip_ref.extractall('.')
177
+ st.text("Folder extraction complete!")
178
+
179
+ os.remove(output) # Optional: remove the downloaded zip file after extraction.
180
+ st.text("Download and extraction completed!")
181
+
182
+
183
+ def get_detected_boxes(txt_path, img_width, img_height):
184
+ with open(txt_path, 'r') as f:
185
+ lines = f.readlines()
186
+ boxes = []
187
+ for line in lines:
188
+ parts = list(map(float, line.strip().split()))
189
+ # Assuming the format is: class x_center y_center width height confidence
190
+ x_center, y_center, width, height = parts[1:5]
191
+
192
+ # Convert to pixel values
193
+ x_center *= img_width
194
+ y_center *= img_height
195
+ width *= img_width
196
+ height *= img_height
197
+
198
+ boxes.append([x_center, y_center, width, height])
199
+
200
+ return boxes
201
+
202
+
203
+
204
+ def process_with_yolo(img_pil):
205
+
206
+
207
+ with st.spinner('Downloading YOLOv3 folder...'):
208
+ download_ultralytics_yolov3_folder()
209
+
210
+ # display_files_in_directory('/home/user/app/')
211
+ # Save the image to a temporary file
212
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
213
+ img_pil.save(temp_file.name)
214
+
215
+ if os.path.exists('yolov3/runs/detect/mainlinedection'):
216
+ shutil.rmtree('yolov3/runs/detect/mainlinedection')
217
+
218
+ cmd = [
219
+ 'python', 'yolov3/detect.py',
220
+ '--source', temp_file.name, # use the temp file path here
221
+ '--weights', 'yolov3/runs/train/mainline/weights/best.pt',
222
+ '--save-txt',
223
+ '--save-conf',
224
+ '--imgsz', '672',
225
+ '--name', 'mainlinedection'
226
+ ]
227
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
228
+
229
+ stdout, stderr = process.communicate()
230
+ if process.returncode != 0:
231
+ st.error(f"YOLOv3 command failed with code {process.returncode}.")
232
+
233
+
234
+
235
+
236
+ # After processing the image, at the end of the function...
237
+ output_path = os.path.join('yolov3/runs/detect/mainlinedection', os.path.basename(temp_file.name))
238
+ if os.path.exists(output_path):
239
+ return output_path
240
+ else:
241
+ st.error("Processed image not found!")
242
+ return None
243
+
244
+ # Optional: You can print the output to Streamlit, though it might be extensive.
245
+ st.write(stdout.decode())
246
+ if stderr:
247
+ st.error(stderr.decode())
248
+
249
+ # Optional: Close and delete the temporary file
250
+ temp_file.close()
251
+ os.unlink(temp_file.name)
252
+
253
+
254
+
255
+ # def display_detected_lines(original_path, output_path):
256
+ # # Derive the txt_path from the output_path
257
+ # txt_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(output_path).replace(".jpg", ".txt"))
258
+
259
+ # if os.path.exists(txt_path):
260
+ # # Use the original image for cropping, NOT the output_path
261
+ # original_image = Image.open(original_path)
262
+
263
+ # boxes = get_detected_boxes(txt_path, original_image.width, original_image.height)
264
+ # if not boxes:
265
+ # st.warning("No lines detected by YOLOv3.")
266
+ # return
267
+
268
+ # # Ensure the detected_lines directory exists, if not, create it
269
+ # detected_lines_dir = 'detected_lines'
270
+ # if not os.path.exists(detected_lines_dir):
271
+ # os.makedirs(detected_lines_dir)
272
+
273
+ # for index, box in enumerate(boxes):
274
+ # x_center, y_center, width, height = box
275
+ # x_min = int(x_center - (width / 2))
276
+ # y_min = int(y_center - (height / 2))
277
+ # x_max = int(x_center + (width / 2))
278
+ # y_max = int(y_center + (height / 2))
279
+
280
+ # extracted_line = original_image.crop((x_min, y_min, x_max, y_max))
281
+
282
+ # # Save the extracted line to the detected_lines directory
283
+ # extracted_line.save(os.path.join(detected_lines_dir, f"line_{index}.jpg"))
284
+
285
+ # st.image(extracted_line, caption="Detected Line", use_column_width=True)
286
+ # else:
287
+ # st.error("Annotation file (.txt) not found!")
288
+
289
+
290
+
291
+ def display_detected_lines(original_path, output_path):
292
+
293
+ # Derive the txt_path from the output_path
294
+ txt_path = os.path.join('yolov3/runs/detect/mainlinedection/labels', os.path.basename(output_path).replace(".jpg", ".txt"))
295
+
296
+ if os.path.exists(txt_path):
297
+ original_image = Image.open(original_path)
298
+ boxes = get_detected_boxes(txt_path, original_image.width, original_image.height)
299
+
300
+ if not boxes:
301
+ st.warning("No lines detected by YOLOv3.")
302
+ return
303
+
304
+ # Create a temporary directory to store the detected lines
305
+ with TemporaryDirectory() as temp_dir:
306
+
307
+ detected_line_paths = [] # List to store paths of the detected line images
308
+
309
+ for index, box in enumerate(boxes):
310
+ x_center, y_center, width, height = box
311
+ x_min = int(x_center - (width / 2))
312
+ y_min = int(y_center - (height / 2))
313
+ x_max = int(x_center + (width / 2))
314
+ y_max = int(y_center + (height / 2))
315
+
316
+ extracted_line = original_image.crop((x_min, y_min, x_max, y_max))
317
+
318
+ # Save the detected line image to the temporary directory
319
+ detected_line_path = os.path.join(temp_dir, f"detected_line_{index}.jpg")
320
+ extracted_line.save(detected_line_path)
321
+ detected_line_paths.append(detected_line_path)
322
+
323
+ # Perform OCR on detected lines
324
+ recognized_texts = perform_ocr_on_detected_lines(detected_line_paths)
325
+
326
+ # print("Decoded OCR Results:", recognized_texts)
327
+ # st.text(f"Detected Line: {recognized_texts}")
328
+
329
+
330
+ # Display the results
331
+ for img_path, text in zip(detected_line_paths, recognized_texts):
332
+ # st.image(img_path, caption=f"Detected Line: {text}", use_column_width=True)
333
+ st.image(img_path, use_column_width=True)
334
+
335
+ st.markdown(
336
+ f"<p style='font-size: 18px; font-weight: bold;'>{text}</p>",
337
+ unsafe_allow_html=True
338
+ )
339
+ # Add a small break for better spacing
340
+ st.markdown("<br>", unsafe_allow_html=True)
341
+
342
+ else:
343
+ st.error("Annotation file (.txt) not found!")
344
+
345
+
346
+
347
+
348
+ def crop_detected_line(img_path, bounding_box):
349
+ with Image.open(img_path) as img:
350
+ cropped_img = img.crop(bounding_box)
351
+ return cropped_img
352
+
353
+
354
+ def perform_ocr_on_detected_lines(detected_line_paths):
355
+ """
356
+ Performs OCR on the provided list of detected line image paths.
357
+
358
+ Args:
359
+ - detected_line_paths: List of paths to the detected line images.
360
+
361
+ Returns:
362
+ - A list of recognized text for each image.
363
+ """
364
+
365
+ # Load the saved detected lines for OCR processing
366
+ test_dataset = Loadlines(detected_line_paths)
367
+
368
+ prediction_texts = []
369
+ for batch in test_dataset:
370
+ preds = ocr_model(batch)
371
+ pred_texts = decode_predictions(preds)
372
+ # st.text(f"Decoded OCR Results : {pred_texts}")
373
+ prediction_texts.extend(pred_texts)
374
+
375
+ return prediction_texts
376
+
377
+
378
+ if __name__ == "__main__":
379
+ main()
requirements(1).txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ tensorflow
4
+ opencv-python
5
+
6
+ Pillow
7
+ gdown
8
+ PyYAML
9
+ requests
10
+ scipy
11
+ tqdm
12
+
13
+ seaborn
text_line_model(1).h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d7ed7446227c627fd5ff11d4950ec02db8b77d9c9ef87a0472f826ae086377
3
+ size 4309376