File size: 2,338 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import clip
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image

def read_image(imgid):
    from pathlib import Path
    vanilla = Path(imgid)
    fixed = Path(f"data_en/images/{imgid}")
    assert not (vanilla.exists() == fixed.exists()) # 両者共に存在/不在だと困る

    path = vanilla if vanilla.exists() else fixed
    return Image.open(path).convert("RGB")

class CLIPScore():
    def __init__(self,device="cuda"):
        self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=device)
        self.device = device

    def batchify(self, targets, batch_size):
        return [targets[i:i+batch_size] for i in range(0,len(targets),batch_size)]

    def __call__(self, mt_list, refs_list, img_list, no_ref=False):
        B = 32
        mt_list, refs_list, img_list = [self.batchify(x,B) for x in [mt_list,refs_list,img_list]]
        scores = []
        assert len(mt_list) == len(refs_list) == len(img_list)
        for mt, refs, imgs in (pbar:= tqdm(zip(mt_list,refs_list, img_list),total=len(mt_list))):
            pbar.set_description(f"CLIPScore (noref: {no_ref})")
            imgs = [read_image(imgid) for imgid in imgs]
            refs_token = []
            for ref_list in refs:
                refs_token.append([clip.tokenize("A photo depicts " + ref,truncate=True).to(self.device) for ref in ref_list])

            refs = [torch.cat(ref,dim=0) for ref in refs_token]
            mts = clip.tokenize(["A photo depicts " + x for x in mt],truncate=True).to(self.device)
            imgs = torch.cat([self.clip_preprocess(img).unsqueeze(0) for img in imgs],dim=0).to(self.device)

            imgs = self.clip.encode_image(imgs)
            mts = self.clip.encode_text(mts)
            cos = F.cosine_similarity(imgs, mts,eps=0)
            cos[cos < 0.] = 0.
            clip_score = 2.5 * cos

            if no_ref:
                scores.extend(clip_score.tolist())
                continue

            cos = F.cosine_similarity(imgs, mts,eps=0)
            cos[cos < 0.] = 0.
            clip_score2 = cos

            assert clip_score.shape == clip_score2.shape
            clip_score = 2.0 * clip_score * clip_score2 / (clip_score + clip_score2)

            if not no_ref:
                scores.extend(clip_score.tolist())
        
        return scores