phuochungus's picture
Add redirect to index route
eb371a0
raw
history blame contribute delete
No virus
4.31 kB
import os
import numpy as np
import yolov5
from fastapi import File, UploadFile, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, RedirectResponse
from PIL import Image
from vietocr.tool.config import Cfg
from vietocr.tool.predictor import Predictor
import mmcv
from io import BytesIO
import sources.Controllers.config as cfg
from sources import app
from sources.Controllers import utils
# Init yolov5 model
CORNER_MODEL = yolov5.load(cfg.CORNER_MODEL_PATH)
CONTENT_MODEL = yolov5.load(cfg.CONTENT_MODEL_PATH)
FACE_MODEL = yolov5.load(cfg.FACE_MODEL_PATH)
# Set conf and iou threshold -> Remove overlap and low confident bounding boxes
CONTENT_MODEL.conf = cfg.CONF_CONTENT_THRESHOLD
CONTENT_MODEL.iou = cfg.IOU_CONTENT_THRESHOLD
# CORNER_MODEL.conf = cfg.CONF_CORNER_THRESHOLD
# CORNER_MODEL.iou = cfg.IOU_CORNER_THRESHOLD
# Config directory
UPLOAD_FOLDER = cfg.UPLOAD_FOLDER
SAVE_DIR = cfg.SAVE_DIR
FACE_CROP_DIR = cfg.FACE_DIR
""" ---- ##### -----"""
""" Recognizion detected parts in ID """
config = Cfg.load_config_from_name(
"vgg_seq2seq"
) # OR vgg_transformer -> acc || vgg_seq2seq -> time
# config = Cfg.load_config_from_file(cfg.OCR_CFG)
# config['weights'] = cfg.OCR_MODEL_PATH
config["cnn"]["pretrained"] = False
config["device"] = cfg.DEVICE
config["predictor"]["beamsearch"] = False
detector = Predictor(config)
@app.get("/")
async def index():
return RedirectResponse(url="/docs")
@app.post(
"/uploader",
)
async def upload(file: UploadFile = File(...)):
img_bytes = await file.read()
img = mmcv.imfrombytes(img_bytes)
CORNER = CORNER_MODEL(img)
CORNER.save(save_dir="results/")
predictions = CORNER.pred[0]
categories = predictions[:, 5].tolist() # Class
if len(categories) != 4:
error = "Detecting corner failed!"
return JSONResponse(status_code=401, content={"message": error})
boxes = utils.class_Order(predictions[:, :4].tolist(), categories)
IMG = Image.open(BytesIO(img_bytes))
center_points = list(map(utils.get_center_point, boxes))
""" Temporary fixing """
c2, c3 = center_points[2], center_points[3]
c2_fix, c3_fix = (c2[0], c2[1] + 30), (c3[0], c3[1] + 30)
center_points = [center_points[0], center_points[1], c2_fix, c3_fix]
center_points = np.asarray(center_points)
aligned = utils.four_point_transform(IMG, center_points)
aligned = Image.fromarray(aligned)
CONTENT = CONTENT_MODEL(aligned)
predictions = CONTENT.pred[0]
categories = predictions[:, 5].tolist()
if 7 not in categories:
if len(categories) < 9:
error = "Missing fields! Detecting content failed!"
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, content={"message": error}
)
elif 7 in categories:
if len(categories) < 10:
error = "Missing fields! Detecting content failed!"
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, content={"message": error}
)
boxes = predictions[:, :4].tolist()
""" Non Maximum Suppression """
boxes, categories = utils.non_max_suppression_fast(np.array(boxes), categories, 0.7)
boxes = utils.class_Order(boxes, categories)
if not os.path.isdir(SAVE_DIR):
os.mkdir(SAVE_DIR)
else:
for f in os.listdir(SAVE_DIR):
os.remove(os.path.join(SAVE_DIR, f))
for index, box in enumerate(boxes):
left, top, right, bottom = box
if 5 < index < 9:
right = right + 100
cropped_image = aligned.crop((left, top, right, bottom))
cropped_image.save(os.path.join(SAVE_DIR, f"{index}.jpg"))
FIELDS_DETECTED = []
for idx, img_crop in enumerate(sorted(os.listdir(SAVE_DIR))):
if idx > 0:
img_ = Image.open(os.path.join(SAVE_DIR, img_crop))
s = detector.predict(img_)
FIELDS_DETECTED.append(s)
if 7 in categories:
FIELDS_DETECTED = (
FIELDS_DETECTED[:6]
+ [FIELDS_DETECTED[6] + ", " + FIELDS_DETECTED[7]]
+ [FIELDS_DETECTED[8]]
)
response = {"data": FIELDS_DETECTED}
response = jsonable_encoder(response)
return JSONResponse(
content=response,
)