mahmoud669 commited on
Commit
38412b8
·
verified ·
1 Parent(s): 387362f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -21
app.py CHANGED
@@ -9,6 +9,7 @@ import timm
9
  from tqdm import tqdm
10
  import torch.nn.functional as F
11
  from collections import Counter
 
12
  # Your model and other necessary functions
13
  # Assuming model, device, final_conv, fc_params, cls_names, and any other needed components are defined elsewhere
14
  # from your_model_module import model, device, final_conv, fc_params, cls_names, SaveFeatures, getCAM, tensor_2_im
@@ -39,27 +40,11 @@ reversed_map = {
39
  16: 'Will Smith'
40
  }
41
 
42
- def extract_and_display_images(zip_file):
43
- # Create a directory to store extracted images
44
- extract_path = "extracted_images"
45
- os.makedirs(extract_path, exist_ok=True)
46
-
47
- # Extract images from the ZIP file
48
- with zipfile.ZipFile(zip_file, 'r') as zip_ref:
49
- zip_ref.extractall(extract_path)
50
-
51
- # Display each image in the extracted directory
52
- image_files = os.listdir(extract_path)
53
- for image_file in image_files:
54
- image_path = os.path.join(extract_path, image_file)
55
- image = Image.open(image_path)
56
- st.image(image, caption=image_file, use_column_width=True)
57
-
58
  model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
59
  model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
60
  model.eval()
 
61
  left_column, right_column = st.columns(2)
62
-
63
  with left_column:
64
  # Title of the app
65
  st.title("Original Model")
@@ -97,9 +82,8 @@ with left_column:
97
  top_three = freq.most_common(3)
98
  for celeb, count in top_three:
99
  st.write(f"{celeb}: {int(count)*2}%")
100
- def extract(zip_file):
101
  # Create a directory to store extracted images
102
- extract_path = "extracted_images"
103
  os.makedirs(extract_path, exist_ok=True)
104
 
105
  # Extract images from the ZIP file
@@ -114,14 +98,41 @@ def extract(zip_file):
114
 
115
  with right_column:
116
  uploaded_file = st.file_uploader("Upload ZIP with images for celebrity to forget.", type="zip")
117
- if uploaded_file is not None:
118
  st.write("Uploaded ZIP file details:")
119
  st.write({
120
  "Filename": uploaded_file.name,
121
  })
122
 
123
  # Call function to extract and display images
124
- extract(uploaded_file)
125
  st.write("Unlearning begins...")
126
  unlearn()
127
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from tqdm import tqdm
10
  import torch.nn.functional as F
11
  from collections import Counter
12
+ from scrub import unlearn
13
  # Your model and other necessary functions
14
  # Assuming model, device, final_conv, fc_params, cls_names, and any other needed components are defined elsewhere
15
  # from your_model_module import model, device, final_conv, fc_params, cls_names, SaveFeatures, getCAM, tensor_2_im
 
40
  16: 'Will Smith'
41
  }
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
44
  model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
45
  model.eval()
46
+ extract('celeb-dataset.zip', 'celeb-dataset')
47
  left_column, right_column = st.columns(2)
 
48
  with left_column:
49
  # Title of the app
50
  st.title("Original Model")
 
82
  top_three = freq.most_common(3)
83
  for celeb, count in top_three:
84
  st.write(f"{celeb}: {int(count)*2}%")
85
+ def extract(zip_file, extract_path):
86
  # Create a directory to store extracted images
 
87
  os.makedirs(extract_path, exist_ok=True)
88
 
89
  # Extract images from the ZIP file
 
98
 
99
  with right_column:
100
  uploaded_file = st.file_uploader("Upload ZIP with images for celebrity to forget.", type="zip")
101
+ if uploaded_file is not None:
102
  st.write("Uploaded ZIP file details:")
103
  st.write({
104
  "Filename": uploaded_file.name,
105
  })
106
 
107
  # Call function to extract and display images
108
+ extract(uploaded_file, 'forget_set')
109
  st.write("Unlearning begins...")
110
  unlearn()
111
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
112
+ # Perform inference
113
+ st.write("Performing inference...")
114
+
115
+ # Transform the image to fit model requirements
116
+ preprocess = transforms.Compose([
117
+ transforms.Resize((224, 224)),
118
+ transforms.ToTensor(),
119
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
120
+ ])
121
+
122
+ image_tensor = preprocess(image).unsqueeze(0)
123
+
124
+ preds = []
125
+ with torch.no_grad():
126
+ for i in range(50):
127
+ output = model(image_tensor)
128
+ probabilities = F.softmax(output, dim=1)
129
+ pred_class = torch.argmax(probabilities, dim=1)
130
+ pred_label = reversed_map[pred_class.item()]
131
+ preds.append(pred_label)
132
+
133
+
134
+ freq = Counter(preds)
135
+ top_three = freq.most_common(3)
136
+ for celeb, count in top_three:
137
+ st.write(f"{celeb}: {int(count)*2}%")
138
+