Spaces:
Sleeping
Sleeping
from flask import Flask, request, render_template, redirect, send_from_directory | |
import torch | |
import os | |
import cv2 | |
import numpy as np | |
from ultralytics import YOLO | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from fuzzywuzzy import fuzz | |
from transformers import VisionEncoderDecoderModel, TrOCRProcessor, AutoTokenizer, ViTImageProcessor, NllbTokenizer | |
import unicodedata | |
import time | |
from multiprocessing import cpu_count | |
app = Flask(__name__) | |
UPLOAD_FOLDER = 'uploads' | |
RESULT_FOLDER = 'results' | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
app.config['RESULT_FOLDER'] = RESULT_FOLDER | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
torch.set_default_dtype(torch.float16) # Set the default dtype to float16 | |
torch.set_default_device(device) # Set the default device to the CUDA device | |
torch.backends.cudnn.benchmark = True | |
# Load detection model | |
detection_model = YOLO('train34/best.pt').to(device) | |
detection_model.half() # Ensure the model is in half precision | |
# Load recognition model | |
recognition_model = VisionEncoderDecoderModel.from_pretrained('fine-tuned-small-printed-V2-checkpoint-with100000data/checkpoint-170160').to(device) | |
recognition_model.eval() | |
if torch.cuda.is_available(): | |
recognition_model.half() | |
tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
feature_extractor = ViTImageProcessor.from_pretrained('fine-tuned-small-printed-V2-checkpoint-with100000data/checkpoint-170160') | |
processor = TrOCRProcessor(image_processor=feature_extractor, tokenizer=tokenizer) | |
# processor = TrOCRProcessor.from_pretrained('fine-tuned-small-printed-V2-checkpoint-with100000data/checkpoint-170160') | |
if not os.path.exists(UPLOAD_FOLDER): | |
os.makedirs(UPLOAD_FOLDER) | |
if not os.path.exists(RESULT_FOLDER): | |
os.makedirs(RESULT_FOLDER) | |
def normalize_text(text): | |
return unicodedata.normalize('NFC', text) | |
def preprocess_image(image_path): | |
image = cv2.imread(image_path) | |
if image is None: | |
raise ValueError(f"Image not found at path: {image_path}") | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to RGB | |
image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21) | |
return image | |
def detect_text(image, model): | |
results = model(image) | |
boxes = [box for box, conf in zip(results[0].boxes.xyxy.cpu().numpy(), results[0].boxes.conf.cpu().numpy()) if conf >= 0.3] | |
return boxes | |
def calculate_iou(box1, box2): | |
x1, y1, x2, y2 = box1 | |
x3, y3, x4, y4 = box2 | |
xi1 = max(x1, x3) | |
yi1 = max(y1, y3) | |
xi2 = min(x2, x4) | |
yi2 = min(y2, y4) | |
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1) | |
box1_area = (x2 - x1) * (y2 - y1) | |
box2_area = (x4 - x3) * (y4 - y3) | |
union_area = box1_area + box2_area - inter_area | |
return inter_area / union_area if union_area > 0 else 0 | |
def group_boxes_by_lines(boxes, iou_threshold=0.5): | |
lines = [] | |
boxes = [list(box) for box in boxes] | |
while boxes: | |
current_box = boxes.pop(0) | |
line_group = [current_box] | |
for other_box in boxes[:]: | |
if calculate_iou(current_box, other_box) > iou_threshold: | |
line_group.append(other_box) | |
boxes.remove(other_box) | |
lines.append(line_group) | |
return lines | |
def concatenate_boxes(line_group): | |
x_min = min(box[0] for box in line_group) | |
y_min = min(box[1] for box in line_group) | |
x_max = max(box[2] for box in line_group) | |
y_max = max(box[3] for box in line_group) | |
return [x_min, y_min, x_max, y_max] | |
def recognize_text_by_lines(image, line_groups, model, processor, target_word, threshold=70): | |
detected_boxes = [] | |
for line_group in line_groups: | |
line_box = concatenate_boxes(line_group) | |
x1, y1, x2, y2 = map(int, line_box) | |
roi = image[y1:y2, x1:x2] | |
roi = cv2.resize(roi, (384, 384)) | |
pixel_values = processor(images=roi, return_tensors="pt").pixel_values.to(device) | |
generated_ids = model.generate(pixel_values, max_new_tokens=50) | |
recognized_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
recognized_text = normalize_text(recognized_text) | |
if fuzz.partial_ratio(recognized_text, target_word) >= threshold: | |
detected_boxes.append((line_box, target_word)) | |
return detected_boxes | |
def detect_and_recognize(image, target_word, threshold): | |
detections = detect_text(image, detection_model) | |
line_groups = group_boxes_by_lines(detections) | |
with ThreadPoolExecutor(max_workers=cpu_count()) as executor: | |
futures = [executor.submit(recognize_text_by_lines, image, [line_group], recognition_model, processor, target_word, threshold) for line_group in line_groups] | |
matching_detections = [result for future in as_completed(futures) for result in future.result()] | |
return matching_detections | |
def draw_highlight_boxes(image, detections): | |
if not detections: | |
print("No matching words found.") | |
for box, _ in detections: | |
x1, y1, x2, y2 = map(int, box) | |
overlay = image.copy() | |
cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), -1) | |
alpha = 0.4 | |
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) | |
return image | |
def process_image(image_path, target_word, threshold): | |
image = preprocess_image(image_path) | |
matching_detections = detect_and_recognize(image, target_word, threshold) | |
result_image = draw_highlight_boxes(image.copy(), matching_detections) | |
result_image_path = os.path.join(app.config['RESULT_FOLDER'], os.path.basename(image_path)) | |
cv2.imwrite(result_image_path, cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)) | |
return result_image_path | |
def index(): | |
if request.method == 'POST': | |
if 'file' not in request.files or request.files['file'].filename == '': | |
return redirect(request.url) | |
file = request.files['file'] | |
target_word = request.form['target_word'] | |
if file: | |
filename = file.filename | |
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
file.save(filepath) | |
start_time = time.time() | |
result_image_path = process_image(filepath, target_word, 70) | |
end_time = time.time() | |
print(f"Processing time: {end_time - start_time} seconds") | |
return render_template('index.html', filename=os.path.basename(result_image_path), target_word=target_word) | |
return render_template('index.html') | |
def result(filename): | |
return send_from_directory(app.config['RESULT_FOLDER'], filename) | |
if __name__ == '__main__': | |
app.run(debug=True) | |