Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ import timm
|
|
9 |
from tqdm import tqdm
|
10 |
import torch.nn.functional as F
|
11 |
from collections import Counter
|
|
|
12 |
# Your model and other necessary functions
|
13 |
# Assuming model, device, final_conv, fc_params, cls_names, and any other needed components are defined elsewhere
|
14 |
# from your_model_module import model, device, final_conv, fc_params, cls_names, SaveFeatures, getCAM, tensor_2_im
|
@@ -39,27 +40,11 @@ reversed_map = {
|
|
39 |
16: 'Will Smith'
|
40 |
}
|
41 |
|
42 |
-
def extract_and_display_images(zip_file):
|
43 |
-
# Create a directory to store extracted images
|
44 |
-
extract_path = "extracted_images"
|
45 |
-
os.makedirs(extract_path, exist_ok=True)
|
46 |
-
|
47 |
-
# Extract images from the ZIP file
|
48 |
-
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
49 |
-
zip_ref.extractall(extract_path)
|
50 |
-
|
51 |
-
# Display each image in the extracted directory
|
52 |
-
image_files = os.listdir(extract_path)
|
53 |
-
for image_file in image_files:
|
54 |
-
image_path = os.path.join(extract_path, image_file)
|
55 |
-
image = Image.open(image_path)
|
56 |
-
st.image(image, caption=image_file, use_column_width=True)
|
57 |
-
|
58 |
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
|
59 |
model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
|
60 |
model.eval()
|
|
|
61 |
left_column, right_column = st.columns(2)
|
62 |
-
|
63 |
with left_column:
|
64 |
# Title of the app
|
65 |
st.title("Original Model")
|
@@ -97,9 +82,8 @@ with left_column:
|
|
97 |
top_three = freq.most_common(3)
|
98 |
for celeb, count in top_three:
|
99 |
st.write(f"{celeb}: {int(count)*2}%")
|
100 |
-
def extract(zip_file):
|
101 |
# Create a directory to store extracted images
|
102 |
-
extract_path = "extracted_images"
|
103 |
os.makedirs(extract_path, exist_ok=True)
|
104 |
|
105 |
# Extract images from the ZIP file
|
@@ -114,14 +98,41 @@ def extract(zip_file):
|
|
114 |
|
115 |
with right_column:
|
116 |
uploaded_file = st.file_uploader("Upload ZIP with images for celebrity to forget.", type="zip")
|
117 |
-
|
118 |
st.write("Uploaded ZIP file details:")
|
119 |
st.write({
|
120 |
"Filename": uploaded_file.name,
|
121 |
})
|
122 |
|
123 |
# Call function to extract and display images
|
124 |
-
extract(uploaded_file)
|
125 |
st.write("Unlearning begins...")
|
126 |
unlearn()
|
127 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from tqdm import tqdm
|
10 |
import torch.nn.functional as F
|
11 |
from collections import Counter
|
12 |
+
from scrub import unlearn
|
13 |
# Your model and other necessary functions
|
14 |
# Assuming model, device, final_conv, fc_params, cls_names, and any other needed components are defined elsewhere
|
15 |
# from your_model_module import model, device, final_conv, fc_params, cls_names, SaveFeatures, getCAM, tensor_2_im
|
|
|
40 |
16: 'Will Smith'
|
41 |
}
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
|
44 |
model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
|
45 |
model.eval()
|
46 |
+
extract('celeb-dataset.zip', 'celeb-dataset')
|
47 |
left_column, right_column = st.columns(2)
|
|
|
48 |
with left_column:
|
49 |
# Title of the app
|
50 |
st.title("Original Model")
|
|
|
82 |
top_three = freq.most_common(3)
|
83 |
for celeb, count in top_three:
|
84 |
st.write(f"{celeb}: {int(count)*2}%")
|
85 |
+
def extract(zip_file, extract_path):
|
86 |
# Create a directory to store extracted images
|
|
|
87 |
os.makedirs(extract_path, exist_ok=True)
|
88 |
|
89 |
# Extract images from the ZIP file
|
|
|
98 |
|
99 |
with right_column:
|
100 |
uploaded_file = st.file_uploader("Upload ZIP with images for celebrity to forget.", type="zip")
|
101 |
+
if uploaded_file is not None:
|
102 |
st.write("Uploaded ZIP file details:")
|
103 |
st.write({
|
104 |
"Filename": uploaded_file.name,
|
105 |
})
|
106 |
|
107 |
# Call function to extract and display images
|
108 |
+
extract(uploaded_file, 'forget_set')
|
109 |
st.write("Unlearning begins...")
|
110 |
unlearn()
|
111 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
112 |
+
# Perform inference
|
113 |
+
st.write("Performing inference...")
|
114 |
+
|
115 |
+
# Transform the image to fit model requirements
|
116 |
+
preprocess = transforms.Compose([
|
117 |
+
transforms.Resize((224, 224)),
|
118 |
+
transforms.ToTensor(),
|
119 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
120 |
+
])
|
121 |
+
|
122 |
+
image_tensor = preprocess(image).unsqueeze(0)
|
123 |
+
|
124 |
+
preds = []
|
125 |
+
with torch.no_grad():
|
126 |
+
for i in range(50):
|
127 |
+
output = model(image_tensor)
|
128 |
+
probabilities = F.softmax(output, dim=1)
|
129 |
+
pred_class = torch.argmax(probabilities, dim=1)
|
130 |
+
pred_label = reversed_map[pred_class.item()]
|
131 |
+
preds.append(pred_label)
|
132 |
+
|
133 |
+
|
134 |
+
freq = Counter(preds)
|
135 |
+
top_three = freq.most_common(3)
|
136 |
+
for celeb, count in top_three:
|
137 |
+
st.write(f"{celeb}: {int(count)*2}%")
|
138 |
+
|