Spaces:
Paused
Paused
import os | |
import numpy as np | |
import yolov5 | |
from fastapi import File, UploadFile, status | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.responses import JSONResponse, RedirectResponse | |
from PIL import Image | |
from vietocr.tool.config import Cfg | |
from vietocr.tool.predictor import Predictor | |
import mmcv | |
from io import BytesIO | |
import sources.Controllers.config as cfg | |
from sources import app | |
from sources.Controllers import utils | |
# Init yolov5 model | |
CORNER_MODEL = yolov5.load(cfg.CORNER_MODEL_PATH) | |
CONTENT_MODEL = yolov5.load(cfg.CONTENT_MODEL_PATH) | |
FACE_MODEL = yolov5.load(cfg.FACE_MODEL_PATH) | |
# Set conf and iou threshold -> Remove overlap and low confident bounding boxes | |
CONTENT_MODEL.conf = cfg.CONF_CONTENT_THRESHOLD | |
CONTENT_MODEL.iou = cfg.IOU_CONTENT_THRESHOLD | |
# CORNER_MODEL.conf = cfg.CONF_CORNER_THRESHOLD | |
# CORNER_MODEL.iou = cfg.IOU_CORNER_THRESHOLD | |
# Config directory | |
UPLOAD_FOLDER = cfg.UPLOAD_FOLDER | |
SAVE_DIR = cfg.SAVE_DIR | |
FACE_CROP_DIR = cfg.FACE_DIR | |
""" ---- ##### -----""" | |
""" Recognizion detected parts in ID """ | |
config = Cfg.load_config_from_name( | |
"vgg_seq2seq" | |
) # OR vgg_transformer -> acc || vgg_seq2seq -> time | |
# config = Cfg.load_config_from_file(cfg.OCR_CFG) | |
# config['weights'] = cfg.OCR_MODEL_PATH | |
config["cnn"]["pretrained"] = False | |
config["device"] = cfg.DEVICE | |
config["predictor"]["beamsearch"] = False | |
detector = Predictor(config) | |
async def index(): | |
return RedirectResponse(url="/docs") | |
async def upload(file: UploadFile = File(...)): | |
img_bytes = await file.read() | |
img = mmcv.imfrombytes(img_bytes) | |
CORNER = CORNER_MODEL(img) | |
CORNER.save(save_dir="results/") | |
predictions = CORNER.pred[0] | |
categories = predictions[:, 5].tolist() # Class | |
if len(categories) != 4: | |
error = "Detecting corner failed!" | |
return JSONResponse(status_code=401, content={"message": error}) | |
boxes = utils.class_Order(predictions[:, :4].tolist(), categories) | |
IMG = Image.open(BytesIO(img_bytes)) | |
center_points = list(map(utils.get_center_point, boxes)) | |
""" Temporary fixing """ | |
c2, c3 = center_points[2], center_points[3] | |
c2_fix, c3_fix = (c2[0], c2[1] + 30), (c3[0], c3[1] + 30) | |
center_points = [center_points[0], center_points[1], c2_fix, c3_fix] | |
center_points = np.asarray(center_points) | |
aligned = utils.four_point_transform(IMG, center_points) | |
aligned = Image.fromarray(aligned) | |
CONTENT = CONTENT_MODEL(aligned) | |
predictions = CONTENT.pred[0] | |
categories = predictions[:, 5].tolist() | |
if 7 not in categories: | |
if len(categories) < 9: | |
error = "Missing fields! Detecting content failed!" | |
return JSONResponse( | |
status_code=status.HTTP_400_BAD_REQUEST, content={"message": error} | |
) | |
elif 7 in categories: | |
if len(categories) < 10: | |
error = "Missing fields! Detecting content failed!" | |
return JSONResponse( | |
status_code=status.HTTP_400_BAD_REQUEST, content={"message": error} | |
) | |
boxes = predictions[:, :4].tolist() | |
""" Non Maximum Suppression """ | |
boxes, categories = utils.non_max_suppression_fast(np.array(boxes), categories, 0.7) | |
boxes = utils.class_Order(boxes, categories) | |
if not os.path.isdir(SAVE_DIR): | |
os.mkdir(SAVE_DIR) | |
else: | |
for f in os.listdir(SAVE_DIR): | |
os.remove(os.path.join(SAVE_DIR, f)) | |
for index, box in enumerate(boxes): | |
left, top, right, bottom = box | |
if 5 < index < 9: | |
right = right + 100 | |
cropped_image = aligned.crop((left, top, right, bottom)) | |
cropped_image.save(os.path.join(SAVE_DIR, f"{index}.jpg")) | |
FIELDS_DETECTED = [] | |
for idx, img_crop in enumerate(sorted(os.listdir(SAVE_DIR))): | |
if idx > 0: | |
img_ = Image.open(os.path.join(SAVE_DIR, img_crop)) | |
s = detector.predict(img_) | |
FIELDS_DETECTED.append(s) | |
if 7 in categories: | |
FIELDS_DETECTED = ( | |
FIELDS_DETECTED[:6] | |
+ [FIELDS_DETECTED[6] + ", " + FIELDS_DETECTED[7]] | |
+ [FIELDS_DETECTED[8]] | |
) | |
response = {"data": FIELDS_DETECTED} | |
response = jsonable_encoder(response) | |
return JSONResponse( | |
content=response, | |
) | |