File size: 6,747 Bytes
a8f9af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from flask import Flask, request, render_template, redirect, send_from_directory
import torch
import os
import cv2
import numpy as np
from ultralytics import YOLO
from concurrent.futures import ThreadPoolExecutor, as_completed
from fuzzywuzzy import fuzz
from transformers import VisionEncoderDecoderModel, TrOCRProcessor, AutoTokenizer, ViTImageProcessor, NllbTokenizer
import unicodedata
import time
from multiprocessing import cpu_count

app = Flask(__name__)
UPLOAD_FOLDER = 'uploads'
RESULT_FOLDER = 'results'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['RESULT_FOLDER'] = RESULT_FOLDER

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.set_default_dtype(torch.float16)  # Set the default dtype to float16
    torch.set_default_device(device)  # Set the default device to the CUDA device
    torch.backends.cudnn.benchmark = True

# Load detection model
detection_model = YOLO('train34/best.pt').to(device)
detection_model.half()  # Ensure the model is in half precision

# Load recognition model
recognition_model = VisionEncoderDecoderModel.from_pretrained('fine-tuned-small-printed-V2-checkpoint-with100000data/checkpoint-170160').to(device)
recognition_model.eval()
if torch.cuda.is_available():
    recognition_model.half()



tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

feature_extractor = ViTImageProcessor.from_pretrained('fine-tuned-small-printed-V2-checkpoint-with100000data/checkpoint-170160')

processor = TrOCRProcessor(image_processor=feature_extractor, tokenizer=tokenizer)


# processor = TrOCRProcessor.from_pretrained('fine-tuned-small-printed-V2-checkpoint-with100000data/checkpoint-170160')


if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)
if not os.path.exists(RESULT_FOLDER):
    os.makedirs(RESULT_FOLDER)

def normalize_text(text):
    return unicodedata.normalize('NFC', text)

def preprocess_image(image_path):
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Image not found at path: {image_path}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB
    image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
    return image

def detect_text(image, model):
    results = model(image)
    boxes = [box for box, conf in zip(results[0].boxes.xyxy.cpu().numpy(), results[0].boxes.conf.cpu().numpy()) if conf >= 0.3]
    return boxes

def calculate_iou(box1, box2):
    x1, y1, x2, y2 = box1
    x3, y3, x4, y4 = box2
    xi1 = max(x1, x3)
    yi1 = max(y1, y3)
    xi2 = min(x2, x4)
    yi2 = min(y2, y4)
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x4 - x3) * (y4 - y3)
    union_area = box1_area + box2_area - inter_area
    return inter_area / union_area if union_area > 0 else 0

def group_boxes_by_lines(boxes, iou_threshold=0.5):
    lines = []
    boxes = [list(box) for box in boxes]
    while boxes:
        current_box = boxes.pop(0)
        line_group = [current_box]
        for other_box in boxes[:]:
            if calculate_iou(current_box, other_box) > iou_threshold:
                line_group.append(other_box)
                boxes.remove(other_box)
        lines.append(line_group)
    return lines

def concatenate_boxes(line_group):
    x_min = min(box[0] for box in line_group)
    y_min = min(box[1] for box in line_group)
    x_max = max(box[2] for box in line_group)
    y_max = max(box[3] for box in line_group)
    return [x_min, y_min, x_max, y_max]

def recognize_text_by_lines(image, line_groups, model, processor, target_word, threshold=70):
    detected_boxes = []
    for line_group in line_groups:
        line_box = concatenate_boxes(line_group)
        x1, y1, x2, y2 = map(int, line_box)
        roi = image[y1:y2, x1:x2]
        roi = cv2.resize(roi, (384, 384))
        pixel_values = processor(images=roi, return_tensors="pt").pixel_values.to(device)
        generated_ids = model.generate(pixel_values, max_new_tokens=50)
        recognized_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        recognized_text = normalize_text(recognized_text)
        if fuzz.partial_ratio(recognized_text, target_word) >= threshold:
            detected_boxes.append((line_box, target_word))
    return detected_boxes

def detect_and_recognize(image, target_word, threshold):
    detections = detect_text(image, detection_model)
    line_groups = group_boxes_by_lines(detections)
    with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
        futures = [executor.submit(recognize_text_by_lines, image, [line_group], recognition_model, processor, target_word, threshold) for line_group in line_groups]
        matching_detections = [result for future in as_completed(futures) for result in future.result()]
    return matching_detections

def draw_highlight_boxes(image, detections):
    if not detections:
        print("No matching words found.")
    for box, _ in detections:
        x1, y1, x2, y2 = map(int, box)
        overlay = image.copy()
        cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), -1)
        alpha = 0.4
        cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
    return image

def process_image(image_path, target_word, threshold):
    image = preprocess_image(image_path)
    matching_detections = detect_and_recognize(image, target_word, threshold)
    result_image = draw_highlight_boxes(image.copy(), matching_detections)
    result_image_path = os.path.join(app.config['RESULT_FOLDER'], os.path.basename(image_path))
    cv2.imwrite(result_image_path, cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR))
    return result_image_path

@app.route('/', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        if 'file' not in request.files or request.files['file'].filename == '':
            return redirect(request.url)
        file = request.files['file']
        target_word = request.form['target_word']
        if file:
            filename = file.filename
            filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
            file.save(filepath)
            start_time = time.time()
            result_image_path = process_image(filepath, target_word, 70)
            end_time = time.time()
            print(f"Processing time: {end_time - start_time} seconds")
            return render_template('index.html', filename=os.path.basename(result_image_path), target_word=target_word)
    return render_template('index.html')

@app.route('/result/<filename>')
def result(filename):
    return send_from_directory(app.config['RESULT_FOLDER'], filename)

if __name__ == '__main__':
    app.run(debug=True)