MaterialsDatabase / page_files /categorized /Backend /Pdf_ImageExtraction.py
AbhijitClemson's picture
Upload folder using huggingface_hub
1adc2e7 verified
import os
import re
import json
import math
import tempfile
import fitz # PyMuPDF
import cv2
import numpy as np
from PIL import Image
import streamlit as st
# -------------------
# Config
# -------------------
DPI = 300
OUT_DIR = "outputs"
KEEP_ONLY_STRESS_STRAIN = False
CAP_RE = re.compile(r"^(Fig\.?\s*\d+|Figure\s*\d+)\b", re.IGNORECASE)
SS_KW = re.compile(
r"(stress\s*[-–]?\s*strain|stress|strain|tensile|MPa|GPa|kN|yield|elongation)",
re.IGNORECASE
)
# -------------------
# Render helpers
# -------------------
def render_page(page, dpi=DPI):
mat = fitz.Matrix(dpi/72, dpi/72)
pix = page.get_pixmap(matrix=mat, alpha=False)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return img, mat
def pdf_to_px_bbox(bbox_pdf, mat):
x0, y0, x1, y1 = bbox_pdf
sx, sy = mat.a, mat.d
return (int(float(x0) * sx), int(float(y0) * sy), int(float(x1) * sx), int(float(y1) * sy))
def safe_crop_px(pil_img, box):
if not isinstance(box, (tuple, list)):
return None
if len(box) == 1 and isinstance(box[0], (tuple, list)) and len(box[0]) == 4:
box = box[0]
if len(box) != 4:
return None
x0, y0, x1, y1 = box
if any(isinstance(v, (tuple, list)) for v in (x0, y0, x1, y1)):
return None
try:
x0 = int(x0)
y0 = int(y0)
x1 = int(x1)
y1 = int(y1)
except (TypeError, ValueError):
return None
if x1 < x0:
x0, x1 = x1, x0
if y1 < y0:
y0, y1 = y1, y0
W, H = pil_img.size
x0 = max(0, min(W, x0))
x1 = max(0, min(W, x1))
y0 = max(0, min(H, y0))
y1 = max(0, min(H, y1))
if x1 <= x0 or y1 <= y0:
return None
return pil_img.crop((x0, y0, x1, y1))
# -------------------
# Captions
# -------------------
def find_caption_blocks(page):
caps = []
blocks = page.get_text("blocks")
for b in blocks:
x0, y0, x1, y1, text = b[0], b[1], b[2], b[3], b[4]
t = " ".join(str(text).strip().split())
if CAP_RE.match(t):
caps.append({"bbox": (x0, y0, x1, y1), "text": t})
return caps
# -------------------
# Dedupe: dHash
# -------------------
def dhash64(pil_img):
gray = pil_img.convert("L").resize((9, 8), Image.LANCZOS)
pixels = list(gray.getdata())
bits = 0
for r in range(8):
for c in range(8):
left = pixels[r * 9 + c]
right = pixels[r * 9 + c + 1]
bits = (bits << 1) | (1 if left > right else 0)
return bits
# -------------------
# Rejectors
# -------------------
def has_colorbar_like_strip(pil_img):
img = np.array(pil_img)
if img.ndim != 3:
return False
H, W, _ = img.shape
if W < 250 or H < 150:
return False
strip_w = max(18, int(0.07 * W))
strip = img[:, W-strip_w:W, :]
q = (strip // 24).reshape(-1, 3)
uniq = np.unique(q, axis=0)
return len(uniq) > 70
def texture_score(pil_img):
gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
lap = cv2.Laplacian(gray, cv2.CV_64F)
return float(lap.var())
def is_mostly_legend(pil_img):
gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
bw = cv2.medianBlur(bw, 3)
H, W = bw.shape
fill = float(np.count_nonzero(bw)) / float(H * W)
return (0.03 < fill < 0.18) and (min(H, W) < 260)
# -------------------
# Plot detection
# -------------------
def detect_axes_lines(pil_img):
gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
H, W = gray.shape
min_len = int(0.28 * min(H, W))
lines = cv2.HoughLinesP(
edges, 1, np.pi/180,
threshold=90,
minLineLength=min_len,
maxLineGap=14
)
if lines is None:
return None, None
horizontals, verticals = [], []
for x1, y1, x2, y2 in lines[:, 0]:
dx, dy = abs(x2-x1), abs(y2-y1)
length = math.hypot(dx, dy)
if dy < 18 and dx > 0.35 * W:
horizontals.append((length, (x1, y1, x2, y2)))
if dx < 18 and dy > 0.35 * H:
verticals.append((length, (x1, y1, x2, y2)))
if not horizontals or not verticals:
return None, None
horizontals.sort(key=lambda t: t[0], reverse=True)
verticals.sort(key=lambda t: t[0], reverse=True)
return horizontals[0][1], verticals[0][1]
def axis_intersection_ok(x_axis, y_axis, W, H):
xa_y = int(round((x_axis[1] + x_axis[3]) / 2))
ya_x = int(round((y_axis[0] + y_axis[2]) / 2))
if not (0 <= xa_y < H and 0 <= ya_x < W):
return False
if ya_x > int(0.95 * W) or xa_y < int(0.05 * H):
return False
return True
def tick_text_presence_score(pil_img, x_axis, y_axis):
img = np.array(pil_img)
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
bw = cv2.medianBlur(bw, 3)
H, W = gray.shape
xa_y = int(round((x_axis[1] + x_axis[3]) / 2))
ya_x = int(round((y_axis[0] + y_axis[2]) / 2))
y0a = max(0, xa_y - 40)
y1a = min(H, xa_y + 110)
x_roi = bw[y0a:y1a, 0:W]
x0b = max(0, ya_x - 180)
x1b = min(W, ya_x + 50)
y_roi = bw[0:H, x0b:x1b]
def count_small_components(mask):
num, _, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
cnt = 0
for i in range(1, num):
x, y, w, h, area = stats[i]
if 4 <= w <= 150 and 4 <= h <= 150 and 20 <= area <= 5000:
cnt += 1
return cnt
return count_small_components(x_roi) + count_small_components(y_roi)
def is_real_plot(pil_img):
if has_colorbar_like_strip(pil_img):
return False
if is_mostly_legend(pil_img):
return False
x_axis, y_axis = detect_axes_lines(pil_img)
if x_axis is None or y_axis is None:
return False
arr = np.array(pil_img)
H, W = arr.shape[0], arr.shape[1]
if not axis_intersection_ok(x_axis, y_axis, W, H):
return False
if texture_score(pil_img) > 2200:
return False
score = tick_text_presence_score(pil_img, x_axis, y_axis)
return score >= 18
# -------------------
# Candidate boxes in a region
# -------------------
def connected_components_boxes(pil_img):
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
mask = (gray < 245).astype(np.uint8) * 255
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((7, 7), np.uint8), iterations=2)
num, _, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
boxes = []
for i in range(1, num):
x, y, w, h, area = stats[i]
boxes.append((int(area), (int(x), int(y), int(x + w), int(y + h))))
boxes.sort(key=lambda t: t[0], reverse=True)
return boxes
def expand_box(box, W, H, left=0.10, right=0.06, top=0.06, bottom=0.18):
x0, y0, x1, y1 = box
bw = x1 - x0
bh = y1 - y0
ex0 = max(0, int(x0 - left * bw))
ex1 = min(W, int(x1 + right * bw))
ey0 = max(0, int(y0 - top * bh))
ey1 = min(H, int(y1 + bottom * bh))
return (ex0, ey0, ex1, ey1)
# -------------------
# Crop plot from caption
# -------------------
def crop_plot_from_caption(page_img, cap_bbox_pdf, mat):
cap_px = pdf_to_px_bbox(cap_bbox_pdf, mat)
cap_y0 = cap_px[1]
cap_y1 = cap_px[3]
W, H = page_img.size
search_top = max(0, cap_y0 - int(0.95 * H))
search_bot = min(H, cap_y1 + int(0.20 * H))
region = safe_crop_px(page_img, (0, search_top, W, search_bot))
if region is None:
return None
comps = connected_components_boxes(region)
best = None
best_area = -1
for area, box in comps[:35]:
x0, y0, x1, y1 = box
bw = x1 - x0
bh = y1 - y0
if bw < 220 or bh < 180:
continue
exp = expand_box(box, region.size[0], region.size[1])
cand = safe_crop_px(region, exp)
if cand is None:
continue
if not is_real_plot(cand):
continue
if area > best_area:
best_area = area
best = cand
return best
# -------------------
# Streamlit UI
# -------------------
def run_extraction(pdf_path, paper_id="uploaded_paper"):
out_paper = os.path.join(OUT_DIR, paper_id)
out_imgs = os.path.join(out_paper, "plots_with_axes")
os.makedirs(out_imgs, exist_ok=True)
doc = fitz.open(pdf_path)
results = []
seen = set()
saved = 0
for p in range(len(doc)):
page = doc[p]
caps = find_caption_blocks(page)
if not caps:
continue
page_img, mat = render_page(page, dpi=DPI)
for cap in caps:
cap_text = cap["text"]
if KEEP_ONLY_STRESS_STRAIN and not SS_KW.search(cap_text):
continue
fig = crop_plot_from_caption(page_img, cap["bbox"], mat)
if fig is None:
continue
if fig.size[0] > 8 and fig.size[1] > 8:
fig = fig.crop((2, 2, fig.size[0]-2, fig.size[1]-2))
try:
h = dhash64(fig)
except Exception:
continue
if h in seen:
continue
seen.add(h)
img_name = f"p{p+1:02d}_{saved:04d}.png"
img_path = os.path.join(out_imgs, img_name)
fig.save(img_path)
results.append({
"page": p + 1,
"caption": cap_text,
"image": img_path
})
saved += 1
out_json = os.path.join(out_paper, "plots_with_axes.json")
with open(out_json, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
return results, out_json
def main():
st.set_page_config(page_title="Research Paper Plot Extractor", layout="wide")
st.title(" Plot Extractor (Upload PDF)")
uploaded = st.file_uploader("Upload a research paper PDF", type=["pdf"])
if not uploaded:
st.info("Upload a PDF to extract plots.")
return
paper_id = os.path.splitext(uploaded.name)[0].replace(" ", "_")
with tempfile.TemporaryDirectory() as tmpdir:
pdf_path = os.path.join(tmpdir, uploaded.name)
with open(pdf_path, "wb") as f:
f.write(uploaded.read())
with st.spinner("Extracting plots..."):
results, out_json = run_extraction(pdf_path, paper_id=paper_id)
st.success(f"Extracted {len(results)} plots.")
# Show images + captions
for r in results:
st.markdown(f"**Page {r['page']}** — {r['caption']}")
st.image(r["image"], use_container_width=True)
st.divider()
# JSON viewer + download
st.subheader("JSON Output")
st.json(results)
with open(out_json, "rb") as f:
st.download_button(
"Download JSON",
data=f,
file_name=os.path.basename(out_json),
mime="application/json"
)
if __name__ == "__main__":
main()