msheriff commited on
Commit
4e0cdd6
1 Parent(s): 400411f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -4
app.py CHANGED
@@ -19,16 +19,94 @@ import gradio as gr
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def process_pdf():
23
  print('process_pdf')
24
- # cropped_table = detect_and_crop_table(image)
25
  # image, cells = recognize_table(cropped_table)
26
 
27
  # cell_coordinates = get_cell_coordinates_by_row(cells)
28
  # df, data = apply_ocr(cell_coordinates, image)
29
 
 
30
  # return image, df, data
31
- return [], [], []
32
 
33
  title = "Sheriff's Demo: Table Detection & Recognition with Table Transformer (TATR)."
34
  description = """A demo by M Sheriff for table extraction with the Table Transformer.
@@ -39,8 +117,9 @@ after which the detected table is extracted and https://huggingface.co/microsoft
39
  # examples = [['image.png'], ['mistral_paper.png']]
40
 
41
  app = gr.Interface(fn=process_pdf,
42
- inputs=gr.Image(type="pil"),
43
- outputs=[gr.Image(type="pil", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
 
44
  title=title,
45
  description=description,
46
  # examples=examples
 
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+
23
+ class MaxResize(object):
24
+ def __init__(self, max_size=800):
25
+ self.max_size = max_size
26
+
27
+ def __call__(self, image):
28
+ width, height = image.size
29
+ current_max_size = max(width, height)
30
+ scale = self.max_size / current_max_size
31
+ resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
32
+
33
+ return resized_image
34
+
35
+ detection_transform = transforms.Compose([
36
+ MaxResize(800),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
39
+ ])
40
+
41
+ structure_transform = transforms.Compose([
42
+ MaxResize(1000),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
45
+ ])
46
+
47
+ # load table detection model
48
+ # processor = TableTransformerImageProcessor(max_size=800)
49
+ model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm").to(device)
50
+
51
+ # load table structure recognition model
52
+ # structure_processor = TableTransformerImageProcessor(max_size=1000)
53
+ structure_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(device)
54
+
55
+ # load EasyOCR reader
56
+ reader = easyocr.Reader(['en'])
57
+
58
+
59
+ def outputs_to_objects(outputs, img_size, id2label):
60
+ m = outputs.logits.softmax(-1).max(-1)
61
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
62
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
63
+ pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
64
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
65
+
66
+ objects = []
67
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
68
+ class_label = id2label[int(label)]
69
+ if not class_label == 'no object':
70
+ objects.append({'label': class_label, 'score': float(score),
71
+ 'bbox': [float(elem) for elem in bbox]})
72
+ return objects
73
+
74
+
75
+ def detect_and_crop_table(image):
76
+ # prepare image for the model
77
+ # pixel_values = processor(image, return_tensors="pt").pixel_values
78
+ pixel_values = detection_transform(image).unsqueeze(0).to(device)
79
+
80
+ # forward pass
81
+ with torch.no_grad():
82
+ outputs = model(pixel_values)
83
+
84
+ # postprocess to get detected tables
85
+ id2label = model.config.id2label
86
+ id2label[len(model.config.id2label)] = "no object"
87
+ detected_tables = outputs_to_objects(outputs, image.size, id2label)
88
+
89
+ # visualize
90
+ # fig = visualize_detected_tables(image, detected_tables)
91
+ # image = fig2img(fig)
92
+
93
+ # crop first detected table out of image
94
+ cropped_table = image.crop(detected_tables[0]["bbox"])
95
+
96
+ return cropped_table
97
+
98
+
99
  def process_pdf():
100
  print('process_pdf')
101
+ cropped_table = detect_and_crop_table(image)
102
  # image, cells = recognize_table(cropped_table)
103
 
104
  # cell_coordinates = get_cell_coordinates_by_row(cells)
105
  # df, data = apply_ocr(cell_coordinates, image)
106
 
107
+ return cropped_table
108
  # return image, df, data
109
+ //return [], [], []
110
 
111
  title = "Sheriff's Demo: Table Detection & Recognition with Table Transformer (TATR)."
112
  description = """A demo by M Sheriff for table extraction with the Table Transformer.
 
117
  # examples = [['image.png'], ['mistral_paper.png']]
118
 
119
  app = gr.Interface(fn=process_pdf,
120
+ inputs=gr.Image(type="pil"),
121
+ outputs=[gr.Image(type="pil")],
122
+ //outputs=[gr.Image(type="pil", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
123
  title=title,
124
  description=description,
125
  # examples=examples