mahmoud669 commited on
Commit
8a811c4
1 Parent(s): 2068125

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -33
app.py CHANGED
@@ -39,44 +39,89 @@ reversed_map = {
39
  16: 'Will Smith'
40
  }
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
43
  model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
44
  model.eval()
 
45
 
 
46
  # Title of the app
47
- st.title("Image Upload and Inference Example")
48
-
49
- # File uploader for images
50
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
51
-
52
- if uploaded_file is not None:
53
- # Open and display the image
54
- image = Image.open(uploaded_file)
55
- st.image(image, caption='Uploaded Image.', use_column_width=True)
56
-
57
- # Perform inference
58
- st.write("Performing inference...")
59
 
60
- # Transform the image to fit model requirements
61
- preprocess = transforms.Compose([
62
- transforms.Resize((224, 224)),
63
- transforms.ToTensor(),
64
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
65
- ])
66
-
67
- image_tensor = preprocess(image).unsqueeze(0)
68
-
69
- preds = []
70
- with torch.no_grad():
71
- for i in range(50):
72
- output = model(image_tensor)
73
- probabilities = F.softmax(output, dim=1)
74
- pred_class = torch.argmax(probabilities, dim=1)
75
- pred_label = reversed_map[pred_class.item()]
76
- preds.append(pred_label)
 
 
 
 
 
 
 
 
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- freq = Counter(preds)
80
- top_three = freq.most_common(3)
81
- for celeb, count in top_three:
82
- st.write(f"{celeb}: {int(count)*2}%")
 
 
39
  16: 'Will Smith'
40
  }
41
 
42
+ def extract_and_display_images(zip_file):
43
+ # Create a directory to store extracted images
44
+ extract_path = "extracted_images"
45
+ os.makedirs(extract_path, exist_ok=True)
46
+
47
+ # Extract images from the ZIP file
48
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
49
+ zip_ref.extractall(extract_path)
50
+
51
+ # Display each image in the extracted directory
52
+ image_files = os.listdir(extract_path)
53
+ for image_file in image_files:
54
+ image_path = os.path.join(extract_path, image_file)
55
+ image = Image.open(image_path)
56
+ st.image(image, caption=image_file, use_column_width=True)
57
+
58
  model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
59
  model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
60
  model.eval()
61
+ left_column, right_column = st.columns(2)
62
 
63
+ with left_column:
64
  # Title of the app
65
+ st.title("Original Model")
66
+ # File uploader for images
67
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
68
 
69
+ if uploaded_file is not None:
70
+ # Open and display the image
71
+ image = Image.open(uploaded_file)
72
+ st.image(image, caption='Uploaded Image.', width=300)
73
+
74
+ # Perform inference
75
+ st.write("Performing inference...")
76
+
77
+ # Transform the image to fit model requirements
78
+ preprocess = transforms.Compose([
79
+ transforms.Resize((224, 224)),
80
+ transforms.ToTensor(),
81
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
82
+ ])
83
+
84
+ image_tensor = preprocess(image).unsqueeze(0)
85
+
86
+ preds = []
87
+ with torch.no_grad():
88
+ for i in range(50):
89
+ output = model(image_tensor)
90
+ probabilities = F.softmax(output, dim=1)
91
+ pred_class = torch.argmax(probabilities, dim=1)
92
+ pred_label = reversed_map[pred_class.item()]
93
+ preds.append(pred_label)
94
+
95
 
96
+ freq = Counter(preds)
97
+ top_three = freq.most_common(3)
98
+ for celeb, count in top_three:
99
+ st.write(f"{celeb}: {int(count)*2}%")
100
+ def extract(zip_file):
101
+ # Create a directory to store extracted images
102
+ extract_path = "extracted_images"
103
+ os.makedirs(extract_path, exist_ok=True)
104
+
105
+ # Extract images from the ZIP file
106
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
107
+ zip_ref.extractall(extract_path)
108
+
109
+ # Display each image in the extracted directory
110
+ image_files = os.listdir(extract_path)
111
+ for image_file in image_files:
112
+ image_path = os.path.join(extract_path, image_file)
113
+ image = Image.open(image_path)
114
+
115
+ with right_column:
116
+ uploaded_file = st.file_uploader("Upload ZIP with images for celebrity to forget.", type="zip")
117
+ if uploaded_file is not None:
118
+ st.write("Uploaded ZIP file details:")
119
+ st.write({
120
+ "Filename": uploaded_file.name,
121
+ })
122
 
123
+ # Call function to extract and display images
124
+ extract(uploaded_file)
125
+ st.write("Unlearning begins...")
126
+ unlearn()
127
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])