syntheticbot's picture
Upload 5 files
9a5479a verified
import pandas as pd
import torch
from PIL import Image
from sklearn.metrics import classification_report, accuracy_score
from transformers import CLIPImageProcessor
import os
from tqdm import tqdm
# IMPORTANT: This line imports your custom model class from the training script.
# Ensure 'train_clip.py' is in the same directory.
from train_clip import MultiTaskClipVisionModel
# --- 1. Configuration ---
# Verify this path is correct. It should point to the directory where the
# 'pytorch_model.bin' and 'preprocessor_config.json' files for your best model are located.
MODEL_PATH = "./clip-fairface-finetuned/best_model" # Or "./clip-fairface-finetuned/checkpoint-XXXX"
VAL_CSV = './fairface_label_val.csv'
BASE_PATH = './'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
print(f"Loading model from: {MODEL_PATH}")
# --- 2. Load Label Mappings (must be identical to training) ---
# We load the TRAIN csv to ensure the label mappings are consistent with what the model was trained on.
train_df = pd.read_csv('./fairface_label_train.csv')
age_labels = sorted(train_df['age'].unique())
gender_labels = sorted(train_df['gender'].unique())
race_labels = sorted(train_df['race'].unique())
label_mappings = {
'age': {label: i for i, label in enumerate(age_labels)},
'gender': {label: i for i, label in enumerate(gender_labels)},
'race': {label: i for i, label in enumerate(race_labels)},
}
# Create reverse mappings from ID back to human-readable label
id_mappings = {
'age': {i: label for label, i in label_mappings['age'].items()},
'gender': {i: label for label, i in label_mappings['gender'].items()},
'race': {i: label for label, i in label_mappings['race'].items()},
}
NUM_LABELS = {
'age': len(age_labels),
'gender': len(gender_labels),
'race': len(race_labels),
}
# --- 3. Load Model and Processor ---
print("Loading processor and model...")
processor = CLIPImageProcessor.from_pretrained(MODEL_PATH)
model = MultiTaskClipVisionModel(num_labels=NUM_LABELS)
# Load the saved model weights. `map_location` ensures it works even if you trained on GPU and now use CPU.
model.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'pytorch_model.bin'), map_location=torch.device(DEVICE)))
model.to(DEVICE)
model.eval() # Set the model to evaluation mode
print("Model loaded successfully.")
# --- 4. Evaluation on Validation Set ---
def evaluate_on_dataset():
print(f"\nEvaluating on validation data from: {VAL_CSV}")
val_df = pd.read_csv(VAL_CSV)
# Lists to store all predictions and true labels
all_preds = {'age': [], 'gender': [], 'race': []}
all_true = {'age': [], 'gender': [], 'race': []}
# Disable gradient calculations for efficiency
with torch.no_grad():
# Use tqdm for a nice progress bar
for index, row in tqdm(val_df.iterrows(), total=val_df.shape[0], desc="Evaluating"):
image_path = os.path.join(BASE_PATH, row['file'])
image = Image.open(image_path).convert("RGB")
# Process the image and move to the correct device
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
# Get model predictions
outputs = model(pixel_values=inputs['pixel_values'])
logits = outputs['logits']
# Process predictions for each task
for task in ['age', 'gender', 'race']:
pred_id = torch.argmax(logits[task], dim=-1).item()
true_label = row[task]
true_id = label_mappings[task][true_label]
all_preds[task].append(pred_id)
all_true[task].append(true_id)
# --- Print Reports ---
print("\n--- Evaluation Results ---")
for task in ['age', 'gender', 'race']:
task_preds = all_preds[task]
task_true = all_true[task]
task_labels = list(label_mappings[task].keys())
task_target_names = [id_mappings[task][i] for i in range(len(task_labels))]
accuracy = accuracy_score(task_true, task_preds)
report = classification_report(
task_true,
task_preds,
target_names=task_target_names,
zero_division=0
)
print(f"\n--- {task.upper()} CLASSIFICATION REPORT ---")
print(f"Overall Accuracy: {accuracy:.4f}")
print(report)
# --- 5. Function for Single Image Prediction ---
def predict_single_image(image_path):
print(f"\n--- Predicting for single image: {image_path} ---")
if not os.path.exists(image_path):
print(f"Error: Image path not found at '{image_path}'")
return
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(pixel_values=inputs['pixel_values'])
logits = outputs['logits']
predictions = {}
for task in ['age', 'gender', 'race']:
pred_id = torch.argmax(logits[task], dim=-1).item()
pred_label = id_mappings[task][pred_id]
predictions[task] = pred_label
print("Predictions:")
for task, label in predictions.items():
print(f" - {task.capitalize()}: {label}")
return predictions
if __name__ == "__main__":
# Run the full evaluation on the validation dataset
evaluate_on_dataset()
# --- Example of single image prediction ---
# IMPORTANT: Change this path to an image you want to test
sample_image_path = 'val/1.jpg'
predict_single_image(sample_image_path)