Spaces:
Runtime error
Runtime error
Binarization in fastapi side
Browse files- binarization .py +107 -0
- nougat_api_app.py +48 -8
- predict.ipynb +2 -2
binarization .py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import cv2
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def integral(img):
|
7 |
+
'''
|
8 |
+
计算图像的积分和平方积分
|
9 |
+
:param img:Mat--- 输入待处理图像
|
10 |
+
:return:integral_sum, integral_sqrt_sum:Mat--- 积分图和平方积分图
|
11 |
+
'''
|
12 |
+
integral_sum=np.zeros((img.shape[0],img.shape[1]),dtype=np.int32)
|
13 |
+
integral_sqrt_sum=np.zeros((img.shape[0],img.shape[1]),dtype=np.int32)
|
14 |
+
|
15 |
+
rows,cols=img.shape
|
16 |
+
for r in range(rows):
|
17 |
+
sum=0
|
18 |
+
sqrt_sum=0
|
19 |
+
for c in range(cols):
|
20 |
+
sum+=img[r][c]
|
21 |
+
sqrt_sum+=math.sqrt(img[r][c])
|
22 |
+
|
23 |
+
if r==0:
|
24 |
+
integral_sum[r][c]=sum
|
25 |
+
integral_sqrt_sum[r][c]=sqrt_sum
|
26 |
+
else:
|
27 |
+
integral_sum[r][c]=sum+integral_sum[r-1][c]
|
28 |
+
integral_sqrt_sum[r][c]=sqrt_sum+integral_sqrt_sum[r-1][c]
|
29 |
+
|
30 |
+
return integral_sum, integral_sqrt_sum
|
31 |
+
|
32 |
+
def sauvola(img,k=0.1,kernerl=(31,31)):
|
33 |
+
'''
|
34 |
+
sauvola阈值法。
|
35 |
+
根据当前像素点邻域内的灰度均值与标准方差来动态计算该像素点的阈值
|
36 |
+
:param img:Mat--- 输入待处理图像
|
37 |
+
:param k:float---修正参数,一般0<k<1
|
38 |
+
:param kernerl:set---窗口大小
|
39 |
+
:return:img:Mat---阈值处理后的图像
|
40 |
+
'''
|
41 |
+
if kernerl[0]%2!=1 or kernerl[1]%2!=1:
|
42 |
+
raise ValueError('kernerl元组中的值必须为奇数, 请检查kernerl[0] or kernerl[1]是否为奇数!!!')
|
43 |
+
|
44 |
+
# 计算积分图和积分平方和图
|
45 |
+
integral_sum,integral_sqrt_sum=integral(img)
|
46 |
+
# integral_sum, integral_sqrt_sum = cv2.integral2(img)
|
47 |
+
# integral_sum=integral_sum[1:integral_sum.shape[0],1:integral_sum.shape[1]]
|
48 |
+
# integral_sqrt_sum=integral_sqrt_sum[1:integral_sqrt_sum.shape[0],1:integral_sqrt_sum.shape[1]]
|
49 |
+
|
50 |
+
#创建图像
|
51 |
+
rows,cols=img.shape
|
52 |
+
diff=np.zeros((rows,cols),np.float32)
|
53 |
+
sqrt_diff=np.zeros((rows,cols),np.float32)
|
54 |
+
mean=np.zeros((rows,cols),np.float32)
|
55 |
+
threshold=np.zeros((rows,cols),np.float32)
|
56 |
+
std=np.zeros((rows,cols),np.float32)
|
57 |
+
|
58 |
+
whalf=kernerl[0]>>1#计算领域类半径的一半
|
59 |
+
|
60 |
+
for row in range(rows):
|
61 |
+
#print('第{}行处理中...'.format(row))
|
62 |
+
for col in range(cols):
|
63 |
+
xmin=max(0,row-whalf)
|
64 |
+
ymin=max(0,col-whalf)
|
65 |
+
xmax=min(rows-1,row+whalf)
|
66 |
+
ymax=min(cols-1,col+whalf)
|
67 |
+
|
68 |
+
area=(xmax-xmin+1)*(ymax-ymin+1)
|
69 |
+
if area<=0:
|
70 |
+
sys.exit(1)
|
71 |
+
|
72 |
+
if xmin==0 and ymin==0:
|
73 |
+
diff[row,col]=integral_sum[xmax,ymax]
|
74 |
+
sqrt_diff[row,col]=integral_sqrt_sum[xmax,ymax]
|
75 |
+
elif xmin>0 and ymin==0:
|
76 |
+
diff[row, col] = integral_sum[xmax, ymax]-integral_sum[xmin-1,ymax]
|
77 |
+
sqrt_diff[row, col] = integral_sqrt_sum[xmax, ymax]-integral_sqrt_sum[xmin-1, ymax]
|
78 |
+
elif xmin==0 and ymin>0:
|
79 |
+
diff[row, col] = integral_sum[xmax, ymax] - integral_sum[xmax, ymax-1]
|
80 |
+
sqrt_diff[row, col] = integral_sqrt_sum[xmax, ymax] - integral_sqrt_sum[xmax, ymax-1]
|
81 |
+
else:
|
82 |
+
diagsum=integral_sum[xmax, ymax]+integral_sum[xmin-1, ymin-1]
|
83 |
+
idiagsum=integral_sum[xmax, ymin-1]+integral_sum[xmin-1, ymax]
|
84 |
+
diff[row,col]=diagsum-idiagsum
|
85 |
+
|
86 |
+
sqdiagsum=integral_sqrt_sum[xmax, ymax]+integral_sqrt_sum[xmin-1, ymin-1]
|
87 |
+
sqidiagsum=integral_sqrt_sum[xmax, ymin-1]+integral_sqrt_sum[xmin-1, ymax]
|
88 |
+
sqrt_diff[row,col]=sqdiagsum-sqidiagsum
|
89 |
+
|
90 |
+
mean[row,col]=diff[row, col]/area
|
91 |
+
std[row,col]=math.sqrt((sqrt_diff[row,col]-math.sqrt(diff[row,col])/area)/(area-1))
|
92 |
+
threshold[row,col]=mean[row,col]*(1+k*((std[row,col]/128)-1))
|
93 |
+
|
94 |
+
if img[row,col]<threshold[row,col]:
|
95 |
+
img[row,col]=0
|
96 |
+
else:
|
97 |
+
img[row,col]=255
|
98 |
+
|
99 |
+
return img
|
100 |
+
|
101 |
+
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
102 |
+
# return Image.fromarray(img)
|
103 |
+
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
104 |
+
|
105 |
+
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
106 |
+
# return np.asarray(img)
|
107 |
+
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
nougat_api_app.py
CHANGED
@@ -11,7 +11,7 @@ print('GPU Device name:', torch.cuda.get_device_name(torch.cuda.current_device()
|
|
11 |
import sys
|
12 |
from functools import partial
|
13 |
from http import HTTPStatus
|
14 |
-
|
15 |
from fastapi import FastAPI, File, UploadFile, Request,Response, BackgroundTasks, HTTPException
|
16 |
from fastapi import APIRouter, Depends
|
17 |
import os
|
@@ -70,8 +70,9 @@ from datetime import datetime
|
|
70 |
from sql_app.db import Base
|
71 |
import psycopg2
|
72 |
import numpy as np
|
|
|
73 |
|
74 |
-
logging.basicConfig(filename='info.log', level=logging.
|
75 |
#logger = logging.getLogger()
|
76 |
#logger.setLevel(logging.INFO)
|
77 |
|
@@ -88,6 +89,9 @@ global selected_model_name
|
|
88 |
|
89 |
# Load the ML model
|
90 |
def loadModel(checkpoint):
|
|
|
|
|
|
|
91 |
model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
|
92 |
if torch.cuda.is_available():
|
93 |
model.to("cuda")
|
@@ -242,6 +246,19 @@ async def app_middleware(request: Request, call_next):
|
|
242 |
headers=dict(response.headers), media_type=response.media_type, background=task)
|
243 |
|
244 |
''' debug code for , not test yet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
from loguru import logger
|
246 |
from starlette.routing import Match
|
247 |
|
@@ -401,6 +418,10 @@ def predict_image(model_name, images, batchsize=1, markdown=True, out_path=""):
|
|
401 |
logging.info("we are under image to mmd convertiong")
|
402 |
#sample = Image.open(images.name).convert('RGB')
|
403 |
sample = images.convert('RGB')
|
|
|
|
|
|
|
|
|
404 |
im_new = resize_with_padding(sample, (672,896))
|
405 |
img_tensor = prepare(im_new,random_padding=False)
|
406 |
img_tensor = img_tensor.unsqueeze(0)
|
@@ -422,7 +443,7 @@ def predict_image(model_name, images, batchsize=1, markdown=True, out_path=""):
|
|
422 |
out = out.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
|
423 |
f.write(out)
|
424 |
else:
|
425 |
-
logging.debug(out
|
426 |
|
427 |
return model_output, [out_path]
|
428 |
|
@@ -523,6 +544,7 @@ def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
|
523 |
async def predict(
|
524 |
request: Request,
|
525 |
selectedModel: str = Form(...),
|
|
|
526 |
file: UploadFile = File(...),
|
527 |
start: int = None, stop: int = None,
|
528 |
):
|
@@ -549,6 +571,7 @@ async def predict(
|
|
549 |
#parsed_url = urlparse(request)
|
550 |
#model_name = parse_qs(parsed_url.query)['model'][0]
|
551 |
model_name = selectedModel
|
|
|
552 |
if model_name == None:
|
553 |
model = nougatModel
|
554 |
else:
|
@@ -573,7 +596,7 @@ async def predict(
|
|
573 |
try:
|
574 |
with open(dest, 'wb') as f:
|
575 |
logging.info(f"save uploading files to {dest}")
|
576 |
-
imgbin = await file.read()
|
577 |
f.write(imgbin)
|
578 |
md5 = hashlib.md5(imgbin).hexdigest()
|
579 |
finger_printer = md5
|
@@ -584,8 +607,16 @@ async def predict(
|
|
584 |
f.close()
|
585 |
#logging.info(f"input image type is {type(imgbin)}")
|
586 |
#logging.info(f"input image type is {type(f)}")
|
587 |
-
|
588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
589 |
logging.info(f"uploading Image Type: {type(img)}")
|
590 |
if img.format != "PNG":
|
591 |
dest_filename, dest_ext = os.path.splitext(dest)
|
@@ -594,6 +625,15 @@ async def predict(
|
|
594 |
# f.seek(0)
|
595 |
# img = Image.open(f)
|
596 |
img = convertImageFormat(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
img.save( dest_filename + ".png", "PNG")
|
598 |
model_output,_ = predict_image(model_name, images=img)
|
599 |
logging.debug(f"predict output as: {model_output}")
|
@@ -847,7 +887,7 @@ def get_pdf(pdf_link):
|
|
847 |
|
848 |
if response.status_code == 200:
|
849 |
# Save the PDF content to a local file
|
850 |
-
with open(unique_filename, 'wb'
|
851 |
pdf_file.write(response.content)
|
852 |
logging.info("PDF downloaded successfully.")
|
853 |
else:
|
@@ -1134,5 +1174,5 @@ if __name__ == "__main__":
|
|
1134 |
# ssl_keyfile='/workspace/nougat-latex/lzs.chrdw.ml.key',
|
1135 |
# ssl_certfile='/workspace/nougat-latex/fullchain.cer')
|
1136 |
|
1137 |
-
uvicorn.run("__main__:app", host="0.0.0.0", port=
|
1138 |
#demo.launch(debug=True,share=True, server_name="0.0.0.0",server_port=8866)
|
|
|
11 |
import sys
|
12 |
from functools import partial
|
13 |
from http import HTTPStatus
|
14 |
+
import cv2
|
15 |
from fastapi import FastAPI, File, UploadFile, Request,Response, BackgroundTasks, HTTPException
|
16 |
from fastapi import APIRouter, Depends
|
17 |
import os
|
|
|
70 |
from sql_app.db import Base
|
71 |
import psycopg2
|
72 |
import numpy as np
|
73 |
+
from .binarization import sauvola, convert_from_cv2_to_image, convert_from_image_to_cv2
|
74 |
|
75 |
+
logging.basicConfig(filename='info.log', level=logging.INFO)
|
76 |
#logger = logging.getLogger()
|
77 |
#logger.setLevel(logging.INFO)
|
78 |
|
|
|
89 |
|
90 |
# Load the ML model
|
91 |
def loadModel(checkpoint):
|
92 |
+
if not checkpoint.exists():
|
93 |
+
checkpoint = default_checkpoint_path
|
94 |
+
logging.info(f"request checkpoint is not exist, using default {checkpoint_name}")
|
95 |
model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
|
96 |
if torch.cuda.is_available():
|
97 |
model.to("cuda")
|
|
|
246 |
headers=dict(response.headers), media_type=response.media_type, background=task)
|
247 |
|
248 |
''' debug code for , not test yet
|
249 |
+
# Exception handlers
|
250 |
+
def add_exception_handlers(_app: FastAPI):
|
251 |
+
@_app.exception_handler(ApiAuthException)
|
252 |
+
async def api_auth_exception_handler(request: Request, exc: ApiAuthException):
|
253 |
+
return await handler.api_auth_exception_handler(request, exc)
|
254 |
+
|
255 |
+
@_app.exception_handler(ApiException)
|
256 |
+
async def api_exception_handler(request: Request, exc: ApiException):
|
257 |
+
return await handler.api_exception_handler(request, exc)
|
258 |
+
|
259 |
+
add_exception_handlers(main_app)
|
260 |
+
add_exception_handlers(sub_app)
|
261 |
+
|
262 |
from loguru import logger
|
263 |
from starlette.routing import Match
|
264 |
|
|
|
418 |
logging.info("we are under image to mmd convertiong")
|
419 |
#sample = Image.open(images.name).convert('RGB')
|
420 |
sample = images.convert('RGB')
|
421 |
+
gray_image = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY)
|
422 |
+
sauvola_img = sauvola(gray_image)
|
423 |
+
#convert back to RGB format
|
424 |
+
sample = cv2.cvtColor(sauvola_img,cv2.COLOR_GRAY2RGB)
|
425 |
im_new = resize_with_padding(sample, (672,896))
|
426 |
img_tensor = prepare(im_new,random_padding=False)
|
427 |
img_tensor = img_tensor.unsqueeze(0)
|
|
|
443 |
out = out.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
|
444 |
f.write(out)
|
445 |
else:
|
446 |
+
logging.debug(f"the out is {out}")
|
447 |
|
448 |
return model_output, [out_path]
|
449 |
|
|
|
544 |
async def predict(
|
545 |
request: Request,
|
546 |
selectedModel: str = Form(...),
|
547 |
+
binarization: str = Form(...),
|
548 |
file: UploadFile = File(...),
|
549 |
start: int = None, stop: int = None,
|
550 |
):
|
|
|
571 |
#parsed_url = urlparse(request)
|
572 |
#model_name = parse_qs(parsed_url.query)['model'][0]
|
573 |
model_name = selectedModel
|
574 |
+
isBinarized = bool(binarization)
|
575 |
if model_name == None:
|
576 |
model = nougatModel
|
577 |
else:
|
|
|
596 |
try:
|
597 |
with open(dest, 'wb') as f:
|
598 |
logging.info(f"save uploading files to {dest}")
|
599 |
+
imgbin = await file.read()
|
600 |
f.write(imgbin)
|
601 |
md5 = hashlib.md5(imgbin).hexdigest()
|
602 |
finger_printer = md5
|
|
|
607 |
f.close()
|
608 |
#logging.info(f"input image type is {type(imgbin)}")
|
609 |
#logging.info(f"input image type is {type(f)}")
|
610 |
+
|
611 |
+
if not isBinarized:
|
612 |
+
#binarize image
|
613 |
+
image = cv2.imread(dest)
|
614 |
+
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
615 |
+
sauvola_img = sauvola(gray_image)
|
616 |
+
#convert back to RGB format
|
617 |
+
img = cv2.cvtColor(sauvola_img,cv2.COLOR_GRAY2RGB)
|
618 |
+
else:
|
619 |
+
img = Image.open(io.BytesIO(imgbin))
|
620 |
logging.info(f"uploading Image Type: {type(img)}")
|
621 |
if img.format != "PNG":
|
622 |
dest_filename, dest_ext = os.path.splitext(dest)
|
|
|
625 |
# f.seek(0)
|
626 |
# img = Image.open(f)
|
627 |
img = convertImageFormat(img)
|
628 |
+
#convert to cv2 format
|
629 |
+
cv2_image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
630 |
+
#convert to gray format
|
631 |
+
gray_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
|
632 |
+
#binarize with sauvola algorithm
|
633 |
+
sauvola_img = sauvola(gray_image)
|
634 |
+
#convert back to RGB format
|
635 |
+
img = cv2.cvtColor(sauvola_img,cv2.COLOR_GRAY2RGB)
|
636 |
+
img = convert_from_cv2_to_image(img)
|
637 |
img.save( dest_filename + ".png", "PNG")
|
638 |
model_output,_ = predict_image(model_name, images=img)
|
639 |
logging.debug(f"predict output as: {model_output}")
|
|
|
887 |
|
888 |
if response.status_code == 200:
|
889 |
# Save the PDF content to a local file
|
890 |
+
with open(unique_filename, 'wb') as pdf_file:
|
891 |
pdf_file.write(response.content)
|
892 |
logging.info("PDF downloaded successfully.")
|
893 |
else:
|
|
|
1174 |
# ssl_keyfile='/workspace/nougat-latex/lzs.chrdw.ml.key',
|
1175 |
# ssl_certfile='/workspace/nougat-latex/fullchain.cer')
|
1176 |
|
1177 |
+
uvicorn.run("__main__:app", host="0.0.0.0", port=8866,log_level="debug", workers=1)
|
1178 |
#demo.launch(debug=True,share=True, server_name="0.0.0.0",server_port=8866)
|
predict.ipynb
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f7512536f5c844b04801b605568df461d30cfaa3151d4dac30878824dbb698aa
|
3 |
+
size 4819007
|