redaction-detector-streamlit / streamlit_app.py
Alex Strick van Linschoten
update app
b0f2ac0
raw
history blame
No virus
7.29 kB
from io import BytesIO
import os
import pathlib
import tempfile
import time
import fitz
import gradio as gr
import PIL
import skimage
import streamlit as st
from fastai.learner import load_learner
from fastai.vision.all import *
from fpdf import FPDF
from icevision.all import *
from icevision.models.checkpoint import *
from PIL import Image as PILImage
CHECKPOINT_PATH = "./allsynthetic-imgsize768.pth"
@st.cache
def load_icevision_model():
return model_from_checkpoint(CHECKPOINT_PATH)
@st.cache
def load_fastai_model():
return load_learner("fastai-classification-model.pkl")
checkpoint_and_model = load_icevision_model()
model = checkpoint_and_model["model"]
model_type = checkpoint_and_model["model_type"]
class_map = checkpoint_and_model["class_map"]
img_size = checkpoint_and_model["img_size"]
valid_tfms = tfms.A.Adapter(
[*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]
)
learn = load_fastai_model()
labels = learn.dls.vocab
@st.experimental_memo
def get_content_area(pred_dict) -> int:
if "content" not in pred_dict["detection"]["labels"]:
return 0
content_bboxes = [
pred_dict["detection"]["bboxes"][idx]
for idx, label in enumerate(pred_dict["detection"]["labels"])
if label == "content"
]
cb = content_bboxes[0]
return (cb.xmax - cb.xmin) * (cb.ymax - cb.ymin)
@st.experimental_memo
def get_redaction_area(pred_dict) -> int:
if "redaction" not in pred_dict["detection"]["labels"]:
return 0
redaction_bboxes = [
pred_dict["detection"]["bboxes"][idx]
for idx, label in enumerate(pred_dict["detection"]["labels"])
if label == "redaction"
]
return sum(
(bbox.xmax - bbox.xmin) * (bbox.ymax - bbox.ymin)
for bbox in redaction_bboxes
)
st.title("Redaction Detector")
st.image(
"./synthetic-redactions.jpg",
width=300,
)
uploaded_pdf = st.file_uploader(
"Upload a PDF...",
type="pdf",
accept_multiple_files=False,
help="This application processes PDF files. Please upload a document you believe to contain redactions.",
on_change=None,
)
# Add a selectbox to the sidebar:
st.sidebar.header("Customisation Options")
graph_checkbox = st.sidebar.checkbox(
"Show analysis charts",
value=True,
help="Display charts analysising the redactions found in the document.",
)
extract_images_checkbox = st.sidebar.checkbox(
"Extract redacted images",
value=True,
help="Create a PDF file containing the redacted images with an object detection overlay highlighting their locations and the confidence the model had when detecting the redactions.",
)
# Add a slider to the sidebar:
confidence = st.sidebar.slider(
"Confidence level (%)",
min_value=0,
max_value=100,
value=80,
)
@st.cache
def get_pdf_document(input):
with open(
pathlib.Path(filename_without_extension / "output.pdf"), "wb"
) as f:
f.write(uploaded_pdf.getbuffer())
return fitz.open("output.pdf")
@st.cache
def get_image_predictions(img):
return model_type.end2end_detect(
img,
valid_tfms,
model,
class_map=class_map,
detection_threshold=confidence / 100,
display_label=True,
display_bbox=True,
return_img=True,
font_size=16,
label_color="#FF59D6",
)
if uploaded_pdf is None:
st.markdown(pathlib.Path("article.md").read_text())
else:
st.text("Opening PDF...")
filename_without_extension = uploaded_pdf.name[:-4]
results = []
images = []
document = get_pdf_document(uploaded_pdf)
total_image_areas = 0
total_content_areas = 0
total_redaction_area = 0
tmp_dir = tempfile.gettempdir()
for page_num, page in enumerate(document, start=1):
image_pixmap = page.get_pixmap()
image = image_pixmap.tobytes()
_, _, probs = learn.predict(image)
results.append(
{labels[i]: float(probs[i]) for i in range(len(labels))}
)
if probs[0] > (confidence / 100):
redaction_count = len(images)
if not os.path.exists(
os.path.join(tmp_dir, filename_without_extension or "abc")
):
os.makedirs(os.path.join(tmp_dir, filename_without_extension))
image_pixmap.save(
os.path.join(
tmp_dir, filename_without_extension, f"page-{page_num}.png"
)
)
images.append(
[
f"Redacted page #{redaction_count + 1} on page {page_num}",
os.path.join(
tmp_dir,
filename_without_extension,
f"page-{page_num}.png",
),
]
)
redacted_pages = [
str(page + 1)
for page in range(len(results))
if results[page]["redacted"] > (confidence / 100)
]
report = os.path.join(
tmp_dir, filename_without_extension, "redacted_pages.pdf"
)
if extract_images_checkbox:
pdf = FPDF(unit="cm", format="A4")
pdf.set_auto_page_break(0)
imagelist = sorted(
[
i
for i in os.listdir(
os.path.join(tmp_dir, filename_without_extension)
)
if i.endswith("png")
]
)
for image in imagelist:
with PILImage.open(
os.path.join(tmp_dir, filename_without_extension, image)
) as img:
size = img.size
width, height = size
if width > height:
pdf.add_page(orientation="L")
else:
pdf.add_page(orientation="P")
pred_dict = get_image_predictions(img)
total_image_areas += pred_dict["width"] * pred_dict["height"]
total_content_areas += get_content_area(pred_dict)
total_redaction_area += get_redaction_area(pred_dict)
pred_dict["img"].save(
os.path.join(
tmp_dir, filename_without_extension, f"pred-{image}"
),
)
pdf.image(
os.path.join(
tmp_dir, filename_without_extension, f"pred-{image}"
),
w=pdf.w,
h=pdf.h,
)
pdf.output(report, "F")
text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\nThe redacted page numbers were: {', '.join(redacted_pages)}. \n\n"
if not extract_images_checkbox:
st.text(text_output)
# DISPLAY IMAGES
else:
total_redaction_proportion = round(
(total_redaction_area / total_image_areas) * 100, 1
)
content_redaction_proportion = round(
(total_redaction_area / total_content_areas) * 100, 1
)
redaction_analysis = f"- {total_redaction_proportion}% of the total area of the redacted pages was redacted. \n- {content_redaction_proportion}% of the actual content of those redacted pages was redacted."
st.text(text_output + redaction_analysis)
# DISPLAY IMAGES