zphilip48 commited on
Commit
a9ab22c
1 Parent(s): dcc405e

Binarization in fastapi side

Browse files
Files changed (3) hide show
  1. binarization .py +107 -0
  2. nougat_api_app.py +48 -8
  3. predict.ipynb +2 -2
binarization .py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import cv2
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ def integral(img):
7
+ '''
8
+ 计算图像的积分和平方积分
9
+ :param img:Mat--- 输入待处理图像
10
+ :return:integral_sum, integral_sqrt_sum:Mat--- 积分图和平方积分图
11
+ '''
12
+ integral_sum=np.zeros((img.shape[0],img.shape[1]),dtype=np.int32)
13
+ integral_sqrt_sum=np.zeros((img.shape[0],img.shape[1]),dtype=np.int32)
14
+
15
+ rows,cols=img.shape
16
+ for r in range(rows):
17
+ sum=0
18
+ sqrt_sum=0
19
+ for c in range(cols):
20
+ sum+=img[r][c]
21
+ sqrt_sum+=math.sqrt(img[r][c])
22
+
23
+ if r==0:
24
+ integral_sum[r][c]=sum
25
+ integral_sqrt_sum[r][c]=sqrt_sum
26
+ else:
27
+ integral_sum[r][c]=sum+integral_sum[r-1][c]
28
+ integral_sqrt_sum[r][c]=sqrt_sum+integral_sqrt_sum[r-1][c]
29
+
30
+ return integral_sum, integral_sqrt_sum
31
+
32
+ def sauvola(img,k=0.1,kernerl=(31,31)):
33
+ '''
34
+ sauvola阈值法。
35
+ 根据当前像素点邻域内的灰度均值与标准方差来动态计算该像素点的阈值
36
+ :param img:Mat--- 输入待处理图像
37
+ :param k:float---修正参数,一般0<k<1
38
+ :param kernerl:set---窗口大小
39
+ :return:img:Mat---阈值处理后的图像
40
+ '''
41
+ if kernerl[0]%2!=1 or kernerl[1]%2!=1:
42
+ raise ValueError('kernerl元组中的值必须为奇数, 请检查kernerl[0] or kernerl[1]是否为奇数!!!')
43
+
44
+ # 计算积分图和积分平方和图
45
+ integral_sum,integral_sqrt_sum=integral(img)
46
+ # integral_sum, integral_sqrt_sum = cv2.integral2(img)
47
+ # integral_sum=integral_sum[1:integral_sum.shape[0],1:integral_sum.shape[1]]
48
+ # integral_sqrt_sum=integral_sqrt_sum[1:integral_sqrt_sum.shape[0],1:integral_sqrt_sum.shape[1]]
49
+
50
+ #创建图像
51
+ rows,cols=img.shape
52
+ diff=np.zeros((rows,cols),np.float32)
53
+ sqrt_diff=np.zeros((rows,cols),np.float32)
54
+ mean=np.zeros((rows,cols),np.float32)
55
+ threshold=np.zeros((rows,cols),np.float32)
56
+ std=np.zeros((rows,cols),np.float32)
57
+
58
+ whalf=kernerl[0]>>1#计算领域类半径的一半
59
+
60
+ for row in range(rows):
61
+ #print('第{}行处理中...'.format(row))
62
+ for col in range(cols):
63
+ xmin=max(0,row-whalf)
64
+ ymin=max(0,col-whalf)
65
+ xmax=min(rows-1,row+whalf)
66
+ ymax=min(cols-1,col+whalf)
67
+
68
+ area=(xmax-xmin+1)*(ymax-ymin+1)
69
+ if area<=0:
70
+ sys.exit(1)
71
+
72
+ if xmin==0 and ymin==0:
73
+ diff[row,col]=integral_sum[xmax,ymax]
74
+ sqrt_diff[row,col]=integral_sqrt_sum[xmax,ymax]
75
+ elif xmin>0 and ymin==0:
76
+ diff[row, col] = integral_sum[xmax, ymax]-integral_sum[xmin-1,ymax]
77
+ sqrt_diff[row, col] = integral_sqrt_sum[xmax, ymax]-integral_sqrt_sum[xmin-1, ymax]
78
+ elif xmin==0 and ymin>0:
79
+ diff[row, col] = integral_sum[xmax, ymax] - integral_sum[xmax, ymax-1]
80
+ sqrt_diff[row, col] = integral_sqrt_sum[xmax, ymax] - integral_sqrt_sum[xmax, ymax-1]
81
+ else:
82
+ diagsum=integral_sum[xmax, ymax]+integral_sum[xmin-1, ymin-1]
83
+ idiagsum=integral_sum[xmax, ymin-1]+integral_sum[xmin-1, ymax]
84
+ diff[row,col]=diagsum-idiagsum
85
+
86
+ sqdiagsum=integral_sqrt_sum[xmax, ymax]+integral_sqrt_sum[xmin-1, ymin-1]
87
+ sqidiagsum=integral_sqrt_sum[xmax, ymin-1]+integral_sqrt_sum[xmin-1, ymax]
88
+ sqrt_diff[row,col]=sqdiagsum-sqidiagsum
89
+
90
+ mean[row,col]=diff[row, col]/area
91
+ std[row,col]=math.sqrt((sqrt_diff[row,col]-math.sqrt(diff[row,col])/area)/(area-1))
92
+ threshold[row,col]=mean[row,col]*(1+k*((std[row,col]/128)-1))
93
+
94
+ if img[row,col]<threshold[row,col]:
95
+ img[row,col]=0
96
+ else:
97
+ img[row,col]=255
98
+
99
+ return img
100
+
101
+ def convert_from_cv2_to_image(img: np.ndarray) -> Image:
102
+ # return Image.fromarray(img)
103
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
104
+
105
+ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
106
+ # return np.asarray(img)
107
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
nougat_api_app.py CHANGED
@@ -11,7 +11,7 @@ print('GPU Device name:', torch.cuda.get_device_name(torch.cuda.current_device()
11
  import sys
12
  from functools import partial
13
  from http import HTTPStatus
14
-
15
  from fastapi import FastAPI, File, UploadFile, Request,Response, BackgroundTasks, HTTPException
16
  from fastapi import APIRouter, Depends
17
  import os
@@ -70,8 +70,9 @@ from datetime import datetime
70
  from sql_app.db import Base
71
  import psycopg2
72
  import numpy as np
 
73
 
74
- logging.basicConfig(filename='info.log', level=logging.DEBUG)
75
  #logger = logging.getLogger()
76
  #logger.setLevel(logging.INFO)
77
 
@@ -88,6 +89,9 @@ global selected_model_name
88
 
89
  # Load the ML model
90
  def loadModel(checkpoint):
 
 
 
91
  model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
92
  if torch.cuda.is_available():
93
  model.to("cuda")
@@ -242,6 +246,19 @@ async def app_middleware(request: Request, call_next):
242
  headers=dict(response.headers), media_type=response.media_type, background=task)
243
 
244
  ''' debug code for , not test yet
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  from loguru import logger
246
  from starlette.routing import Match
247
 
@@ -401,6 +418,10 @@ def predict_image(model_name, images, batchsize=1, markdown=True, out_path=""):
401
  logging.info("we are under image to mmd convertiong")
402
  #sample = Image.open(images.name).convert('RGB')
403
  sample = images.convert('RGB')
 
 
 
 
404
  im_new = resize_with_padding(sample, (672,896))
405
  img_tensor = prepare(im_new,random_padding=False)
406
  img_tensor = img_tensor.unsqueeze(0)
@@ -422,7 +443,7 @@ def predict_image(model_name, images, batchsize=1, markdown=True, out_path=""):
422
  out = out.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
423
  f.write(out)
424
  else:
425
- logging.debug(out, "\n\n")
426
 
427
  return model_output, [out_path]
428
 
@@ -523,6 +544,7 @@ def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
523
  async def predict(
524
  request: Request,
525
  selectedModel: str = Form(...),
 
526
  file: UploadFile = File(...),
527
  start: int = None, stop: int = None,
528
  ):
@@ -549,6 +571,7 @@ async def predict(
549
  #parsed_url = urlparse(request)
550
  #model_name = parse_qs(parsed_url.query)['model'][0]
551
  model_name = selectedModel
 
552
  if model_name == None:
553
  model = nougatModel
554
  else:
@@ -573,7 +596,7 @@ async def predict(
573
  try:
574
  with open(dest, 'wb') as f:
575
  logging.info(f"save uploading files to {dest}")
576
- imgbin = await file.read()
577
  f.write(imgbin)
578
  md5 = hashlib.md5(imgbin).hexdigest()
579
  finger_printer = md5
@@ -584,8 +607,16 @@ async def predict(
584
  f.close()
585
  #logging.info(f"input image type is {type(imgbin)}")
586
  #logging.info(f"input image type is {type(f)}")
587
-
588
- img = Image.open(io.BytesIO(imgbin))
 
 
 
 
 
 
 
 
589
  logging.info(f"uploading Image Type: {type(img)}")
590
  if img.format != "PNG":
591
  dest_filename, dest_ext = os.path.splitext(dest)
@@ -594,6 +625,15 @@ async def predict(
594
  # f.seek(0)
595
  # img = Image.open(f)
596
  img = convertImageFormat(img)
 
 
 
 
 
 
 
 
 
597
  img.save( dest_filename + ".png", "PNG")
598
  model_output,_ = predict_image(model_name, images=img)
599
  logging.debug(f"predict output as: {model_output}")
@@ -847,7 +887,7 @@ def get_pdf(pdf_link):
847
 
848
  if response.status_code == 200:
849
  # Save the PDF content to a local file
850
- with open(unique_filename, 'wb',encoding="utf-8") as pdf_file:
851
  pdf_file.write(response.content)
852
  logging.info("PDF downloaded successfully.")
853
  else:
@@ -1134,5 +1174,5 @@ if __name__ == "__main__":
1134
  # ssl_keyfile='/workspace/nougat-latex/lzs.chrdw.ml.key',
1135
  # ssl_certfile='/workspace/nougat-latex/fullchain.cer')
1136
 
1137
- uvicorn.run("__main__:app", host="0.0.0.0", port=8503,log_level="debug", workers=1)
1138
  #demo.launch(debug=True,share=True, server_name="0.0.0.0",server_port=8866)
 
11
  import sys
12
  from functools import partial
13
  from http import HTTPStatus
14
+ import cv2
15
  from fastapi import FastAPI, File, UploadFile, Request,Response, BackgroundTasks, HTTPException
16
  from fastapi import APIRouter, Depends
17
  import os
 
70
  from sql_app.db import Base
71
  import psycopg2
72
  import numpy as np
73
+ from .binarization import sauvola, convert_from_cv2_to_image, convert_from_image_to_cv2
74
 
75
+ logging.basicConfig(filename='info.log', level=logging.INFO)
76
  #logger = logging.getLogger()
77
  #logger.setLevel(logging.INFO)
78
 
 
89
 
90
  # Load the ML model
91
  def loadModel(checkpoint):
92
+ if not checkpoint.exists():
93
+ checkpoint = default_checkpoint_path
94
+ logging.info(f"request checkpoint is not exist, using default {checkpoint_name}")
95
  model = NougatModel.from_pretrained(checkpoint).to(torch.bfloat16)
96
  if torch.cuda.is_available():
97
  model.to("cuda")
 
246
  headers=dict(response.headers), media_type=response.media_type, background=task)
247
 
248
  ''' debug code for , not test yet
249
+ # Exception handlers
250
+ def add_exception_handlers(_app: FastAPI):
251
+ @_app.exception_handler(ApiAuthException)
252
+ async def api_auth_exception_handler(request: Request, exc: ApiAuthException):
253
+ return await handler.api_auth_exception_handler(request, exc)
254
+
255
+ @_app.exception_handler(ApiException)
256
+ async def api_exception_handler(request: Request, exc: ApiException):
257
+ return await handler.api_exception_handler(request, exc)
258
+
259
+ add_exception_handlers(main_app)
260
+ add_exception_handlers(sub_app)
261
+
262
  from loguru import logger
263
  from starlette.routing import Match
264
 
 
418
  logging.info("we are under image to mmd convertiong")
419
  #sample = Image.open(images.name).convert('RGB')
420
  sample = images.convert('RGB')
421
+ gray_image = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY)
422
+ sauvola_img = sauvola(gray_image)
423
+ #convert back to RGB format
424
+ sample = cv2.cvtColor(sauvola_img,cv2.COLOR_GRAY2RGB)
425
  im_new = resize_with_padding(sample, (672,896))
