import streamlit as st import torch from PIL import Image import matplotlib.pyplot as plt from safetensors.torch import load_model from transformers import pipeline import torch from torch import nn from torch.nn import functional as func_nn from einops import rearrange from huggingface_hub import PyTorchModelHubMixin from torchvision import models # main model network class SiameseNetwork(nn.Module, PyTorchModelHubMixin): def __init__(self): super().__init__() # convolutional layer/block # self.convnet = MobileNet() self.convnet = models.mobilenet_v2(pretrained=True) # pretrained backbone num_ftrs = self.convnet.classifier[1].in_features # get the first deimnesion of model head self.convnet.classifier[1] = nn.Linear(num_ftrs, 512) # change/switch backbone linear head # fully connected layer for classification self.fc_linear = nn.Sequential( nn.Linear(512, 128), nn.ReLU(inplace=True), # actvation layer nn.Linear(128, 2) ) def single_pass(self, x) -> torch.Tensor: # sinlge Forward pass for each image x = rearrange(x, 'b h w c -> b c h w') # rearrange to (batch, channels, height, width) to match model input output = self.convnet(x) output = self.fc_linear(output) return output def forward(self, input_1: torch.Tensor, input_2: torch.Tensor) -> torch.Tensor: # forward pass of first image output_1 = self.single_pass(input_1) # forward pass of second contrast image output_2 = self.single_pass(input_2) return output_1, output_2 # pretrained model file model_file = 'model.safetensors' #config.safetensor_file # Function to compute similarity def compute_similarity(output1, output2): return torch.nn.functional.cosine_similarity(output1, output2).item() # Function to visualize feature heatmaps def visualize_heatmap(model, image): model.eval() x = image.unsqueeze(0) # remove batch dimension features = model.convnet(x) # feature heatmap learnt by model heatmap = torch.mean(features, dim=1).squeeze().detach().numpy() # normalize heatmap to ndarray plt.imshow(heatmap, cmap="hot") # display heatmap as plot plt.axis("off") return plt # Load the pre-trained model from safeetesor file def load_pipeline(): model_file = 'model.safetensors' #config.safetensor_file # model_id = 'tensorkelechi/signature_mobilenet' model = SiameseNetwork() # model class/skeleton # model.load_state_dict(torch.load(model_file)) model = load_model(model, model_file) # model = pipeline('image-classification', model=model_id, device='cpu') model.eval() return model.to('cpu') # Streamlit app UI template st.title("Signature Forgery Detection") st.write('Application to run/test signature forgery detecton model') st.subheader('Compare signatures') # File uploaders for the two images original_image = st.file_uploader( "Upload the original signature", type=["png", "jpg", "jpeg"] ) comparison_image = st.file_uploader( "Upload the signature to compare", type=["png", "jpg", "jpeg"] ) def run_model_pipeline(model, original_image, comparison_image, threshold=0.5): if original_image is not None and comparison_image is not None: # ensure both images are uploaded # Preprocess images img1 = Image.open(original_image).convert("RGB") # load images from file paths to PIL Image img2 = Image.open(comparison_image).convert("RGB") # read/reshape and normalize as numpy array img1 = read_image(img1) img2 = read_image(img2) # convert to tensors and add batch dimensions to match model input shape img1_tensor = torch.unsqueeze(torch.as_tensor(img1), 0) img2_tensor = torch.unsqueeze(torch.as_tensor(img2), 0) # Get model embeddings/probabilites output1, output2 = model(img1_tensor, img2_tensor) st.success('outputs extracted') # Compute similarity similarity = compute_similarity(output1, output2) # Determine if it's a forgery based on determined threshold is_forgery = similarity < threshold # Display results st.subheader("Results") st.write(f"Similarity: {similarity:.2f}") st.write(f"Classification: {'Forgery' if is_forgery else 'Genuine'}") # Display images col1, col2 = st.columns(2) # GUI columns with col1: st.image(img1, caption="Original Signature", use_column_width=True) with col2: st.image(img2, caption="Comparison Signature", use_column_width=True) # Visualize heatmaps from extracted model features st.subheader("Feature Heatmaps") col3, col4 = st.columns(2) with col3: fig1 = visualize_heatmap(model, img1_tensor) st.pyplot(fig1) with col4: fig2 = visualize_heatmap(model, img2_tensor) st.pyplot(fig2) else: st.write("Please upload both the original and comparison signatures.") # Run the model pipeline if a button is clicked if st.button("Run Model Pipeline"): model = load_pipeline() # button click to process images if st.button("Process Images"): run_model_pipeline(model, original_image, comparison_image)