mahmoud669's picture
Update app.py
8a811c4 verified
raw
history blame
No virus
4.3 kB
import streamlit as st
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from torchvision import transforms
import timm
from tqdm import tqdm
import torch.nn.functional as F
from collections import Counter
# Your model and other necessary functions
# Assuming model, device, final_conv, fc_params, cls_names, and any other needed components are defined elsewhere
# from your_model_module import model, device, final_conv, fc_params, cls_names, SaveFeatures, getCAM, tensor_2_im
# Set up your model and device here
# model = ...
# device = ...
# final_conv = ...
# fc_params = ...
# cls_names = ...
reversed_map = {
0: 'Angelina Jolie',
1: 'Brad Pitt',
2: 'Denzel Washington',
3: 'Hugh Jackman',
4: 'Jennifer Lawrence',
5: 'Johnny Depp',
6: 'Kate Winslet',
7: 'Leonardo DiCaprio',
8: 'Megan Fox',
9: 'Natalie Portman',
10: 'Nicole Kidman',
11: 'Robert Downey Jr',
12: 'Sandra Bullock',
13: 'Scarlett Johansson',
14: 'Tom Cruise',
15: 'Tom Hanks',
16: 'Will Smith'
}
def extract_and_display_images(zip_file):
# Create a directory to store extracted images
extract_path = "extracted_images"
os.makedirs(extract_path, exist_ok=True)
# Extract images from the ZIP file
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(extract_path)
# Display each image in the extracted directory
image_files = os.listdir(extract_path)
for image_file in image_files:
image_path = os.path.join(extract_path, image_file)
image = Image.open(image_path)
st.image(image, caption=image_file, use_column_width=True)
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
model.eval()
left_column, right_column = st.columns(2)
with left_column:
# Title of the app
st.title("Original Model")
# File uploader for images
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Open and display the image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image.', width=300)
# Perform inference
st.write("Performing inference...")
# Transform the image to fit model requirements
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = preprocess(image).unsqueeze(0)
preds = []
with torch.no_grad():
for i in range(50):
output = model(image_tensor)
probabilities = F.softmax(output, dim=1)
pred_class = torch.argmax(probabilities, dim=1)
pred_label = reversed_map[pred_class.item()]
preds.append(pred_label)
freq = Counter(preds)
top_three = freq.most_common(3)
for celeb, count in top_three:
st.write(f"{celeb}: {int(count)*2}%")
def extract(zip_file):
# Create a directory to store extracted images
extract_path = "extracted_images"
os.makedirs(extract_path, exist_ok=True)
# Extract images from the ZIP file
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(extract_path)
# Display each image in the extracted directory
image_files = os.listdir(extract_path)
for image_file in image_files:
image_path = os.path.join(extract_path, image_file)
image = Image.open(image_path)
with right_column:
uploaded_file = st.file_uploader("Upload ZIP with images for celebrity to forget.", type="zip")
if uploaded_file is not None:
st.write("Uploaded ZIP file details:")
st.write({
"Filename": uploaded_file.name,
})
# Call function to extract and display images
extract(uploaded_file)
st.write("Unlearning begins...")
unlearn()
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])