426
  img_tensor = prepare(im_new,random_padding=False)
427
  img_tensor = img_tensor.unsqueeze(0)
 
443
  out = out.replace(r"\(", "$").replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
444
  f.write(out)
445
  else:
446
+ logging.debug(f"the out is {out}")
447
 
448
  return model_output, [out_path]
449
 
 
544
  async def predict(
545
  request: Request,
546
  selectedModel: str = Form(...),
547
+ binarization: str = Form(...),
548
  file: UploadFile = File(...),
549
  start: int = None, stop: int = None,
550
  ):
 
571
  #parsed_url = urlparse(request)
572
  #model_name = parse_qs(parsed_url.query)['model'][0]
573
  model_name = selectedModel
574
+ isBinarized = bool(binarization)
575
  if model_name == None:
576
  model = nougatModel
577
  else:
 
596
  try:
597
  with open(dest, 'wb') as f:
598
  logging.info(f"save uploading files to {dest}")
599
+ imgbin = await file.read()
600
  f.write(imgbin)
601
  md5 = hashlib.md5(imgbin).hexdigest()
602
  finger_printer = md5
 
607
  f.close()
608
  #logging.info(f"input image type is {type(imgbin)}")
609
  #logging.info(f"input image type is {type(f)}")
610
+
611
+ if not isBinarized:
612
+ #binarize image
613
+ image = cv2.imread(dest)
614
+ gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
615
+ sauvola_img = sauvola(gray_image)
616
+ #convert back to RGB format
617
+ img = cv2.cvtColor(sauvola_img,cv2.COLOR_GRAY2RGB)
618
+ else:
619
+ img = Image.open(io.BytesIO(imgbin))
620
  logging.info(f"uploading Image Type: {type(img)}")
