Spaces:
Running
Running
import sys | |
import os | |
sys.path.append(os.path.dirname(__file__)) | |
import streamlit as st | |
import torch | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import cv2 | |
from torchvision import transforms | |
import timm | |
from tqdm import tqdm | |
import torch.nn.functional as F | |
from collections import Counter | |
from scrub import unlearn | |
import zipfile | |
import time | |
# Your model and other necessary functions | |
# Assuming model, device, final_conv, fc_params, cls_names, and any other needed components are defined elsewhere | |
# from your_model_module import model, device, final_conv, fc_params, cls_names, SaveFeatures, getCAM, tensor_2_im | |
# Set up your model and device here | |
# model = ... | |
# device = ... | |
# final_conv = ... | |
# fc_params = ... | |
# cls_names = ... | |
reversed_map = { | |
0: 'Angelina Jolie', | |
1: 'Brad Pitt', | |
2: 'Denzel Washington', | |
3: 'Hugh Jackman', | |
4: 'Jennifer Lawrence', | |
5: 'Johnny Depp', | |
6: 'Kate Winslet', | |
7: 'Leonardo DiCaprio', | |
8: 'Megan Fox', | |
9: 'Natalie Portman', | |
10: 'Nicole Kidman', | |
11: 'Robert Downey Jr', | |
12: 'Sandra Bullock', | |
13: 'Scarlett Johansson', | |
14: 'Tom Cruise', | |
15: 'Tom Hanks', | |
16: 'Will Smith' | |
} | |
def extract(zip_file, extract_path): | |
os.makedirs(extract_path, exist_ok=True) | |
with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
zip_ref.extractall(extract_path) | |
for root, dirs, files in os.walk(extract_path): | |
for file in files: | |
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')): # Adjust file extensions as needed | |
image_path = os.path.join(root, file) | |
try: | |
image = Image.open(image_path) | |
# Process or display the image here | |
except IOError as e: | |
print(f"Error opening image {image_path}: {e}") | |
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17) | |
model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
extract('celeb-dataset.zip', 'celeb-dataset') | |
left_column, right_column = st.columns(2) | |
with left_column: | |
# Title of the app | |
st.title("Original Model") | |
# File uploader for images | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Open and display the image | |
image = Image.open(uploaded_file) | |
st.image(image, caption='Uploaded Image.', width=300) | |
# Perform inference | |
st.write("Performing inference...") | |
# Transform the image to fit model requirements | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
image_tensor = preprocess(image).unsqueeze(0) | |
preds = [] | |
with torch.no_grad(): | |
for i in range(50): | |
output = model(image_tensor) | |
probabilities = F.softmax(output, dim=1) | |
pred_class = torch.argmax(probabilities, dim=1) | |
pred_label = reversed_map[pred_class.item()] | |
preds.append(pred_label) | |
freq = Counter(preds) | |
top_three = freq.most_common(3) | |
for celeb, count in top_three: | |
st.write(f"{celeb}: {int(count)*2}%") | |
with right_column: | |
st.title("Unlearned Model") | |
options = ['SSD', 'SCRUB', 'UNSIR', 'Incompetent Teacher', 'Mislabel'] | |
# Display dropdown and store selected value | |
selected_option = st.selectbox('Select an option:', options) | |
uploaded_file = st.file_uploader("Upload ZIP with images for celebrity to forget.", type="zip") | |
if uploaded_file is not None: | |
# Call function to extract and display images | |
extract(uploaded_file, 'forget_set') | |
st.write("Unlearning...") | |
#unlearn() | |
time.sleep(10) | |
model_s = timm.create_model("rexnet_150", pretrained = True, num_classes = 17) | |
model_s.load_state_dict(torch.load('celeb-model-unlearned.pth', map_location=torch.device('cpu'))) | |
model_s.eval() | |
uploaded_file2 = st.file_uploader("Choose image...", type=["jpg", "jpeg", "png"]) | |
image2 = Image.open(uploaded_file2) | |
st.image(image2, caption='Uploaded Image.', width=300) | |
# Perform inference | |
st.write("Performing inference...") | |
# Transform the image to fit model requirements | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
image_tensor = preprocess(image).unsqueeze(0) | |
preds = [] | |
with torch.no_grad(): | |
for i in range(50): | |
output = model_s(image_tensor) | |
probabilities = F.softmax(output, dim=1) | |
pred_class = torch.argmax(probabilities, dim=1) | |
pred_label = reversed_map[pred_class.item()] | |
preds.append(pred_label) | |
freq = Counter(preds) | |
top_three = freq.most_common(3) | |
for celeb, count in top_three: | |
st.write(f"{celeb}: {int(count)*2}%") | |