FastShift / app.py
NivedPadikkal's picture
Update app.py
594b2f7 verified
import numpy as np
import cv2 as cv
import gradio as gr
def match_features(target_img, comp_img1, comp_img2, comp_img3, comp_img4, comp_img5, comp_img6, comp_img7, comp_img8):
# Initialize list to store results
result_images = []
match_counts = []
# List of comparison images
comparison_imgs = [comp_img1, comp_img2, comp_img3, comp_img4, comp_img5, comp_img6, comp_img7, comp_img8]
# Convert target image to grayscale OpenCV format
target_cv = np.array(target_img.convert("L"))
# Extract SIFT features from target image
sift = cv.SIFT_create()
kp_target, des_target = sift.detectAndCompute(target_cv, None)
# Initialize FLANN matcher
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
search_params = dict(checks=50)
flann = cv.FlannBasedMatcher(index_params, search_params)
# Process each comparison image
for img in comparison_imgs:
# Create a default blank image with "No image" text
blank_img = np.zeros((400, 800, 3), dtype=np.uint8)
cv.putText(blank_img, "No image provided", (250, 200), cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
if img is None:
result_images.append(blank_img)
match_counts.append(0)
continue
# Convert to grayscale OpenCV format
img_cv = np.array(img.convert("L"))
# Extract SIFT features
kp_img, des_img = sift.detectAndCompute(img_cv, None)
# Skip if no features detected
if des_img is None or des_target is None:
cv.putText(blank_img, "No features detected", (250, 200), cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
result_images.append(blank_img)
match_counts.append(0)
continue
try:
# Match features with FLANN
matches = flann.knnMatch(des_target, des_img, k=2)
# Apply ratio test
matchesMask = [[0, 0] for i in range(len(matches))]
good_matches = []
for i, pair in enumerate(matches):
if len(pair) < 2:
continue
m, n = pair
if m.distance < 0.7 * n.distance:
matchesMask[i] = [1, 0]
good_matches.append(m)
# Count good matches
match_count = len(good_matches)
match_counts.append(match_count)
# Draw matches
draw_params = dict(
matchColor=(0, 255, 0),
singlePointColor=(255, 0, 0),
matchesMask=matchesMask,
flags=cv.DrawMatchesFlags_DEFAULT
)
result_img = cv.drawMatchesKnn(target_cv, kp_target, img_cv, kp_img, matches, None, **draw_params)
# Convert to RGB for display
result_img = cv.cvtColor(result_img, cv.COLOR_BGR2RGB)
# Add match count text
font = cv.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)
thickness = 2
h, w = result_img.shape[:2]
text = f"Matches: {match_count}"
(text_width, text_height), _ = cv.getTextSize(text, font, font_scale, thickness)
x = (w - text_width) // 2
y = h - 20
cv.rectangle(result_img, (x-5, y-text_height-5), (x+text_width+5, y+5), (0,0,0), -1)
cv.putText(result_img, text, (x, y), font, font_scale, font_color, thickness)
result_images.append(result_img)
except Exception as e:
# Handle any errors
error_img = np.zeros((400, 800, 3), dtype=np.uint8)
cv.putText(error_img, f"Error: {str(e)}", (50, 200), cv.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 1)
result_images.append(error_img)
match_counts.append(0)
# Ensure we have 8 results
while len(result_images) < 8:
blank_img = np.zeros((400, 800, 3), dtype=np.uint8)
cv.putText(blank_img, "No image provided", (250, 200), cv.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
result_images.append(blank_img)
match_counts.append(0)
# Return all match results and counts
return result_images + [match_counts]
# Create Gradio interface with 1 target and up to 8 comparison images
with gr.Blocks(title="Image Feature Matching Comparison") as iface:
gr.Markdown("# Image Feature Matching with SIFT+FLANN")
gr.Markdown("""
Upload a target image and up to 8 comparison images to find feature matches using
SIFT (Scale-Invariant Feature Transform) with FLANN (Fast Library for Approximate Nearest Neighbors).
The number of matches will be displayed for each comparison.
""")
with gr.Row():
target_input = gr.Image(type="pil", label="Target Image")
with gr.Row():
comp_img1 = gr.Image(type="pil", label="Comparison Image 1")
comp_img2 = gr.Image(type="pil", label="Comparison Image 2")
with gr.Row():
comp_img3 = gr.Image(type="pil", label="Comparison Image 3")
comp_img4 = gr.Image(type="pil", label="Comparison Image 4")
with gr.Row():
comp_img5 = gr.Image(type="pil", label="Comparison Image 5")
comp_img6 = gr.Image(type="pil", label="Comparison Image 6")
with gr.Row():
comp_img7 = gr.Image(type="pil", label="Comparison Image 7")
comp_img8 = gr.Image(type="pil", label="Comparison Image 8")
compare_btn = gr.Button("Compare Images")
with gr.Row():
result1 = gr.Image(label="Result 1")
result2 = gr.Image(label="Result 2")
with gr.Row():
result3 = gr.Image(label="Result 3")
result4 = gr.Image(label="Result 4")
with gr.Row():
result5 = gr.Image(label="Result 5")
result6 = gr.Image(label="Result 6")
with gr.Row():
result7 = gr.Image(label="Result 7")
result8 = gr.Image(label="Result 8")
match_counts_output = gr.JSON(label="Match Counts")
compare_btn.click(
fn=match_features,
inputs=[
target_input,
comp_img1, comp_img2, comp_img3, comp_img4,
comp_img5, comp_img6, comp_img7, comp_img8
],
outputs=[
result1, result2, result3, result4,
result5, result6, result7, result8,
match_counts_output
]
)
# Launch the interface
iface.launch()