redaction-detector-streamlit / streamlit_app.py
Alex Strick van Linschoten
remove comment 32767d3
import os
import pathlib
import tempfile
import time
from io import BytesIO
import pandas as pd
import altair as alt
import fitz
import gradio as gr
import PIL
import skimage
import streamlit as st
from fastai.learner import load_learner
from fastai.vision.all import *
from fpdf import FPDF
from icevision.all import *
from icevision.models.checkpoint import *
from PIL import Image as PILImage
CHECKPOINT_PATH = "./allsynthetic-imgsize768.pth"
def load_icevision_model():
return model_from_checkpoint(CHECKPOINT_PATH)
def load_fastai_model():
return load_learner("fastai-classification-model.pkl")
checkpoint_and_model = load_icevision_model()
model = checkpoint_and_model["model"]
model_type = checkpoint_and_model["model_type"]
class_map = checkpoint_and_model["class_map"]
img_size = checkpoint_and_model["img_size"]
valid_tfms = tfms.A.Adapter(
[*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]
)
learn = load_fastai_model()
labels = learn.dls.vocab
def get_content_area(pred_dict) -> int:
if "content" not in pred_dict["detection"]["labels"]:
return 0
content_bboxes = [
pred_dict["detection"]["bboxes"][idx]
for idx, label in enumerate(pred_dict["detection"]["labels"])
if label == "content"
]
cb = content_bboxes[0]
return (cb.xmax - cb.xmin) * (cb.ymax - cb.ymin)
def get_redaction_area(pred_dict) -> int:
if "redaction" not in pred_dict["detection"]["labels"]:
return 0
redaction_bboxes = [
pred_dict["detection"]["bboxes"][idx]
for idx, label in enumerate(pred_dict["detection"]["labels"])
if label == "redaction"
]
return sum(
(bbox.xmax - bbox.xmin) * (bbox.ymax - bbox.ymin)
for bbox in redaction_bboxes
)
st.title("Redaction Detector")
st.image(
"./synthetic-redactions.jpg",
width=300,
)
uploaded_pdf = st.file_uploader(
"Upload a PDF...",
type="pdf",
accept_multiple_files=False,
help="This application processes PDF files. Please upload a document you believe to contain redactions.",
on_change=None,
)
# Add a selectbox to the sidebar:
st.sidebar.header("Customisation Options")
graph_checkbox = st.sidebar.checkbox(
"Show analysis charts",
value=True,
help="Display charts analysising the redactions found in the document.",
)
extract_images_checkbox = st.sidebar.checkbox(
"Extract redacted images",
value=True,
help="Create a PDF file containing the redacted images with an object detection overlay highlighting their locations and the confidence the model had when detecting the redactions.",
)
# Add a slider to the sidebar:
confidence = st.sidebar.slider(
"Confidence level (%)",
min_value=0,
max_value=100,
value=80,
)
def get_pdf_document(input):
os.makedirs(str(pathlib.Path(filename_without_extension)), exist_ok=True)
with open(
str(pathlib.Path(filename_without_extension) / "output.pdf"), "wb"
) as f:
f.write(uploaded_pdf.getbuffer())
return fitz.open(
str(pathlib.Path(filename_without_extension) / "output.pdf")
)
def get_image_predictions(img):
return model_type.end2end_detect(
img,
valid_tfms,
model,
class_map=class_map,
detection_threshold=confidence / 100,
display_label=True,
display_bbox=True,
return_img=True,
font_size=16,
label_color="#FF59D6",
)
if uploaded_pdf is None:
st.markdown(pathlib.Path("article.md").read_text())
else:
st.text("Opening PDF...")
filename_without_extension = uploaded_pdf.name[:-4]
results = []
images = []
document = get_pdf_document(uploaded_pdf)
total_image_areas = 0
total_content_areas = 0
total_redaction_area = 0
tmp_dir = tempfile.gettempdir()
for page_num, page in enumerate(document, start=1):
image_pixmap = page.get_pixmap()
image = image_pixmap.tobytes()
_, _, probs = learn.predict(image)
results.append(
{labels[i]: float(probs[i]) for i in range(len(labels))}
)
if probs[0] > (confidence / 100):
redaction_count = len(images)
if not os.path.exists(
os.path.join(tmp_dir, filename_without_extension or "abc")
):
os.makedirs(os.path.join(tmp_dir, filename_without_extension))
image_pixmap.save(
os.path.join(
tmp_dir, filename_without_extension, f"page-{page_num}.png"
)
)
images.append(
[
f"Redacted page #{redaction_count + 1} on page {page_num}",
os.path.join(
tmp_dir,
filename_without_extension,
f"page-{page_num}.png",
),
]
)
redacted_pages = [
str(page + 1)
for page in range(len(results))
if results[page]["redacted"] > (confidence / 100)
]
report = os.path.join(
tmp_dir, filename_without_extension, "redacted_pages.pdf"
)
if extract_images_checkbox:
with st.spinner('Calculating redaction proportions...'):
pdf = FPDF(unit="cm", format="A4")
pdf.set_auto_page_break(0)
imagelist = sorted(
[
i
for i in os.listdir(
os.path.join(tmp_dir, filename_without_extension)
)
if i.endswith("png")
]
)
for image in imagelist:
with PILImage.open(
os.path.join(tmp_dir, filename_without_extension, image)
) as img:
size = img.size
width, height = size
if width > height:
pdf.add_page(orientation="L")
else:
pdf.add_page(orientation="P")
pred_dict = get_image_predictions(img)
total_image_areas += pred_dict["width"] * pred_dict["height"]
total_content_areas += get_content_area(pred_dict)
total_redaction_area += get_redaction_area(pred_dict)
pred_dict["img"].save(
os.path.join(
tmp_dir, filename_without_extension, f"pred-{image}"
),
)
pdf.image(
os.path.join(
tmp_dir, filename_without_extension, f"pred-{image}"
),
w=pdf.w,
h=pdf.h,
)
pdf.output(report, "F")
st.success('Image predictions complete!')
text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\nThe redacted page numbers were: {', '.join(redacted_pages)}. \n\n"
st.balloons()
if not extract_images_checkbox:
st.text(text_output)
# DISPLAY IMAGES
else:
total_redaction_proportion = round(
(total_redaction_area / total_image_areas) * 100, 1
)
content_redaction_proportion = round(
(total_redaction_area / total_content_areas) * 100, 1
)
redaction_analysis = f"- {total_redaction_proportion}% of the total area of the redacted pages was redacted. \n- {content_redaction_proportion}% of the actual content of those redacted pages was redacted."
source = pd.DataFrame(
{
"category": ["Unredacted", "Redacted"],
"value": [
100 - total_redaction_proportion,
total_redaction_proportion,
],
}
)
c = (
alt.Chart(source)
.mark_arc()
.encode(
theta=alt.Theta(field="value", type="quantitative"),
color=alt.Color(field="category", type="nominal"),
)
)
st.altair_chart(c, use_container_width=True)
st.text(text_output + redaction_analysis)
# DISPLAY IMAGES