Alex Strick van Linschoten
fix zero division error
8b969e2
raw history blame
No virus
6.84 kB
import os
import tempfile
import fitz
import gradio as gr
import PIL
import skimage
from fastai.learner import load_learner
from fastai.vision.all import *
from fpdf import FPDF
from huggingface_hub import hf_hub_download
from icevision.all import *
from icevision.models.checkpoint import *
from PIL import Image as PILImage
checkpoint_path = "./allsynthetic-imgsize768.pth"
checkpoint_and_model = model_from_checkpoint(checkpoint_path)
model = checkpoint_and_model["model"]
model_type = checkpoint_and_model["model_type"]
class_map = checkpoint_and_model["class_map"]
img_size = checkpoint_and_model["img_size"]
valid_tfms = tfms.A.Adapter(
[*tfms.A.resize_and_pad(img_size), tfms.A.Normalize()]
)
learn = load_learner(
hf_hub_download("strickvl/redaction-classifier-fastai", "model.pkl")
)
labels = learn.dls.vocab
def get_content_area(pred_dict) -> int:
if "content" not in pred_dict["labels"]:
return 0
content_bboxes = [
pred_dict["bboxes"][idx]
for idx, label in enumerate(pred_dict["labels"])
if label == "content"
]
cb = content_bboxes[0]
return (cb.xmax - cb.xmin) * (cb.ymax - cb.ymin)
def get_redaction_area(pred_dict) -> int:
if "redaction" not in pred_dict["labels"]:
return 0
redaction_bboxes = [
pred_dict["bboxes"][idx]
for idx, label in enumerate(pred_dict["labels"])
if label == "redaction"
]
return sum(
(bbox.xmax - bbox.xmin) * (bbox.ymax - bbox.ymin)
for bbox in redaction_bboxes
)
def predict(pdf, confidence, generate_file):
filename_without_extension = pdf.name[:-4]
document = fitz.open(pdf.name)
results = []
images = []
total_image_areas = 0
total_content_areas = 0
total_redaction_area = 0
tmp_dir = tempfile.gettempdir()
for page_num, page in enumerate(document, start=1):
image_pixmap = page.get_pixmap()
image = image_pixmap.tobytes()
_, _, probs = learn.predict(image)
results.append(
{labels[i]: float(probs[i]) for i in range(len(labels))}
)
if probs[0] > (confidence / 100):
redaction_count = len(images)
if not os.path.exists(
os.path.join(tmp_dir, filename_without_extension)
):
os.makedirs(os.path.join(tmp_dir, filename_without_extension))
image_pixmap.save(
os.path.join(
tmp_dir, filename_without_extension, f"page-{page_num}.png"
)
)
images.append(
[
f"Redacted page #{redaction_count + 1} on page {page_num}",
os.path.join(
tmp_dir,
filename_without_extension,
f"page-{page_num}.png",
),
]
)
redacted_pages = [
str(page + 1)
for page in range(len(results))
if results[page]["redacted"] > (confidence / 100)
]
report = os.path.join(
tmp_dir, filename_without_extension, "redacted_pages.pdf"
)
if generate_file:
pdf = FPDF()
pdf.set_auto_page_break(0)
imagelist = sorted(
[
i
for i in os.listdir(
os.path.join(tmp_dir, filename_without_extension)
)
if i.endswith("png")
]
)
for image in imagelist:
with PILImage.open(
os.path.join(tmp_dir, filename_without_extension, image)
) as img:
size = img.size
if size[0] > size[1]:
pdf.add_page("L")
else:
pdf.add_page("P")
pred_dict = model_type.end2end_detect(
img,
valid_tfms,
model,
class_map=class_map,
detection_threshold=confidence / 100,
display_label=True,
display_bbox=True,
return_img=True,
font_size=16,
label_color="#FF59D6",
)
total_image_areas += pred_dict["width"] * pred_dict["height"]
total_content_areas += get_content_area(pred_dict)
total_redaction_area += get_redaction_area(pred_dict)
pred_dict["img"].save(
os.path.join(
tmp_dir, filename_without_extension, f"pred-{image}"
)
)
# TODO: resize image such that it fits the pdf
pdf.image(
os.path.join(
tmp_dir, filename_without_extension, f"pred-{image}"
)
)
pdf.output(report, "F")
total_redaction_proportion = round(
(total_redaction_area / total_image_areas) * 100, 1
)
content_redaction_proportion = round(
(total_redaction_area / total_content_areas) * 100, 1
)
text_output = f"A total of {len(redacted_pages)} pages were redacted. \n\n The redacted page numbers were: {', '.join(redacted_pages)}. "
if generate_file:
redaction_analysis = f"{total_redaction_proportion}% of the total area of the redacted pages was redacted. \n {content_redaction_proportion}% of the actual content of those redacted pages was redacted."
return text_output + redaction_analysis, images, report
else:
return text_output, images, None
title = "Redaction Detector"
description = "A classifier trained on publicly released redacted (and unredacted) FOIA documents, using [fastai](https://github.com/fastai/fastai)."
with open("article.md") as f:
article = f.read()
examples = [["test1.pdf", 80, False], ["test2.pdf", 80, False]]
interpretation = "default"
enable_queue = True
theme = "grass"
allow_flagging = "never"
demo = gr.Interface(
fn=predict,
inputs=[
gr.inputs.File(label="PDF file", file_count="single"),
gr.inputs.Slider(
minimum=0,
maximum=100,
step=None,
default=80,
label="Confidence",
optional=False,
),
gr.inputs.Checkbox(label="Extract redacted images"),
],
outputs=[
gr.outputs.Textbox(label="Document Analysis"),
gr.outputs.Carousel(["text", "image"], label="Redacted pages"),
gr.outputs.File(label="Download redacted pages"),
],
title=title,
description=description,
article=article,
theme=theme,
allow_flagging=allow_flagging,
examples=examples,
interpretation=interpretation,
)
demo.launch(
cache_examples=True,
enable_queue=enable_queue,
)