mahmoud669's picture
Update app.py
bf0a9fe verified
raw
history blame contribute delete
No virus
5.56 kB
import sys
import os
sys.path.append(os.path.dirname(__file__))
import streamlit as st
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from torchvision import transforms
import timm
from tqdm import tqdm
import torch.nn.functional as F
from collections import Counter
from scrub import unlearn
import zipfile
import time
# Your model and other necessary functions
# Assuming model, device, final_conv, fc_params, cls_names, and any other needed components are defined elsewhere
# from your_model_module import model, device, final_conv, fc_params, cls_names, SaveFeatures, getCAM, tensor_2_im
# Set up your model and device here
# model = ...
# device = ...
# final_conv = ...
# fc_params = ...
# cls_names = ...
reversed_map = {
0: 'Angelina Jolie',
1: 'Brad Pitt',
2: 'Denzel Washington',
3: 'Hugh Jackman',
4: 'Jennifer Lawrence',
5: 'Johnny Depp',
6: 'Kate Winslet',
7: 'Leonardo DiCaprio',
8: 'Megan Fox',
9: 'Natalie Portman',
10: 'Nicole Kidman',
11: 'Robert Downey Jr',
12: 'Sandra Bullock',
13: 'Scarlett Johansson',
14: 'Tom Cruise',
15: 'Tom Hanks',
16: 'Will Smith'
}
def extract(zip_file, extract_path):
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(extract_path)
for root, dirs, files in os.walk(extract_path):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')): # Adjust file extensions as needed
image_path = os.path.join(root, file)
try:
image = Image.open(image_path)
# Process or display the image here
except IOError as e:
print(f"Error opening image {image_path}: {e}")
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
model.eval()
extract('celeb-dataset.zip', 'celeb-dataset')
left_column, right_column = st.columns(2)
with left_column:
# Title of the app
st.title("Original Model")
# File uploader for images
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Open and display the image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image.', width=300)
# Perform inference
st.write("Performing inference...")
# Transform the image to fit model requirements
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = preprocess(image).unsqueeze(0)
preds = []
with torch.no_grad():
for i in range(50):
output = model(image_tensor)
probabilities = F.softmax(output, dim=1)
pred_class = torch.argmax(probabilities, dim=1)
pred_label = reversed_map[pred_class.item()]
preds.append(pred_label)
freq = Counter(preds)
top_three = freq.most_common(3)
for celeb, count in top_three:
st.write(f"{celeb}: {int(count)*2}%")
with right_column:
st.title("Unlearned Model")
options = ['SSD', 'SCRUB', 'UNSIR', 'Incompetent Teacher', 'Mislabel']
# Display dropdown and store selected value
selected_option = st.selectbox('Select an option:', options)
uploaded_file = st.file_uploader("Upload ZIP with images for celebrity to forget.", type="zip")
if uploaded_file is not None:
# Call function to extract and display images
extract(uploaded_file, 'forget_set')
st.write("Unlearning...")
#unlearn()
time.sleep(5)
model_s = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
model_s.load_state_dict(torch.load('celeb-model-unlearned.pth', map_location=torch.device('cpu')))
model_s.eval()
uploaded_file2 = st.file_uploader("Choose image...", type=["jpg", "jpeg", "png"])
if uploaded_file2 is not None:
image2 = Image.open(uploaded_file2)
st.image(image2, caption='Uploaded Image.', width=300)
# Perform inference
st.write("Performing inference...")
# Transform the image to fit model requirements
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = preprocess(image2).unsqueeze(0)
preds = []
with torch.no_grad():
for i in range(50):
output = model_s(image_tensor)
probabilities = F.softmax(output, dim=1)
pred_class = torch.argmax(probabilities, dim=1)
pred_label = reversed_map[pred_class.item()]
preds.append(pred_label)
freq = Counter(preds)
top_three = freq.most_common(3)
for celeb, count in top_three:
st.write(f"{celeb}: {int(count)*2}%")