File size: 9,563 Bytes
54199b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
from cProfile import label
import os
import json
import numpy as np
from tqdm import tqdm
from argparse import ArgumentParser
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader

from src.open_clip import create_model_and_transforms, get_tokenizer
from src.training.train import calc_ImageReward, inversion_score
from src.training.data import ImageRewardDataset, collate_rank, RankingDataset


parser = ArgumentParser()
parser.add_argument('--data-type', type=str, choices=['benchmark', 'test', 'ImageReward', 'drawbench'])
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--image-path', type=str, help='path to image files')
parser.add_argument('--checkpoint', type=str, help='path to checkpoint')
parser.add_argument('--batch-size', type=int, default=20)
args = parser.parse_args()

batch_size = args.batch_size
args.model = "ViT-H-14"
args.precision = 'amp'
print(args.model)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, preprocess_train, preprocess_val = create_model_and_transforms(
    args.model,
    'laion2B-s32B-b79K',
    precision=args.precision,
    device=device,
    jit=False,
    force_quick_gelu=False,
    force_custom_text=False,
    force_patch_dropout=False,
    force_image_size=None,
    pretrained_image=False,
    image_mean=None,
    image_std=None,
    light_augmentation=True,
    aug_cfg={},
    output_dict=True,
    with_score_predictor=False,
    with_region_predictor=False
)

checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint['state_dict'])
tokenizer = get_tokenizer(args.model)
model.eval()

class BenchmarkDataset(Dataset):
    def __init__(self, meta_file, image_folder,transforms, tokenizer):
        self.transforms = transforms
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.open_image = Image.open
        with open(meta_file, 'r') as f:
            self.annotations = json.load(f)
            
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        try:
            img_path = os.path.join(self.image_folder, f'{idx:05d}.jpg')
            images = self.transforms(self.open_image(os.path.join(img_path)))
            caption = self.tokenizer(self.annotations[idx])
            return images, caption
        except:
            print('file not exist')
            return self.__getitem__((idx + 1) % len(self))

def evaluate_IR(data_path, image_folder, model):
    meta_file = data_path + '/ImageReward_test.json'
    dataset = ImageRewardDataset(meta_file, image_folder, preprocess_val, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_rank)
    
    score = 0
    total = len(dataset)
    with torch.no_grad():
        for batch in tqdm(dataloader):
            images, num_images, labels, texts = batch
            images = images.to(device=device, non_blocking=True)
            texts = texts.to(device=device, non_blocking=True)
            num_images = num_images.to(device=device, non_blocking=True)
            labels = labels.to(device=device, non_blocking=True)

            with torch.cuda.amp.autocast():
                outputs = model(images, texts)
                image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"]
                logits_per_image = logit_scale * image_features @ text_features.T
                paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]

            predicted = [torch.argsort(-k) for k in paired_logits_list]
            hps_ranking = [[predicted[i].tolist().index(j) for j in range(n)] for i,n in enumerate(num_images)]
            labels = [label for label in labels.split(num_images.tolist())]
            score +=sum([calc_ImageReward(paired_logits_list[i].tolist(), labels[i]) for i in range(len(hps_ranking))])
    print('ImageReward:', score/total)

def evaluate_rank(data_path, image_folder, model):
    meta_file = data_path + '/test.json'
    dataset = RankingDataset(meta_file, image_folder, preprocess_val, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_rank)
    
    score = 0
    total = len(dataset)
    all_rankings = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            images, num_images, labels, texts = batch
            images = images.to(device=device, non_blocking=True)
            texts = texts.to(device=device, non_blocking=True)
            num_images = num_images.to(device=device, non_blocking=True)
            labels = labels.to(device=device, non_blocking=True)

            with torch.cuda.amp.autocast():
                outputs = model(images, texts)
                image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"]
                logits_per_image = logit_scale * image_features @ text_features.T
                paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]

            predicted = [torch.argsort(-k) for k in paired_logits_list]
            hps_ranking = [[predicted[i].tolist().index(j) for j in range(n)] for i,n in enumerate(num_images)]
            labels = [label for label in labels.split(num_images.tolist())]
            all_rankings.extend(hps_ranking)
            score += sum([inversion_score(hps_ranking[i], labels[i]) for i in range(len(hps_ranking))])
    print('ranking_acc:', score/total)
    with open('logs/hps_rank.json', 'w') as f:
        json.dump(all_rankings, f)

def collate_eval(batch):
    images = torch.stack([sample[0] for sample in batch])
    captions = torch.cat([sample[1] for sample in batch])
    return images, captions


def evaluate_benchmark(data_path, root_dir, model):
    meta_dir = data_path
    model_list = os.listdir(root_dir)
    style_list = os.listdir(os.path.join(root_dir, model_list[0]))

    score = {}
    for model_id in model_list:
        score[model_id]={}
        for style in style_list:
            # score[model_id][style] = [0] * 10
            score[model_id][style] = []
            image_folder = os.path.join(root_dir, model_id, style)
            meta_file = os.path.join(meta_dir, f'{style}.json')
            dataset = BenchmarkDataset(meta_file, image_folder, preprocess_val, tokenizer)
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_eval)

            with torch.no_grad():
                for i, batch in enumerate(dataloader):
                    images, texts = batch
                    images = images.to(device=device, non_blocking=True)
                    texts = texts.to(device=device, non_blocking=True)

                    with torch.cuda.amp.autocast():
                        outputs = model(images, texts)
                        image_features, text_features = outputs["image_features"], outputs["text_features"]
                        logits_per_image = image_features @ text_features.T
                    # score[model_id][style][i] = torch.sum(torch.diagonal(logits_per_image)).cpu().item() / 80
                    score[model_id][style].extend(torch.diagonal(logits_per_image).cpu().tolist())
    print('-----------benchmark score ---------------- ')
    for model_id, data in score.items():
        for style , res in data.items():
            avg_score = [np.mean(res[i:i+80]) for i in range(0, 800, 80)]
            print(model_id, '\t', style, '\t', np.mean(avg_score), '\t', np.std(avg_score))


def evaluate_benchmark_DB(data_path, root_dir, model):
    meta_file = data_path + '/drawbench.json'
    model_list = os.listdir(root_dir)
    

    score = {}
    for model_id in model_list:
        image_folder = os.path.join(root_dir, model_id)
        dataset = BenchmarkDataset(meta_file, image_folder, preprocess_val, tokenizer)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_eval)
        score[model_id] = 0
        with torch.no_grad():
            for batch in tqdm(dataloader):
                images, texts = batch
                images = images.to(device=device, non_blocking=True)
                texts = texts.to(device=device, non_blocking=True)

                with torch.cuda.amp.autocast():
                    outputs = model(images, texts)
                    image_features, text_features = outputs["image_features"], outputs["text_features"]
                    logits_per_image = image_features @ text_features.T
                    diag = torch.diagonal(logits_per_image)
                score[model_id] += torch.sum(diag).cpu().item()
            score[model_id] = score[model_id] / len(dataset)
    # with open('logs/benchmark_score_DB.json', 'w') as f:
    #     json.dump(score, f)
    print('-----------drawbench score ---------------- ')
    for model, data in score.items():
        print(model, '\t', '\t', np.mean(data))


if args.data_type == 'ImageReward':
    evaluate_IR(args.data_path, args.image_path, model)
elif args.data_type == 'test':
    evaluate_rank(args.data_path, args.image_path, model)
elif args.data_type == 'benchmark':
    evaluate_benchmark(args.data_path, args.image_path, model)
elif args.data_type == 'drawbench':
    evaluate_benchmark_DB(args.data_path, args.image_path, model)
else:
    raise NotImplementedError