File size: 5,252 Bytes
c063bb0
a9d81c5
 
b79dbcf
c063bb0
a9d81c5
 
698149b
c063bb0
a9d81c5
738bd96
f6be418
c063bb0
738bd96
a9d81c5
738bd96
a9d81c5
f6be418
a9d81c5
 
738bd96
 
f6be418
a9d81c5
738bd96
f6be418
a9d81c5
 
 
 
738bd96
f6be418
738bd96
a9d81c5
 
 
 
 
e06502f
c063bb0
 
 
 
951d58b
c063bb0
b79dbcf
 
f6be418
c063bb0
7ea1aa1
951d58b
 
04f8aab
c063bb0
04f8aab
 
951d58b
b79dbcf
c063bb0
 
 
 
 
 
9a7a057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c063bb0
 
b79dbcf
a9d81c5
c063bb0
 
 
789611a
c063bb0
 
 
 
 
 
789611a
c063bb0
 
 
 
 
 
 
 
 
9a7a057
 
 
 
 
086a606
9a7a057
 
c063bb0
 
 
 
 
 
 
 
 
 
 
 
951d58b
 
 
 
c063bb0
60c1d6c
 
f6be418
c063bb0
 
e06502f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import os
from functools import partial
import torch

import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download

from die_model import UNetDIEModel
from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding


def die_inference(image_raw, num_of_die_iterations, die_model, device):
    """
    Applies the DIE model for document enhancement on a provided image.
    """
    # preprocess
    image_raw_resized = resize_image(image_raw, 1500)
    image_raw_resized_square = make_image_square(image_raw_resized)
    image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square).to(device)
    
    # convert string to int
    num_of_die_iterations = int(num_of_die_iterations)
    
    # inference
    image_die = die_model.enhance_document_image(
        image_raw_list=[image_raw_resized_square_tensor],
        num_of_die_iterations=num_of_die_iterations
    )[0]
    
    # postprocess
    return remove_square_padding(
        original_image=image_raw,
        square_image=image_die,
        resize_back_to_original=True
    )


def main():
    """
    Main function to set up and run the Gradio demo.
    """
    
    args = parse_arguments()
    
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Set up model
    die_token = os.getenv("DIE_TOKEN")
    
    args.die_model_path = hf_hub_download(
        repo_id="gabar92/die",
        filename=args.die_model_path,
        use_auth_token=die_token
    )
    
    die_model = UNetDIEModel(args=args)

    # Prepare example images
    example_image_list = [
        [Image.open(os.path.join(args.example_image_path, image_path))]
        for image_path in os.listdir(args.example_image_path)
    ]
    
    description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \
                  "" \
                  "This interactive application showcases a specialized AI model developed by " \
                  "the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \
                  "" \
                  "Our DIE model is designed to enhance and restore archival and aged document images " \
                  "by removing various types of degradation, thereby making historical documents more legible " \
                  "and suitable for Optical Character Recognition (OCR) processing.\n\n" \
                  "" \
                  "The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \
                  "such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \
                  "and unwanted background elements. " \
                  "By applying deep learning techniques, specifically a U-Net-based architecture, " \
                  "the model accurately cleans and clarifies text while preserving original details. " \
                  "This improved clarity dramatically boosts OCR accuracy, making it an ideal " \
                  "pre-processing tool in digitization workflows.\n\n" \
                  "" \
                  "If you’re interested in learning more about the model’s capabilities or potential applications, " \
                  "please contact us at: gabar92@renyi.hu.\n" 

    # Partial function for inference with model and device arguments
    partial_die_inference = partial(die_inference, die_model=die_model, device=args.device)

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                gr.Markdown("## Document Image Enhancement (DIE) model")

        with gr.Row():
            with gr.Column():
                gr.Markdown(description)
            with gr.Column():
                # Display QR code as an image in Gradio
                gr.Image(value=Image.open("logo/qr-code.png").resize((400, 400)), label="QR Code")

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil", label="Upload Degraded Document Image")
                num_iterations = gr.Dropdown([1, 2, 3], label="Number of DIE Iterations", value=1)
                run_button = gr.Button("Enhance Image")

            with gr.Column():
                output_image = gr.Image(type="pil", label="Enhanced Document Image")
        
        # Display example images
        gr.Examples(
            examples=example_image_list,
            inputs=[input_image],
            label="Example Images - Source: National Archives of Hungary and Budapest City Archives",
        )
        
        # Button trigger for inference
        run_button.click(partial_die_inference, [input_image, num_iterations], output_image)

    demo.launch()


def parse_arguments():
    """
    Parses command-line arguments.
    :return: argument namespace
    """
    parser = argparse.ArgumentParser()

    parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")

    parser.add_argument("--example_image_path", default="example_images")
    
    return parser.parse_args()
    

if __name__ == "__main__":
    main()