File size: 4,814 Bytes
7df2acb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import v2
from PIL import Image
import pandas as pd
from tqdm import tqdm

# DEVICE SETUP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("\nπŸš€ Using device:", device)

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# ----- HELPER FUNCTIONS -----
def get_bert_embedding(text):
    inputs = tokenizer.encode_plus(
        text, add_special_tokens=True,
        return_tensors='pt', max_length=80,
        truncation=True, padding='max_length'
    )
    return inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0)

# ----- DATASET CLASS -----
class FakedditDataset(Dataset):
    def __init__(self, df, text_field="clean_title", label_field="binary_label", image_id="id"):
        self.df = df.reset_index(drop=True)
        self.text_field = text_field
        self.label_field = label_field
        self.image_id = image_id

        self.transform = v2.Compose([
            v2.Resize((256, 256)),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = self.df.at[idx, self.text_field]
        label = self.df.at[idx, self.label_field]
        image_path = f"./val_images/{self.df.at[idx, self.image_id]}.jpg"

        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        input_ids, attention_mask = get_bert_embedding(str(text))

        return image, input_ids, attention_mask, torch.tensor(label, dtype=torch.long)

# ----- MODEL CLASSES -----
class SelfAttentionFusion(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.attn = nn.Linear(embed_dim * 2, 2)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x_text, x_img):
        stacked = torch.stack([x_text, x_img], dim=1)
        attn_weights = self.softmax(self.attn(torch.cat([x_text, x_img], dim=1))).unsqueeze(2)
        fused = (attn_weights * stacked).sum(dim=1)
        return fused

class BERTResNetClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.image_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.fc_image = nn.Linear(1000, 512)
        self.drop_img = nn.Dropout(0.3)

        self.text_model = BertModel.from_pretrained("bert-base-uncased")
        self.fc_text = nn.Linear(self.text_model.config.hidden_size, 512)
        self.drop_text = nn.Dropout(0.3)

        self.fusion = SelfAttentionFusion(512)
        self.fc_final = nn.Linear(512, num_classes)

    def forward(self, image, input_ids, attention_mask):
        x_img = self.image_model(image)
        x_img = self.drop_img(x_img)
        x_img = self.fc_image(x_img)

        x_text = self.text_model(input_ids=input_ids, attention_mask=attention_mask)[0][:, 0, :]
        x_text = self.drop_text(x_text)
        x_text = self.fc_text(x_text)

        x_fused = self.fusion(x_text, x_img)
        return self.fc_final(x_fused)

# ----- LOAD DATA -----
df = pd.read_csv("./val_output.csv")
print("πŸ“„ Loaded validation CSV with", len(df), "samples")
val_dataset = FakedditDataset(df)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# ----- LOAD MODEL STATE -----
def remove_module_prefix(state_dict):
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')
        new_state_dict[name] = v
    return new_state_dict

print("πŸ“¦ Loading model weights...")
state_dict = torch.load("state_dict.pth", map_location=device)
clean_state_dict = remove_module_prefix(state_dict)

model = BERTResNetClassifier(num_classes=2)
model.load_state_dict(clean_state_dict)
model.to(device)
model.eval()
print("βœ… Model loaded and ready for evaluation")

# ----- EVALUATION -----
correct = 0
total = 0
print("\nπŸ” Starting evaluation...")
with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        images, input_ids, attention_mask, labels = batch
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(images, input_ids, attention_mask)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total * 100
print(f"\n🎯 Final Validation Accuracy: {accuracy:.2f}%")