faranbutt789 commited on
Commit
5177c9a
Β·
verified Β·
1 Parent(s): efb54b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -77
app.py CHANGED
@@ -8,17 +8,16 @@ import torch
8
  import torch.nn as nn
9
  import torchvision.models as models
10
  import torchvision.transforms as T
11
- from PIL import ImageFont, ImageDraw, Image
12
  import numpy as np
13
 
14
  # Device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # --- Model definition (must match your saved state) ---
18
  class AgeGenderClassifier(nn.Module):
19
  def __init__(self):
20
  super(AgeGenderClassifier, self).__init__()
21
- # classifier expected input dim 2048 (as in your training run)
22
  self.intermediate = nn.Sequential(
23
  nn.Linear(2048, 512),
24
  nn.ReLU(),
@@ -48,15 +47,15 @@ class AgeGenderClassifier(nn.Module):
48
  def build_model(weights_path: str):
49
  """Rebuild VGG16 backbone + custom avgpool/classifier then load weights."""
50
  backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
51
- # freeze all then fine-tune later if needed (same as training script)
 
52
  for p in backbone.parameters():
53
  p.requires_grad = False
54
-
55
- # allow last block to be trainable if desired (kept same as your training code)
56
  for p in backbone.features[24:].parameters():
57
  p.requires_grad = True
58
 
59
- # replace avgpool with the same block used during training (conv->maxpool->relu->flatten)
60
  backbone.avgpool = nn.Sequential(
61
  nn.Conv2d(512, 512, kernel_size=3),
62
  nn.MaxPool2d(2),
@@ -73,11 +72,9 @@ def build_model(weights_path: str):
73
  raise FileNotFoundError(f"Model weights not found at {weights_path}")
74
 
75
  state = torch.load(weights_path, map_location=device)
76
- # If saved state was model.state_dict(), load directly
77
  try:
78
  model.load_state_dict(state)
79
  except Exception:
80
- # if state is a dict with other keys, try common wrappers
81
  if "model_state_dict" in state:
82
  model.load_state_dict(state["model_state_dict"])
83
  else:
@@ -98,81 +95,50 @@ transform = T.Compose([
98
  INV_AGE_SCALE = 80 # training used age/80 normalization
99
 
100
 
101
- def draw_caption_on_image(image, caption):
102
- if image.mode != "RGBA":
103
- image = image.convert("RGBA")
104
- draw = ImageDraw.Draw(image)
105
- font = ImageFont.load_default()
106
-
107
- bbox = draw.textbbox((0,0), caption, font=font)
108
- text_w = bbox[2] - bbox[0]
109
- text_h = bbox[3] - bbox[1]
110
-
111
- # semi-transparent rectangle
112
- overlay = Image.new("RGBA", image.size)
113
- overlay_draw = ImageDraw.Draw(overlay)
114
- overlay_draw.rectangle([0,0,text_w+20,text_h+20], fill=(0,0,0,127))
115
- image = Image.alpha_composite(image, overlay)
116
-
117
- draw = ImageDraw.Draw(image)
118
- draw.text((10,10), caption, font=font, fill="white")
119
- return image.convert("RGB")
120
-
121
-
122
-
123
- # --- Prediction function for multiple images ---
124
 
125
- def predict_images(images: List[Image.Image], model) -> List[Image.Image]:
126
- """Takes a list of PIL images and returns list of PIL images annotated with predictions."""
127
- if images is None or len(images) == 0:
128
- return []
129
-
130
- # preprocess all images into a batch
131
  tensors = []
132
  for im in images:
133
  if im.mode != "RGB":
134
  im = im.convert("RGB")
135
- t = transform(im)
136
- tensors.append(t)
137
 
138
  batch = torch.stack(tensors).to(device)
139
 
140
  with torch.no_grad():
141
  pred_age, pred_gender = model(batch)
142
- # ensure shapes (N,1)
143
  pred_age = pred_age.squeeze(-1).cpu().numpy()
144
  pred_gender = pred_gender.squeeze(-1).cpu().numpy()
145
 
146
- outputs = []
 
 
147
  for img, pa, pg in zip(images, pred_age, pred_gender):
148
  age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE)
149
  gender_label = "Female" if pg > 0.5 else "Male"
150
  gender_emoji = "πŸ‘©" if pg > 0.5 else "πŸ‘¨"
151
  conf = float(pg if pg > 0.5 else 1 - pg)
152
 
153
- caption = f"{gender_emoji} {gender_label} ({conf:.2f}) β€’ πŸŽ‚ Age β‰ˆ {age_val}"
154
- out_img = draw_caption_on_image(img, caption)
155
- outputs.append(out_img)
156
 
157
- return outputs
158
 
159
 
160
- # --- Load model once on startup ---
161
  MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth")
162
  model = build_model(MODEL_WEIGHTS)
163
 
 
164
  # --- Gradio UI ---
165
  with gr.Blocks(title="FairFace Age & Gender β€” Multi-image Demo") as demo:
166
  gr.Markdown("""
167
  # 🧠 FairFace Multi-task Age & Gender Predictor
168
- Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results right on the picture.
169
-
170
- **How to use**
171
- 1. Click **Browse** or drag & drop multiple images. βœ…
172
- 2. Click **Run**. The model processes images and shows results below. ⚑
173
- 3. Use the download button on the output images if you want to save them.
174
-
175
- *Note:* Age is estimated (approx.). This model was trained on the FairFace dataset.
176
  """)
177
 
178
  with gr.Row():
@@ -180,41 +146,34 @@ with gr.Blocks(title="FairFace Age & Gender β€” Multi-image Demo") as demo:
180
  run_btn = gr.Button("Run ▢️")
181
 
182
  gallery = gr.Gallery(
183
- label="Predictions",
184
- show_label=True,
185
- elem_id="gallery",
186
- columns=3, # 3 images per row
187
  height="auto"
188
  )
189
 
 
 
190
  def run_and_predict(files):
191
- # files is list of uploaded file dicts or file paths depending on environment
192
  if not files:
193
- return []
194
-
195
  pil_imgs = []
196
- # if File component returns list of dicts in HF spaces, handle both
197
  for f in files:
198
- # f might be a path string or dict-like
199
- if isinstance(f, dict) and "name" in f and "data" in f:
200
- # web upload format
201
- im = Image.open(io.BytesIO(f["data"]))
202
- else:
203
- path = f if isinstance(f, str) else f.name
204
- im = Image.open(path)
205
- pil_imgs.append(im.convert("RGB"))
206
 
207
- return predict_images(pil_imgs, model)
 
 
208
 
209
- run_btn.click(fn=run_and_predict, inputs=[img_input], outputs=[gallery])
210
 
211
  gr.Markdown("""
212
  ---
213
  **Tips & Notes**
214
- - The model outputs age normalized to 0–80 years (approx).
215
- - If results look odd, try a clearer, frontal face image.
216
- - This demo is for research / demo purposes only β€” be mindful of privacy. πŸ™
217
  """)
218
 
219
  if __name__ == "__main__":
220
- demo.launch()
 
8
  import torch.nn as nn
9
  import torchvision.models as models
10
  import torchvision.transforms as T
11
+ from PIL import Image
12
  import numpy as np
13
 
14
  # Device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # --- Model definition ---
18
  class AgeGenderClassifier(nn.Module):
19
  def __init__(self):
20
  super(AgeGenderClassifier, self).__init__()
 
21
  self.intermediate = nn.Sequential(
22
  nn.Linear(2048, 512),
23
  nn.ReLU(),
 
47
  def build_model(weights_path: str):
48
  """Rebuild VGG16 backbone + custom avgpool/classifier then load weights."""
49
  backbone = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
50
+
51
+ # freeze all layers
52
  for p in backbone.parameters():
53
  p.requires_grad = False
54
+ # optionally allow last block to be trainable
 
55
  for p in backbone.features[24:].parameters():
56
  p.requires_grad = True
57
 
58
+ # replace avgpool
59
  backbone.avgpool = nn.Sequential(
60
  nn.Conv2d(512, 512, kernel_size=3),
61
  nn.MaxPool2d(2),
 
72
  raise FileNotFoundError(f"Model weights not found at {weights_path}")
73
 
74
  state = torch.load(weights_path, map_location=device)
 
75
  try:
76
  model.load_state_dict(state)
77
  except Exception:
 
78
  if "model_state_dict" in state:
79
  model.load_state_dict(state["model_state_dict"])
80
  else:
 
95
  INV_AGE_SCALE = 80 # training used age/80 normalization
96
 
97
 
98
+ # --- Prediction function ---
99
+ def predict_images_with_text(images: List[Image.Image], model):
100
+ """Return original images and captions for each."""
101
+ if not images:
102
+ return [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
 
 
 
 
 
104
  tensors = []
105
  for im in images:
106
  if im.mode != "RGB":
107
  im = im.convert("RGB")
108
+ tensors.append(transform(im))
 
109
 
110
  batch = torch.stack(tensors).to(device)
111
 
112
  with torch.no_grad():
113
  pred_age, pred_gender = model(batch)
 
114
  pred_age = pred_age.squeeze(-1).cpu().numpy()
115
  pred_gender = pred_gender.squeeze(-1).cpu().numpy()
116
 
117
+ output_images = []
118
+ captions = []
119
+
120
  for img, pa, pg in zip(images, pred_age, pred_gender):
121
  age_val = int(np.clip(pa, 0.0, 1.0) * INV_AGE_SCALE)
122
  gender_label = "Female" if pg > 0.5 else "Male"
123
  gender_emoji = "πŸ‘©" if pg > 0.5 else "πŸ‘¨"
124
  conf = float(pg if pg > 0.5 else 1 - pg)
125
 
126
+ output_images.append(np.array(img))
127
+ captions.append(f"{gender_emoji} {gender_label} ({conf:.2f}) β€’ πŸŽ‚ Age β‰ˆ {age_val}")
 
128
 
129
+ return output_images, captions
130
 
131
 
132
+ # --- Load model ---
133
  MODEL_WEIGHTS = os.environ.get("MODEL_PATH", "age_gender_model.pth")
134
  model = build_model(MODEL_WEIGHTS)
135
 
136
+
137
  # --- Gradio UI ---
138
  with gr.Blocks(title="FairFace Age & Gender β€” Multi-image Demo") as demo:
139
  gr.Markdown("""
140
  # 🧠 FairFace Multi-task Age & Gender Predictor
141
+ Upload **one or more** images (JPG/PNG). The app will predict **gender** and **age** for each image and display results below the image.
 
 
 
 
 
 
 
142
  """)
143
 
144
  with gr.Row():
 
146
  run_btn = gr.Button("Run ▢️")
147
 
148
  gallery = gr.Gallery(
149
+ label="Uploaded Images",
150
+ columns=3,
 
 
151
  height="auto"
152
  )
153
 
154
+ captions = gr.HTML(label="Predictions")
155
+
156
  def run_and_predict(files):
 
157
  if not files:
158
+ return [], ""
 
159
  pil_imgs = []
 
160
  for f in files:
161
+ path = f if isinstance(f, str) else f.name
162
+ pil_imgs.append(Image.open(path).convert("RGB"))
 
 
 
 
 
 
163
 
164
+ imgs, texts = predict_images_with_text(pil_imgs, model)
165
+ captions_html = "<br>".join([f"<h2>{t}</h2>" for t in texts])
166
+ return imgs, captions_html
167
 
168
+ run_btn.click(fn=run_and_predict, inputs=[img_input], outputs=[gallery, captions])
169
 
170
  gr.Markdown("""
171
  ---
172
  **Tips & Notes**
173
+ - Age is normalized to 0–80 years (approx.).
174
+ - For best results, upload clear frontal face images.
175
+ - This is a demo β€” respect privacy when using photos. πŸ™
176
  """)
177
 
178
  if __name__ == "__main__":
179
+ demo.launch()