import os import pathlib import tempfile import time from io import BytesIO import pandas as pd import altair as alt import fitz import gradio as gr import PIL import skimage import streamlit as st from fastai.learner import load_learner from fastai.vision.all import * from fpdf import FPDF from icevision.all import * from icevision.models.checkpoint import * from PIL import Image as PILImage CHECKPOINT_PATH = "./allsynthetic-imgsize768.pth" def load_icevision_model(): return model_from_checkpoint(CHECKPOINT_PATH) def load_fastai_model(): return load_learner("fastai-classification-model.pkl") checkpoint_and_model = load_icevision_model() model = checkpoint_and_model["model"] model_type = checkpoint_and_model["model_type"] class_map = checkpoint_and_model["class_map"] img_size = checkpoint_and_model["img_size"] valid_tfms = tfms.A.Adapter( [*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()] ) learn = load_fastai_model() labels = learn.dls.vocab def get_content_area(pred_dict) -> int: if "content" not in pred_dict["detection"]["labels"]: return 0 content_bboxes = [ pred_dict["detection"]["bboxes"][idx] for idx, label in enumerate(pred_dict["detection"]["labels"]) if label == "content" ] cb = content_bboxes[0] return (cb.xmax - cb.xmin) * (cb.ymax - cb.ymin) def get_redaction_area(pred_dict) -> int: if "redaction" not in pred_dict["detection"]["labels"]: return 0 redaction_bboxes = [ pred_dict["detection"]["bboxes"][idx] for idx, label in enumerate(pred_dict["detection"]["labels"]) if label == "redaction" ] return sum( (bbox.xmax - bbox.xmin) * (bbox.ymax - bbox.ymin) for bbox in redaction_bboxes ) st.title("Redaction Detector") st.image( "./synthetic-redactions.jpg", width=300, ) uploaded_pdf = st.file_uploader( "Upload a PDF...", type="pdf", accept_multiple_files=False, help="This application processes PDF files. Please upload a document you believe to contain redactions.", on_change=None, ) # Add a selectbox to the sidebar: st.sidebar.header("Customisation Options") graph_checkbox = st.sidebar.checkbox( "Show analysis charts", value=True, help="Display charts analysising the redactions found in the document.", ) extract_images_checkbox = st.sidebar.checkbox( "Extract redacted images", value=True, help="Create a PDF file containing the redacted images with an object detection overlay highlighting their locations and the confidence the model had when detecting the redactions.", ) # Add a slider to the sidebar: confidence = st.sidebar.slider( "Confidence level (%)", min_value=0, max_value=100, value=80, ) def get_pdf_document(input): os.makedirs(str(pathlib.Path(filename_without_extension)), exist_ok=True) with open( str(pathlib.Path(filename_without_extension) / "output.pdf"), "wb" ) as f: f.write(uploaded_pdf.getbuffer()) return fitz.open( str(pathlib.Path(filename_without_extension) / "output.pdf") ) def get_image_predictions(img): return model_type.end2end_detect( img, valid_tfms, model, class_map=class_map, detection_threshold=confidence / 100, display_label=True, display_bbox=True, return_img=True, font_size=16, label_color="#FF59D6", ) if uploaded_pdf is None: st.markdown(pathlib.Path("article.md").read_text()) else: st.text("Opening PDF...") filename_without_extension = uploaded_pdf.name[:-4] results = [] images = [] document = get_pdf_document(uploaded_pdf) total_image_areas = 0 total_content_areas = 0 total_redaction_area = 0 tmp_dir = tempfile.gettempdir() for page_num, page in enumerate(document, start=1): image_pixmap = page.get_pixmap() image = image_pixmap.tobytes() _, _, probs = learn.predict(image) results.append( {labels[i]: float(probs[i]) for i in range(len(labels))} ) if probs[0] > (confidence / 100): redaction_count = len(images) if not os.path.exists( os.path.join(tmp_dir, filename_without_extension or "abc") ): os.makedirs(os.path.join(tmp_dir, filename_without_extension)) image_pixmap.save( os.path.join( tmp_dir, filename_without_extension, f"page-{page_num}.png" ) ) images.append( [ f"Redacted page #{redaction_count + 1} on page {page_num}", os.path.join( tmp_dir, filename_without_extension, f"page-{page_num}.png", ), ] ) redacted_pages = [ str(page + 1) for page in range(len(results)) if results[page]["redacted"] > (confidence / 100) ] report = os.path.join( tmp_dir, filename_without_extension, "redacted_pages.pdf" ) if extract_images_checkbox: with st.spinner('Calculating redaction proportions...'): pdf = FPDF(unit="cm", format="A4") pdf.set_auto_page_break(0) imagelist = sorted( [ i for i in os.listdir( os.path.join(tmp_dir, filename_without_extension) ) if i.endswith("png") ] ) for image in imagelist: with PILImage.open( os.path.join(tmp_dir, filename_without_extension, image) ) as img: size = img.size width, height = size if width > height: pdf.add_page(orientation="L") else: pdf.add_page(orientation="P") pred_dict = get_image_predictions(img) total_image_areas += pred_dict["width"] * pred_dict["height"] total_content_areas += get_content_area(pred_dict) total_redaction_area += get_redaction_area(pred_dict) pred_dict["img"].save( os.path.join( tmp_dir, filename_without_extension, f"pred-{image}" ), ) pdf.image( os.path.join( tmp_dir, filename_without_extension, f"pred-{image}" ), w=pdf.w, h=pdf.h, ) pdf.output(report, "F") st.success('Image predictions complete!') text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\nThe redacted page numbers were: {', '.join(redacted_pages)}. \n\n" st.balloons() if not extract_images_checkbox: st.text(text_output) # DISPLAY IMAGES else: total_redaction_proportion = round( (total_redaction_area / total_image_areas) * 100, 1 ) content_redaction_proportion = round( (total_redaction_area / total_content_areas) * 100, 1 ) redaction_analysis = f"- {total_redaction_proportion}% of the total area of the redacted pages was redacted. \n- {content_redaction_proportion}% of the actual content of those redacted pages was redacted." source = pd.DataFrame( { "category": ["Unredacted", "Redacted"], "value": [ 100 - total_redaction_proportion, total_redaction_proportion, ], } ) c = ( alt.Chart(source) .mark_arc() .encode( theta=alt.Theta(field="value", type="quantitative"), color=alt.Color(field="category", type="nominal"), ) ) st.altair_chart(c, use_container_width=True) st.text(text_output + redaction_analysis) # DISPLAY IMAGES