File size: 2,299 Bytes
1dd6eaf
0d5c920
 
1dd6eaf
0d5c920
 
 
1dd6eaf
0de8628
1dd6eaf
 
 
 
0d5c920
1dd6eaf
0d5c920
1dd6eaf
 
 
 
 
 
 
 
 
 
 
 
 
0d5c920
 
1dd6eaf
 
0d5c920
 
1dd6eaf
 
0d5c920
0de8628
1dd6eaf
0de8628
1dd6eaf
 
 
 
 
 
0d5c920
1dd6eaf
0d5c920
1dd6eaf
0d5c920
 
c09e4cb
 
0d5c920
c09e4cb
 
 
 
1dd6eaf
0202ce7
 
 
8c4da2d
0202ce7
 
 
 
 
0d5c920
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr

# -------- CONFIG --------
checkpoint_path = "age_prediction_model2.pth"  # Just the model file name for Hugging Face Spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -------- SIMPLE CNN MODEL --------
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 64x64
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 32x32
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),  # 16x16
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 16 * 16, 256), nn.ReLU(),
            nn.Linear(256, 1)  # Output: age (regression)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# -------- LOAD MODEL --------
model = SimpleCNN().to(device)

# Check if the checkpoint file exists and load
if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))  # Load to the correct device
    model.eval()  # Set the model to evaluation mode
    print(f"Model loaded from {checkpoint_path}")
else:
    print(f"Error: Checkpoint file not found at {checkpoint_path}. Please check the path.")

# -------- PREPROCESSING --------
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

def predict_age(image: Image.Image) -> float:
    image_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image_tensor)
        age = output.item()
    return round(age, 2)


import gradio as gr

gr.Interface(
    fn=predict_age,
    inputs=gr.Image(type="pil", image_mode="RGB"),
    outputs="number",  # or "text" if your output is text-based
    title="Age Prediction from Face",
    description="Upload a face image and get the predicted age."
).launch()