ketanmore's picture
Upload folder using huggingface_hub
2720487 verified
from typing import List, Tuple
import torch
import numpy as np
from PIL import Image
from surya.model.detection.segformer import SegformerForRegressionMask
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
from surya.schema import TextDetectionResult
from surya.settings import settings
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import torch.nn.functional as F
def get_batch_size():
batch_size = settings.DETECTOR_BATCH_SIZE
if batch_size is None:
batch_size = 6
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 24
return batch_size
def batch_detection(images: List, model: SegformerForRegressionMask, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
assert all([isinstance(image, Image.Image) for image in images])
if batch_size is None:
batch_size = get_batch_size()
heatmap_count = model.config.num_labels
images = [image.convert("RGB") for image in images] # also copies the images
orig_sizes = [image.size for image in images]
splits_per_image = [get_total_splits(size, processor) for size in orig_sizes]
batches = []
current_batch_size = 0
current_batch = []
for i in range(len(images)):
if current_batch_size + splits_per_image[i] > batch_size:
if len(current_batch) > 0:
batches.append(current_batch)
current_batch = []
current_batch_size = 0
current_batch.append(i)
current_batch_size += splits_per_image[i]
if len(current_batch) > 0:
batches.append(current_batch)
all_preds = []
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
batch_image_idxs = batches[batch_idx]
batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs])
split_index = []
split_heights = []
image_splits = []
for image_idx, image in enumerate(batch_images):
image_parts, split_height = split_image(image, processor)
image_splits.extend(image_parts)
split_index.extend([image_idx] * len(image_parts))
split_heights.extend(split_height)
image_splits = [prepare_image_detection(image, processor) for image in image_splits]
# Batch images in dim 0
batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device)
with torch.inference_mode():
pred = model(pixel_values=batch)
logits = pred.logits
correct_shape = [processor.size["height"], processor.size["width"]]
current_shape = list(logits.shape[2:])
if current_shape != correct_shape:
logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)
logits = logits.cpu().detach().numpy().astype(np.float32)
preds = []
for i, (idx, height) in enumerate(zip(split_index, split_heights)):
# If our current prediction length is below the image idx, that means we have a new image
# Otherwise, we need to add to the current image
if len(preds) <= idx:
preds.append([logits[i][k] for k in range(heatmap_count)])
else:
heatmaps = preds[idx]
pred_heatmaps = [logits[i][k] for k in range(heatmap_count)]
if height < processor.size["height"]:
# Cut off padding to get original height
pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps]
for k in range(heatmap_count):
heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
preds[idx] = heatmaps
all_preds.extend(preds)
assert len(all_preds) == len(images)
assert all([len(pred) == heatmap_count for pred in all_preds])
return all_preds, orig_sizes
def parallel_get_lines(preds, orig_sizes):
heatmap, affinity_map = preds
heat_img = Image.fromarray((heatmap * 255).astype(np.uint8))
aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8))
affinity_size = list(reversed(affinity_map.shape))
heatmap_size = list(reversed(heatmap.shape))
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes)
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes)
result = TextDetectionResult(
bboxes=bboxes,
vertical_lines=vertical_lines,
heatmap=heat_img,
affinity_map=aff_img,
image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]]
)
return result
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
results = []
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images
for i in range(len(images)):
result = parallel_get_lines(preds[i], orig_sizes[i])
results.append(result)
else:
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(parallel_get_lines, preds, orig_sizes))
return results