import atexit import base64 import hashlib import os import uuid import numpy as np from apscheduler.schedulers.background import BackgroundScheduler from flask import Flask, jsonify, request, logging as flog from flask_limiter.util import get_remote_address from ultralytics import YOLO app = Flask(__name__) onnx_model = YOLO('numbers_yolov8s.onnx', task="detect") cls_type_array = np.array( ['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '3', '4', '5', '6', '7', '8', '9', 'equal', '-', '+']) def load_model(): global onnx_model, cls_type_array onnx_model = YOLO('numbers_yolov8s.onnx', task="detect") cls_type_array = np.array( ['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '3', '4', '5', '6', '7', '8', '9', 'equal', '-', '+']) scheduler = BackgroundScheduler() scheduler.add_job(func=load_model, trigger="interval", seconds=3600) scheduler.start() def shutdown_scheduler(): scheduler.shutdown() atexit.register(shutdown_scheduler) def generate_hashed_uuid(): # 生成随机UUID random_uuid = uuid.uuid4() # 将UUID转换为字符串 str_uuid = str(random_uuid) # 对UUID进行哈希 hash_object = hashlib.sha256(str_uuid.encode()) # 将哈希对象转换为十六进制字符串 hex_dig = hash_object.hexdigest() return hex_dig def ocr_png(filename): # 打开图片 results = onnx_model(filename)[0] # 获取cls和对应的左上角x坐标,并将cls转换为整数 cls_and_x = np.array([(int(cls.item()), box[0].item()) for cls, box in zip(results.boxes.cls, results.boxes.xyxy)]) # 过滤出cls小于20的结果,并根据左上角x坐标从小到大排序 sorted_cls_and_x = cls_and_x[cls_and_x[:, 0] < 20] sorted_cls_and_x = sorted_cls_and_x[sorted_cls_and_x[:, 1].argsort()] # 使用NumPy的向量化操作获取排序后的cls列表 sorted_cls = cls_type_array[sorted_cls_and_x[:, 0].astype(int)] # print(sorted_cls) # 对于cls为21和22,进行同样的操作 sorted_cls_and_x_21_22 = cls_and_x[np.isin(cls_and_x[:, 0], [21, 22])] sorted_cls_and_x_21_22 = sorted_cls_and_x_21_22[sorted_cls_and_x_21_22[:, 1].argsort()] sorted_cls_21_22 = cls_type_array[sorted_cls_and_x_21_22[:, 0].astype(int)] # print(sorted_cls_21_22) result = sorted_cls[0] + sorted_cls_21_22[0] + sorted_cls[1] + sorted_cls_21_22[1] + sorted_cls[2] return {"ocr": result, "result": eval(result)} def get_ipaddr(): if request.access_route: print(request.access_route[0]) return request.access_route[0] else: return request.remote_addr or '127.0.0.1' handler = flog.default_handler def get_token(): default_token = "init_token" if os.path.exists("token"): return open("token", "r").read().strip() return default_token def check_request(required_data, data): token = get_token() if not data or any(key not in data for key in required_data): print("Error:Invalid Request Data\n" + str(data)) return False if data["token"] != token: print("Error:Invalid Token\n" + str(data)) return False return True @app.errorhandler(429) def rate_limit_exceeded(e): print(get_remote_address()) return jsonify(msg="Too many request"), 429 @app.errorhandler(405) def method_not_allowed(e): print(get_remote_address()) return jsonify(msg="Unauthorized Request"), 405 @app.route("/", methods=["GET"]) def index(): return jsonify(status_code=200, ip=get_ipaddr()) @app.route("/update/token", methods=["POST"]) def update_token(): require_data = ["token", "new_token"] data = request.get_json(force=True, silent=True) if not check_request(require_data, data): return jsonify(msg="Unauthorized Request"), 403 token = open("token", "w+") token.write(data["new_token"]) token.close() return jsonify(msg="Token updated successfully", success=True) @app.route("/api/solve", methods=["POST"]) def solver_captcha(): require_data = ["token", "data"] data = request.get_json(force=True, silent=True) if not check_request(require_data, data): return jsonify(msg="Unauthorized Request"), 403 file_name = generate_hashed_uuid() try: image_data = base64.b64decode(data["data"]) with open(f"{file_name}.png", "wb") as f: f.write(image_data) resp = ocr_png(f"{file_name}.png") return resp except Exception as e: print(e) return "error", 500 finally: os.remove(f"{file_name}.png") app.run(host="0.0.0.0", port=8081)