Spaces:
Configuration error
Configuration error
from rich.progress import Progress, SpinnerColumn, TextColumn | |
from rich import print | |
from transformers import AutoModelForObjectDetection | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
import os | |
class TableDetector(object): | |
_model = None # Static variable to hold the table detection model | |
_device = None # Static variable to hold the device information | |
def __init__(self): | |
pass | |
class MaxResize(object): | |
def __init__(self, max_size=800): | |
self.max_size = max_size | |
def __call__(self, image): | |
width, height = image.size | |
current_max_size = max(width, height) | |
scale = self.max_size / current_max_size | |
resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) | |
return resized_image | |
def _initialize_model(cls, invoke_pipeline_step, local): | |
""" | |
Static method to initialize the table detection model if not already initialized. | |
""" | |
if cls._model is None: | |
# Use invoke_pipeline_step to load the model | |
cls._model, cls._device = invoke_pipeline_step( | |
lambda: cls.load_table_detection_model(), | |
"Loading table detection model...", | |
local | |
) | |
print("Table detection model initialized.") | |
def detect_tables(self, file_path, local=True, debug_dir=None, debug=False): | |
# Ensure the model is initialized using invoke_pipeline_step | |
self._initialize_model(self.invoke_pipeline_step, local) | |
# Use the static model and device | |
model, device = self._model, self._device | |
outputs, image = self.invoke_pipeline_step( | |
lambda: self.prepare_image(file_path, model, device), | |
"Preparing image for table detection...", | |
local | |
) | |
objects = self.invoke_pipeline_step( | |
lambda: self.identify_tables(model, outputs, image), | |
"Identifying tables in the image...", | |
local | |
) | |
cropped_tables = self.invoke_pipeline_step( | |
lambda: self.crop_tables(file_path, image, objects, debug, debug_dir), | |
"Cropping tables from the image...", | |
local | |
) | |
return cropped_tables | |
def load_table_detection_model(): | |
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
return model, device | |
def prepare_image(self, file_path, model, device): | |
image = Image.open(file_path).convert("RGB") | |
detection_transform = transforms.Compose([ | |
self.MaxResize(800), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
pixel_values = detection_transform(image).unsqueeze(0) | |
pixel_values = pixel_values.to(device) | |
with torch.no_grad(): | |
outputs = model(pixel_values) | |
return outputs, image | |
def identify_tables(self, model, outputs, image): | |
id2label = model.config.id2label | |
id2label[len(model.config.id2label)] = "no object" | |
objects = self.outputs_to_objects(outputs, image.size, id2label) | |
return objects | |
def crop_tables(self, file_path, image, objects, debug, debug_dir): | |
tokens = [] | |
detection_class_thresholds = { | |
"table": 0.5, | |
"table rotated": 0.5, | |
"no object": 10 | |
} | |
crop_padding = 30 | |
tables_crops = self.objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding) | |
cropped_tables = [] | |
if len(tables_crops) == 0: | |
if debug: | |
print("No tables detected in: ", file_path) | |
return None | |
elif len(tables_crops) > 1: | |
for i, table_crop in enumerate(tables_crops): | |
if debug: | |
print("Table detected in:", file_path, "-", i + 1) | |
cropped_table = table_crop['image'].convert("RGB") | |
cropped_tables.append(cropped_table) | |
if debug_dir: | |
file_name_table = self.append_filename(file_path, debug_dir, f"table_cropped_{i + 1}") | |
cropped_table.save(file_name_table) | |
else: | |
if debug: | |
print("Table detected in: ", file_path) | |
cropped_table = tables_crops[0]['image'].convert("RGB") | |
cropped_tables.append(cropped_table) | |
if debug_dir: | |
file_name_table = self.append_filename(file_path, debug_dir, "table_cropped") | |
cropped_table.save(file_name_table) | |
return cropped_tables | |
# for output bounding box post-processing | |
def box_cxcywh_to_xyxy(x): | |
x_c, y_c, w, h = x.unbind(-1) | |
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
return torch.stack(b, dim=1) | |
def rescale_bboxes(self, out_bbox, size): | |
img_w, img_h = size | |
b = self.box_cxcywh_to_xyxy(out_bbox) | |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
return b | |
def outputs_to_objects(self, outputs, img_size, id2label): | |
m = outputs.logits.softmax(-1).max(-1) | |
pred_labels = list(m.indices.detach().cpu().numpy())[0] | |
pred_scores = list(m.values.detach().cpu().numpy())[0] | |
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] | |
pred_bboxes = [elem.tolist() for elem in self.rescale_bboxes(pred_bboxes, img_size)] | |
objects = [] | |
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): | |
class_label = id2label[int(label)] | |
if not class_label == 'no object': | |
objects.append({'label': class_label, 'score': float(score), | |
'bbox': [float(elem) for elem in bbox]}) | |
return objects | |
def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10): | |
""" | |
Process the bounding boxes produced by the table detection model into | |
cropped table images and cropped tokens. | |
""" | |
table_crops = [] | |
for obj in objects: | |
if obj['score'] < class_thresholds[obj['label']]: | |
continue | |
cropped_table = {} | |
bbox = obj['bbox'] | |
bbox = [bbox[0] - padding, bbox[1] - padding, bbox[2] + padding, bbox[3] + padding] | |
cropped_img = img.crop(bbox) | |
table_tokens = [token for token in tokens if self.iob(token['bbox'], bbox) >= 0.5] | |
for token in table_tokens: | |
token['bbox'] = [token['bbox'][0] - bbox[0], | |
token['bbox'][1] - bbox[1], | |
token['bbox'][2] - bbox[0], | |
token['bbox'][3] - bbox[1]] | |
# If table is predicted to be rotated, rotate cropped image and tokens/words: | |
if obj['label'] == 'table rotated': | |
cropped_img = cropped_img.rotate(270, expand=True) | |
for token in table_tokens: | |
bbox = token['bbox'] | |
bbox = [cropped_img.size[0] - bbox[3] - 1, | |
bbox[0], | |
cropped_img.size[0] - bbox[1] - 1, | |
bbox[2]] | |
token['bbox'] = bbox | |
cropped_table['image'] = cropped_img | |
cropped_table['tokens'] = table_tokens | |
table_crops.append(cropped_table) | |
return table_crops | |
def append_filename(file_path, debug_dir, word): | |
directory, filename = os.path.split(file_path) | |
name, ext = os.path.splitext(filename) | |
new_filename = f"{name}_{word}{ext}" | |
return os.path.join(debug_dir, new_filename) | |
def iob(boxA, boxB): | |
# Determine the coordinates of the intersection rectangle | |
xA = max(boxA[0], boxB[0]) | |
yA = max(boxA[1], boxB[1]) | |
xB = min(boxA[2], boxB[2]) | |
yB = min(boxA[3], boxB[3]) | |
# Compute the area of intersection rectangle | |
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) | |
# Compute the area of both the prediction and ground-truth rectangles | |
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) | |
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) | |
# Compute the intersection over box (IoB) | |
iob = interArea / float(boxAArea) | |
return iob | |
def invoke_pipeline_step(task_call, task_description, local): | |
if local: | |
with Progress( | |
SpinnerColumn(), | |
TextColumn("[progress.description]{task.description}"), | |
transient=False, | |
) as progress: | |
progress.add_task(description=task_description, total=None) | |
ret = task_call() | |
else: | |
print(task_description) | |
ret = task_call() | |
return ret | |
if __name__ == "__main__": | |
table_detector = TableDetector() | |
# file_path = "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png" | |
# cropped_tables = table_detector.detect_tables(file_path, local=True, debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/", debug=True) | |
# for i, cropped_table in enumerate(cropped_tables): | |
# file_name_table = table_detector.append_filename(file_path, "cropped_" + str(i)) | |
# cropped_table.save(file_name_table) |