|
import pandas as pd |
|
import torch |
|
from PIL import Image |
|
from sklearn.metrics import classification_report, accuracy_score |
|
from transformers import CLIPImageProcessor |
|
import os |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
from train_clip import MultiTaskClipVisionModel |
|
|
|
|
|
|
|
|
|
MODEL_PATH = "./clip-fairface-finetuned/best_model" |
|
|
|
VAL_CSV = './fairface_label_val.csv' |
|
BASE_PATH = './' |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
print(f"Using device: {DEVICE}") |
|
print(f"Loading model from: {MODEL_PATH}") |
|
|
|
|
|
|
|
train_df = pd.read_csv('./fairface_label_train.csv') |
|
age_labels = sorted(train_df['age'].unique()) |
|
gender_labels = sorted(train_df['gender'].unique()) |
|
race_labels = sorted(train_df['race'].unique()) |
|
|
|
label_mappings = { |
|
'age': {label: i for i, label in enumerate(age_labels)}, |
|
'gender': {label: i for i, label in enumerate(gender_labels)}, |
|
'race': {label: i for i, label in enumerate(race_labels)}, |
|
} |
|
|
|
|
|
id_mappings = { |
|
'age': {i: label for label, i in label_mappings['age'].items()}, |
|
'gender': {i: label for label, i in label_mappings['gender'].items()}, |
|
'race': {i: label for label, i in label_mappings['race'].items()}, |
|
} |
|
|
|
NUM_LABELS = { |
|
'age': len(age_labels), |
|
'gender': len(gender_labels), |
|
'race': len(race_labels), |
|
} |
|
|
|
|
|
|
|
print("Loading processor and model...") |
|
processor = CLIPImageProcessor.from_pretrained(MODEL_PATH) |
|
model = MultiTaskClipVisionModel(num_labels=NUM_LABELS) |
|
|
|
|
|
model.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'pytorch_model.bin'), map_location=torch.device(DEVICE))) |
|
model.to(DEVICE) |
|
model.eval() |
|
print("Model loaded successfully.") |
|
|
|
|
|
|
|
def evaluate_on_dataset(): |
|
print(f"\nEvaluating on validation data from: {VAL_CSV}") |
|
val_df = pd.read_csv(VAL_CSV) |
|
|
|
|
|
all_preds = {'age': [], 'gender': [], 'race': []} |
|
all_true = {'age': [], 'gender': [], 'race': []} |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for index, row in tqdm(val_df.iterrows(), total=val_df.shape[0], desc="Evaluating"): |
|
image_path = os.path.join(BASE_PATH, row['file']) |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
outputs = model(pixel_values=inputs['pixel_values']) |
|
logits = outputs['logits'] |
|
|
|
|
|
for task in ['age', 'gender', 'race']: |
|
pred_id = torch.argmax(logits[task], dim=-1).item() |
|
true_label = row[task] |
|
true_id = label_mappings[task][true_label] |
|
|
|
all_preds[task].append(pred_id) |
|
all_true[task].append(true_id) |
|
|
|
|
|
print("\n--- Evaluation Results ---") |
|
for task in ['age', 'gender', 'race']: |
|
task_preds = all_preds[task] |
|
task_true = all_true[task] |
|
task_labels = list(label_mappings[task].keys()) |
|
task_target_names = [id_mappings[task][i] for i in range(len(task_labels))] |
|
|
|
accuracy = accuracy_score(task_true, task_preds) |
|
report = classification_report( |
|
task_true, |
|
task_preds, |
|
target_names=task_target_names, |
|
zero_division=0 |
|
) |
|
|
|
print(f"\n--- {task.upper()} CLASSIFICATION REPORT ---") |
|
print(f"Overall Accuracy: {accuracy:.4f}") |
|
print(report) |
|
|
|
|
|
|
|
def predict_single_image(image_path): |
|
print(f"\n--- Predicting for single image: {image_path} ---") |
|
if not os.path.exists(image_path): |
|
print(f"Error: Image path not found at '{image_path}'") |
|
return |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
inputs = processor(images=image, return_tensors="pt").to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = model(pixel_values=inputs['pixel_values']) |
|
logits = outputs['logits'] |
|
|
|
predictions = {} |
|
for task in ['age', 'gender', 'race']: |
|
pred_id = torch.argmax(logits[task], dim=-1).item() |
|
pred_label = id_mappings[task][pred_id] |
|
predictions[task] = pred_label |
|
|
|
print("Predictions:") |
|
for task, label in predictions.items(): |
|
print(f" - {task.capitalize()}: {label}") |
|
return predictions |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
evaluate_on_dataset() |
|
|
|
|
|
|
|
sample_image_path = 'val/1.jpg' |
|
predict_single_image(sample_image_path) |