File size: 5,252 Bytes
c063bb0 a9d81c5 b79dbcf c063bb0 a9d81c5 698149b c063bb0 a9d81c5 738bd96 f6be418 c063bb0 738bd96 a9d81c5 738bd96 a9d81c5 f6be418 a9d81c5 738bd96 f6be418 a9d81c5 738bd96 f6be418 a9d81c5 738bd96 f6be418 738bd96 a9d81c5 e06502f c063bb0 951d58b c063bb0 b79dbcf f6be418 c063bb0 7ea1aa1 951d58b 04f8aab c063bb0 04f8aab 951d58b b79dbcf c063bb0 9a7a057 c063bb0 b79dbcf a9d81c5 c063bb0 789611a c063bb0 789611a c063bb0 9a7a057 086a606 9a7a057 c063bb0 951d58b c063bb0 60c1d6c f6be418 c063bb0 e06502f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import argparse
import os
from functools import partial
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from die_model import UNetDIEModel
from utils import resize_image, make_image_square, cast_pil_image_to_torch_tensor_with_4_channel_dim, remove_square_padding
def die_inference(image_raw, num_of_die_iterations, die_model, device):
"""
Applies the DIE model for document enhancement on a provided image.
"""
# preprocess
image_raw_resized = resize_image(image_raw, 1500)
image_raw_resized_square = make_image_square(image_raw_resized)
image_raw_resized_square_tensor = cast_pil_image_to_torch_tensor_with_4_channel_dim(image_raw_resized_square).to(device)
# convert string to int
num_of_die_iterations = int(num_of_die_iterations)
# inference
image_die = die_model.enhance_document_image(
image_raw_list=[image_raw_resized_square_tensor],
num_of_die_iterations=num_of_die_iterations
)[0]
# postprocess
return remove_square_padding(
original_image=image_raw,
square_image=image_die,
resize_back_to_original=True
)
def main():
"""
Main function to set up and run the Gradio demo.
"""
args = parse_arguments()
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Set up model
die_token = os.getenv("DIE_TOKEN")
args.die_model_path = hf_hub_download(
repo_id="gabar92/die",
filename=args.die_model_path,
use_auth_token=die_token
)
die_model = UNetDIEModel(args=args)
# Prepare example images
example_image_list = [
[Image.open(os.path.join(args.example_image_path, image_path))]
for image_path in os.listdir(args.example_image_path)
]
description = "Welcome to the Document Image Enhancement (DIE) model demo on Hugging Face!\n\n" \
"" \
"This interactive application showcases a specialized AI model developed by " \
"the [Artificial Intelligence group](https://ai.renyi.hu) at the [Alfréd Rényi Institute of Mathematics](https://renyi.hu).\n\n" \
"" \
"Our DIE model is designed to enhance and restore archival and aged document images " \
"by removing various types of degradation, thereby making historical documents more legible " \
"and suitable for Optical Character Recognition (OCR) processing.\n\n" \
"" \
"The model effectively tackles 20-30 types of domain-specific noise found in historical records, " \
"such as scribbles, bleed-through text, faded or worn text, blurriness, textured noise, " \
"and unwanted background elements. " \
"By applying deep learning techniques, specifically a U-Net-based architecture, " \
"the model accurately cleans and clarifies text while preserving original details. " \
"This improved clarity dramatically boosts OCR accuracy, making it an ideal " \
"pre-processing tool in digitization workflows.\n\n" \
"" \
"If you’re interested in learning more about the model’s capabilities or potential applications, " \
"please contact us at: gabar92@renyi.hu.\n"
# Partial function for inference with model and device arguments
partial_die_inference = partial(die_inference, die_model=die_model, device=args.device)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Document Image Enhancement (DIE) model")
with gr.Row():
with gr.Column():
gr.Markdown(description)
with gr.Column():
# Display QR code as an image in Gradio
gr.Image(value=Image.open("logo/qr-code.png").resize((400, 400)), label="QR Code")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Degraded Document Image")
num_iterations = gr.Dropdown([1, 2, 3], label="Number of DIE Iterations", value=1)
run_button = gr.Button("Enhance Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Enhanced Document Image")
# Display example images
gr.Examples(
examples=example_image_list,
inputs=[input_image],
label="Example Images - Source: National Archives of Hungary and Budapest City Archives",
)
# Button trigger for inference
run_button.click(partial_die_inference, [input_image, num_iterations], output_image)
demo.launch()
def parse_arguments():
"""
Parses command-line arguments.
:return: argument namespace
"""
parser = argparse.ArgumentParser()
parser.add_argument("--die_model_path", default="2024_08_09_model_epoch_89.pt")
parser.add_argument("--example_image_path", default="example_images")
return parser.parse_args()
if __name__ == "__main__":
main()
|