621
  if img.format != "PNG":
622
  dest_filename, dest_ext = os.path.splitext(dest)
 
625
  # f.seek(0)
626
  # img = Image.open(f)
627
  img = convertImageFormat(img)
628
+ #convert to cv2 format
629
+ cv2_image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
630
+ #convert to gray format
631
+ gray_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
632
+ #binarize with sauvola algorithm
633
+ sauvola_img = sauvola(gray_image)
634
+ #convert back to RGB format
635
+ img = cv2.cvtColor(sauvola_img,cv2.COLOR_GRAY2RGB)
636
+ img = convert_from_cv2_to_image(img)
637
  img.save( dest_filename + ".png", "PNG")
638
  model_output,_ = predict_image(model_name, images=img)
639
  logging.debug(f"predict output as: {model_output}")
 
887
 
888
  if response.status_code == 200:
889
  # Save the PDF content to a local file
890
+ with open(unique_filename, 'wb') as pdf_file:
891
  pdf_file.write(response.content)
892
  logging.info("PDF downloaded successfully.")
893
  else:
 
1174
  # ssl_keyfile='/workspace/nougat-latex/lzs.chrdw.ml.key',
1175
  # ssl_certfile='/workspace/nougat-latex/fullchain.cer')
1176
 
1177
+ uvicorn.run("__main__:app", host="0.0.0.0", port=8866,log_level="debug", workers=1)
1178
  #demo.launch(debug=True,share=True, server_name="0.0.0.0",server_port=8866)
predict.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:38cbc49e9d28eb3db48cff4e23ee23122912778874db53e8891013c4c4b60744
3
- size 36292
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7512536f5c844b04801b605568df461d30cfaa3151d4dac30878824dbb698aa
3
+ size 4819007