Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
# Initialize the image-to-text pipeline and models | |
def load_models(): | |
# Make sure to use the correct model names and tokenizer | |
image_pipeline = pipeline("image-to-text", model="microsoft/trocr-large-printed") | |
phishing_model = AutoModelForSequenceClassification.from_pretrained("kithangw/phishing_link_detection", num_labels=2) | |
phishing_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
return image_pipeline, phishing_model, phishing_tokenizer | |
# Define the phishing check function | |
def check_phishing(phishing_model, phishing_tokenizer, url_for_recognize): | |
link_token = phishing_tokenizer(url_for_recognize, max_length=512, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): # Disable gradient calculation for inference | |
output = phishing_model(**link_token) | |
probabilities = torch.nn.functional.softmax(output.logits, dim=-1) | |
predicted_class = torch.argmax(probabilities, dim=-1).item() | |
predicted_prob = probabilities[0, predicted_class].item() | |
labels = ['Not Phishing', 'Phishing'] | |
prediction_label = labels[predicted_class] | |
sentence = f"The URL '{url_for_recognize}' is classified as '{prediction_label}' with a probability of {predicted_prob:.2f}." | |
return sentence | |
def main(): | |
# Load models | |
image_pipeline, phishing_model, phishing_tokenizer = load_models() | |
# Streamlit interface | |
st.title("Phishing URL Detection from Image") | |
# File uploader to scan the image | |
uploaded_image = st.file_uploader("Upload an image of the URL", type=["png", "jpg", "jpeg"]) | |
if uploaded_image is not None: | |
image = Image.open(uploaded_image) | |
st.image(image, caption='Uploaded URL Image', use_column_width=True) | |
try: | |
# Process the image with the OCR pipeline | |
ocr_result = image_pipeline(image)[0]['generated_text'].replace(" ", "").lower() | |
# Store the verified URL in session state for access later | |
st.session_state['verified_url'] = st.text_input("Recognized URL", ocr_result) | |
except Exception as e: | |
st.error(f"An error occurred during image processing: {e}") | |
if st.button('Detect Phishing'): | |
# Check for 'verified_url' in session state instead of local variable | |
if 'verified_url' in st.session_state and st.session_state['verified_url']: | |
result = check_phishing(phishing_model, phishing_tokenizer, st.session_state['verified_url']) | |
st.write(result) | |
else: | |
st.error("Please upload an image to detect the URL and check for phishing.") | |
# Run the main function | |
if __name__ == "__main__": | |
main() |