File size: 5,436 Bytes
46a90f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb5ef54
8e7b01a
 
 
46a90f0
 
8e7b01a
 
46a90f0
 
8e7b01a
46a90f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
import torch
from PIL import Image
import matplotlib.pyplot as plt
from safetensors.torch import load_model
from transformers import pipeline
import torch
from torch import nn
from torch.nn import functional as func_nn
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from torchvision import models


# main model network
class SiameseNetwork(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super().__init__()

        # convolutional layer/block
        # self.convnet = MobileNet()
        self.convnet = models.mobilenet_v2(pretrained=True) # pretrained backbone   
        num_ftrs = self.convnet.classifier[1].in_features # get the first deimnesion of model head

        self.convnet.classifier[1] = nn.Linear(num_ftrs, 512) # change/switch backbone linear head  

        # fully connected layer for classification
        self.fc_linear = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(inplace=True), # actvation layer
            nn.Linear(128, 2)
        )

    def single_pass(self, x) -> torch.Tensor: 
        # sinlge Forward pass for each image
        x = rearrange(x, 'b h w c -> b c h w') # rearrange to (batch, channels, height, width)  to match model input
        output = self.convnet(x)
        output = self.fc_linear(output)

        return output

    def forward(self, input_1: torch.Tensor, input_2: torch.Tensor) -> torch.Tensor:

        # forward pass of first image
        output_1 = self.single_pass(input_1)

        # forward pass of second contrast image
        output_2 = self.single_pass(input_2)

        return output_1, output_2


# pretrained model file
model_file = 'best_signature_mobilenet.safetensors' #config.safetensor_file


# Function to compute similarity
def compute_similarity(output1, output2):
    return torch.nn.functional.cosine_similarity(output1, output2).item()

# Function to visualize feature heatmaps
def visualize_heatmap(model, image):
    model.eval()
    x = image.unsqueeze(0) # remove batch dimension
    features = model.convnet(x) # feature heatmap learnt by model
    heatmap = torch.mean(features, dim=1).squeeze().detach().numpy() # normalize heatmap to ndarray
    plt.imshow(heatmap, cmap="hot") # display heatmap as plot
    plt.axis("off")
    
    return plt

# Load the pre-trained model from safeetesor file
def load_pipeline():
    model_file = 'best_signature_mobilenet.safetensors' #config.safetensor_file
        # model_id = 'tensorkelechi/signature_mobilenet'
    model = SiameseNetwork() # model class/skeleton
    
    # model.load_state_dict(torch.load(model_file))
    model = load_model(model, model_file)
    # model = pipeline('image-classification', model=model_id, device='cpu') 
    model.eval()
    
    return model.to('cpu')

# Streamlit app UI template
st.title("Signature Forgery Detection")
st.write('Application to run/test signature forgery detecton model')

st.subheader('Compare signatures')
# File uploaders for the two images
original_image = st.file_uploader(
    "Upload the original signature", type=["png", "jpg", "jpeg"]
)
comparison_image = st.file_uploader(
    "Upload the signature to compare", type=["png", "jpg", "jpeg"]
)

def run_model_pipeline(model, original_image, comparison_image, threshold=0.5):
    if original_image is not None and comparison_image is not None: # ensure both images are uploaded

        # Preprocess images
        img1 = Image.open(original_image).convert("RGB") # load images from file paths to PIL Image
        img2 = Image.open(comparison_image).convert("RGB")

        # read/reshape and normalize as numpy array
        img1 = read_image(img1)
        img2 = read_image(img2)

        # convert to tensors and add batch dimensions to match model input shape
        img1_tensor = torch.unsqueeze(torch.as_tensor(img1), 0)
        img2_tensor = torch.unsqueeze(torch.as_tensor(img2), 0)

        # Get model embeddings/probabilites
        output1, output2 = model(img1_tensor, img2_tensor)
        st.success('outputs extracted')

        # Compute similarity
        similarity = compute_similarity(output1, output2)

        # Determine if it's a forgery based on determined threshold
        is_forgery = similarity < threshold

        # Display results
        st.subheader("Results")
        st.write(f"Similarity: {similarity:.2f}")
        st.write(f"Classification: {'Forgery' if is_forgery else 'Genuine'}")

        # Display images
        col1, col2 = st.columns(2) # GUI columns

        with col1:
            st.image(img1, caption="Original Signature", use_column_width=True)
        with col2:
            st.image(img2, caption="Comparison Signature", use_column_width=True)

        # Visualize heatmaps from extracted model features
        st.subheader("Feature Heatmaps")
        col3, col4 = st.columns(2)
        with col3:
            fig1 = visualize_heatmap(model, img1_tensor)
            st.pyplot(fig1)
        with col4:
            fig2 = visualize_heatmap(model, img2_tensor)
            st.pyplot(fig2)

    else:
        st.write("Please upload both the original and comparison signatures.")

# Run the model pipeline if a button is clicked
if st.button("Run Model Pipeline"):
    model = load_pipeline()
    
    # button click to process images
    if st.button("Process Images"):
        run_model_pipeline(model, original_image, comparison_image)