shreyanshu09's picture
Update README.md
18bd282 verified
|
raw
history blame
2.22 kB
metadata
license: mit
tags:
  - donut
  - image-to-text
  - vision
datasets:
  - shreyanshu09/Block_Diagram
  - shreyanshu09/BD-EnKo
language:
  - en
  - ko

Block Diagram Global Information Extractor

It was introduced in the paper "Unveiling the Power of Integration: Block Diagram Summarization through Local-Global Fusion" accepted at ACL 2024.

Model description

This model is trained using a transformer encoder and decoder architecture, based on the configuration specified in Donut, to extract the overall summary of block diagram images. It supports both English and Korean languages. The straightforward architecture comprises a visual encoder module and a text decoder module, both based on the Transformer architecture.

Training dataset

  • 41,933 samples from the synthetic and real-world block diagrams in English language (BD-EnKo)
  • 33,101 samples from the synthetic and real-world block diagrams in Korean language (BD-EnKo)
  • 396 samples from real-world English block diagram dataset (CBD)
  • 357 samples from handwritten English block diagram dataset (FC_A)
  • 476 samples from handwritten English block diagram dataset (FC_B)

How to use

Here is how to use this model in PyTorch:

import os
from PIL import Image
import torch
from donut import DonutModel

# Load the pre-trained model
model = DonutModel.from_pretrained("shreyanshu09/block_diagram_global_information") 

# Move the model to GPU if available
if torch.cuda.is_available():
    model.half()
    device = torch.device("cuda:0")
    model.to(device)
    
# Function to process a single image
def process_image(image_path):
    # Load and process the image
    image = Image.open(image_path)
    task_name = os.path.basename('/block_diagram_global_information/dataset/c2t_data/')                  # Create empty folder anywhere
    result = model.inference(image=image, prompt=f"<s_{task_name}>")["predictions"][0]

    # Extract the relevant information from the result
    if 'c2t' in result:
        return result['c2t']
    else:
        return result['text_sequence']

# Example usage
image_path = 'image.png'                  # Input image file
result = process_image(image_path)