File size: 6,173 Bytes
aa7fb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch
from PIL import Image
import numpy as np
import torch.nn as nn
import time

import Model.CLIP.cn_clip.clip as clip
from Model.CLIP.cn_clip.clip import load_from_name, available_models
import random

def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.to(torch.float32)
        if p.grad:
            p.grad.data = p.grad.data.to(torch.float32)
    return model


def calculate_similarity(model, img1, img2, texts1, texts2, device, turn=0):
    batch_size = len(img1)
    probs_all = []
    text_all = []
    feed_backs = []
    G_fbs = []
    T_fbs = []
    for i in range(batch_size):
        img1_i = img1[i].unsqueeze(0).to(device)
        img2_i = img2[i].unsqueeze(0).to(device)
        image = torch.cat((img1_i, img2_i), dim=0).to(device)
        text = []
        for j in range(len(texts1)):
            text.append(texts1[j][i])
        for j in range(len(texts2)):
            text.append(texts2[j][i])
        text_all.append(text)
        text_token = clip.tokenize(text).to(device)
        # model
        model.eval()
        image_features, text_features, logit_scale = model(image, text_token)
        # print(image_features, text_features, logit_scale)
        logit_scale = logit_scale.mean()
        # print("logit_scale:", logit_scale)
        logits_per_text = logit_scale * text_features @ image_features.t()

        probs = logits_per_text.detach().softmax(dim=-1).cpu().numpy()
        # print("Label probs:\n", np.around(probs,3))  # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]
        probs = np.around(probs,3)
        probs_all.append(probs)

        if turn == 0:
            G_fb = ""
            T_fb = text[3]
            feed_back = "我想要" + T_fb + "。"
            feed_backs.append(feed_back)
            G_fbs.append(G_fb)
            T_fbs.append(T_fb)

        else:
            G_fb = ""
            delta = -1
            for i in [2,1,0]:
                if probs[i][0] - probs[i][1] > 0:
                    if text[i] == "":
                        continue
                    elif delta < probs[i][0] - probs[i][1]:
                        delta = probs[i][0] - probs[i][1]
                        G_fb = text[i]
            T_fb = ""
            delta = -1
            for i in [5,4,3]:
                if text[i] == "":
                    continue
                elif delta <= probs[i][1] - probs[i][0]:
                    delta = probs[i][1] - probs[i][0]
                    T_fb = text[i]

            if len(G_fb) != 0 and len(T_fb) != 0:
                feed_back = "我不要" + G_fb + ",我想要" + T_fb + "。"
            elif len(G_fb) != 0 and len(T_fb) == 0:
                feed_back = "我不要" + G_fb + "。"
            elif len(G_fb) == 0 and len(T_fb) != 0:
                feed_back = "我想要" + T_fb + "。"
            else:
                feed_back = "换一个。"
            feed_backs.append(feed_back)
            G_fbs.append(G_fb)
            T_fbs.append(T_fb)
    return probs_all, text_all, feed_backs, G_fbs, T_fbs

def calculate_similarity_one(model, img1, texts1, device, objs, dict_id, T_ids):
    batch_size = len(img1)
    for i in range(batch_size):
        dict_text = dict_id[T_ids[i].cpu().item()]
        img1_i = img1[i].unsqueeze(0).to(device)
        text = []
        text3 = ""
        for j in range(len(texts1)):
            text.append(texts1[j][i])
            text3 += texts1[j][i] + ','
        text_token = clip.tokenize(text).to(device)
        # print("text:", text)
        # model
        model.eval()
        image_features, text_features, logit_scale = model(img1_i, text_token)
        logit_scale = logit_scale.mean()
        logits_per_text = logit_scale * text_features @ image_features.t()
        text_score = logits_per_text.detach().cpu().numpy()
        # print(objs)
        dict_text[objs[i]] = text3
        # print(dict_text)
        # for j in range(len(text)):
        #     print(text[j], text_score[j][0])

    return dict_id

def get_obj(model, G, T, objs, top_k1, top_k2):
    # model
    model.eval()
    image_features, text_features, logit_scale = model(G, objs)
    logits_per_text = logit_scale * text_features @ image_features.t()
    top_list = logits_per_text.detach().cpu().topk(top_k1, dim=0)[1]
    obj_pick = np.random.choice(range(top_k1))
    obj_index1 = top_list[obj_pick]
    image_features, text_features, logit_scale = model(T, objs)
    logits_per_text = logit_scale * text_features @ image_features.t()
    top_list = logits_per_text.detach().cpu().topk(top_k2, dim=0)[1]
    obj_pick = np.random.choice(range(top_k2))
    obj_index2 = top_list[obj_pick]

    random_pick = np.random.choice(range(4))
    if random_pick == 0:
        obj_index = obj_index2
    else:
        obj_index = obj_index1
    return obj_index

def get_objs(model, T, objs, top_k2):
    # model
    model.eval()

    image_features, text_features, logit_scale = model(T, objs)
    logits_per_text = logit_scale * text_features @ image_features.t()
    top_list = logits_per_text.detach().cpu().topk(top_k2, dim=0)

    return top_list[0], top_list[1]


if __name__ == "__main__":

    # texts1 = ["白色吊顶通过金属和黑线勾边打造出几何层叠的效果",
    #             "灰地毯的中性色调与床品、窗帘以深浅对比",
    #             "白灰色床品搭配灰色地毯"]
    texts1 = ["空间","客厅","卧室","墙面","餐厅","公寓","住宅","沙发","家具","地毯","厨房","书房","背景墙","吊灯","墙",
           "卫生间","儿童","床品","装饰","壁纸","地板","窗帘","吊顶","餐椅","别墅","地面","结构","布艺","餐桌","画"]
    texts2 = [
        "打造了一个现代、讲究的温馨的空间。",
        "与地毯的图案",
        "灰地毯的中性色调与床品、窗帘以深浅对比",
        "白灰色床品搭配灰色地毯"
    ]

    print(get_clip_score('902/epoch_50.pt', 423943, texts1))
    # print(texts1, probs1)
    # for i in range(len(texts1)):
    #     print(probs1[i], texts1[i])
    # for i in range(len(texts2)):
    #     print(probs2[i], texts2[i])