RulerNet-Demo / app.py
Yimu Pan
Re-add examples as LFS tracked files
5b4cc5f
import os
import gradio as gr
import numpy as np
import onnxruntime as ort
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
# Retrieve the token from environment variables
hf_token = os.environ.get("HF_TOKEN")
# Download the private model.onnx file
model_path = hf_hub_download(
repo_id="ymp5078/RulerNet",
filename="model.onnx",
use_auth_token=hf_token
)
# ---- Load ONNX Model ----
ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
# ---- Utility Function ----
def outward_cumsum(initial_point, line_direction, directions, n):
left_directions = directions[n < 0][::-1]
right_directions = directions[n >= 0]
left_increments = -np.expand_dims(left_directions, axis=1) * line_direction
right_increments = np.expand_dims(right_directions, axis=1) * line_direction
zero = np.zeros((1, 2), dtype=initial_point.dtype)
left_cumulative = np.cumsum(np.vstack([zero, left_increments]), axis=0)
right_cumulative = np.cumsum(np.vstack([zero, right_increments]), axis=0)
left_points = initial_point + left_cumulative
right_points = initial_point + right_cumulative
extended_points = np.vstack([left_points[::-1], right_points[1:]])
return extended_points
# ---- Image Preprocessing ----
def preprocess_image(pil_img):
image = np.array(pil_img).astype(np.float32) / 255.0
if image.ndim == 2:
image = np.stack([image]*3, axis=-1)
elif image.shape[2] == 4:
image = image[:, :, :3]
resized = np.zeros((768, 768, 3), dtype=np.float32)
h, w = image.shape[:2]
scale = min(768 / w, 768 / h)
new_w, new_h = int(w * scale), int(h * scale)
image_resized = np.array(Image.fromarray((image * 255).astype(np.uint8)).resize((new_w, new_h))).astype(np.float32) / 255.0
top = (768 - new_h) // 2
left = (768 - new_w) // 2
resized[top:top+new_h, left:left+new_w] = image_resized
input_tensor = np.transpose(resized, (2, 0, 1))[np.newaxis, ...].copy()
return input_tensor, (top, left, new_h, new_w)
# ---- Main Inference and Drawing ----
def infer_and_draw(image_pil):
image_tensor, _ = preprocess_image(image_pil)
ort_inputs = {"input": image_tensor}
ort_outs = ort_session.run(None, ort_inputs)
left_point_2d_reconstructed = ort_outs[0][0] # shape: (2,)
dist = ort_outs[1][0][0] # scalar
ratio = ort_outs[2][0][0] # scalar
direction = ort_outs[3][0] # shape: (2,)
points_info = ort_outs[4][0]
min_x, min_y, max_x, max_y = points_info[1:].tolist()
num_points = int(points_info[0])
n = np.arange(-num_points, num_points + 1)
directions = (ratio ** n) * dist
extended_points = outward_cumsum(left_point_2d_reconstructed, direction, directions, n)
within_bounds = (
(extended_points[:, 0] >= min_x) & (extended_points[:, 0] <= max_x) &
(extended_points[:, 1] >= min_y) & (extended_points[:, 1] <= max_y)
)
best_generated_points = extended_points[within_bounds]
if len(best_generated_points) > 1:
diffs = np.linalg.norm(best_generated_points[:-1] - best_generated_points[1:], axis=1)
pred_pix_cm = np.nanmedian(diffs)
else:
pred_pix_cm = 0.0
# ---- Draw on image ----
np_img = (np.transpose(image_tensor[0], (1, 2, 0)) * 255).astype(np.uint8)
pil_img = Image.fromarray(np_img)
draw = ImageDraw.Draw(pil_img)
# Pixel grid overlay
grid_spacing = 50
width, height = pil_img.size
for x in range(0, width, grid_spacing):
draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1)
for y in range(0, height, grid_spacing):
draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1)
r = 3
for x, y in best_generated_points:
draw.ellipse((x - r, y - r, x + r, y + r), fill="red")
text = f"Pix/cm: {pred_pix_cm:.2f}"
text_position = (10, 10)
text_size = draw.textbbox(text_position, text)
padding = 4
rect_coords = (
text_size[0] - padding,
text_size[1] - padding,
text_size[2] + padding,
text_size[3] + padding
)
draw.rectangle(rect_coords, fill="white")
draw.text(text_position, text, fill="black")
return pil_img, f"{pred_pix_cm:.3f}", f"{ratio:.3f}"
if __name__ == '__main__':
demo = gr.Interface(
fn=infer_and_draw,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(label="Generated Ruler Points"),
gr.Textbox(label="Predicted median pixel-per-centimeter value (with grid lines every 50 pixels):"),
gr.Textbox(label="Predicted geometric progression ratio:")
],
examples=[
["examples/sample1.jpg"],
["examples/sample2.jpg"],
["examples/sample3.jpg"],
["examples/sample4.jpg"],
["examples/sample5.jpg"],
["examples/sample6.jpg"]
],
title="ONNX (CPU) Ruler Model Visualizer",
description="Upload an image to visualize ruler points generated by the ONNX model."
)
demo.launch()