File size: 6,173 Bytes
aa7fb02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import torch
from PIL import Image
import numpy as np
import torch.nn as nn
import time
import Model.CLIP.cn_clip.clip as clip
from Model.CLIP.cn_clip.clip import load_from_name, available_models
import random
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.to(torch.float32)
if p.grad:
p.grad.data = p.grad.data.to(torch.float32)
return model
def calculate_similarity(model, img1, img2, texts1, texts2, device, turn=0):
batch_size = len(img1)
probs_all = []
text_all = []
feed_backs = []
G_fbs = []
T_fbs = []
for i in range(batch_size):
img1_i = img1[i].unsqueeze(0).to(device)
img2_i = img2[i].unsqueeze(0).to(device)
image = torch.cat((img1_i, img2_i), dim=0).to(device)
text = []
for j in range(len(texts1)):
text.append(texts1[j][i])
for j in range(len(texts2)):
text.append(texts2[j][i])
text_all.append(text)
text_token = clip.tokenize(text).to(device)
# model
model.eval()
image_features, text_features, logit_scale = model(image, text_token)
# print(image_features, text_features, logit_scale)
logit_scale = logit_scale.mean()
# print("logit_scale:", logit_scale)
logits_per_text = logit_scale * text_features @ image_features.t()
probs = logits_per_text.detach().softmax(dim=-1).cpu().numpy()
# print("Label probs:\n", np.around(probs,3)) # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]
probs = np.around(probs,3)
probs_all.append(probs)
if turn == 0:
G_fb = ""
T_fb = text[3]
feed_back = "我想要" + T_fb + "。"
feed_backs.append(feed_back)
G_fbs.append(G_fb)
T_fbs.append(T_fb)
else:
G_fb = ""
delta = -1
for i in [2,1,0]:
if probs[i][0] - probs[i][1] > 0:
if text[i] == "":
continue
elif delta < probs[i][0] - probs[i][1]:
delta = probs[i][0] - probs[i][1]
G_fb = text[i]
T_fb = ""
delta = -1
for i in [5,4,3]:
if text[i] == "":
continue
elif delta <= probs[i][1] - probs[i][0]:
delta = probs[i][1] - probs[i][0]
T_fb = text[i]
if len(G_fb) != 0 and len(T_fb) != 0:
feed_back = "我不要" + G_fb + ",我想要" + T_fb + "。"
elif len(G_fb) != 0 and len(T_fb) == 0:
feed_back = "我不要" + G_fb + "。"
elif len(G_fb) == 0 and len(T_fb) != 0:
feed_back = "我想要" + T_fb + "。"
else:
feed_back = "换一个。"
feed_backs.append(feed_back)
G_fbs.append(G_fb)
T_fbs.append(T_fb)
return probs_all, text_all, feed_backs, G_fbs, T_fbs
def calculate_similarity_one(model, img1, texts1, device, objs, dict_id, T_ids):
batch_size = len(img1)
for i in range(batch_size):
dict_text = dict_id[T_ids[i].cpu().item()]
img1_i = img1[i].unsqueeze(0).to(device)
text = []
text3 = ""
for j in range(len(texts1)):
text.append(texts1[j][i])
text3 += texts1[j][i] + ','
text_token = clip.tokenize(text).to(device)
# print("text:", text)
# model
model.eval()
image_features, text_features, logit_scale = model(img1_i, text_token)
logit_scale = logit_scale.mean()
logits_per_text = logit_scale * text_features @ image_features.t()
text_score = logits_per_text.detach().cpu().numpy()
# print(objs)
dict_text[objs[i]] = text3
# print(dict_text)
# for j in range(len(text)):
# print(text[j], text_score[j][0])
return dict_id
def get_obj(model, G, T, objs, top_k1, top_k2):
# model
model.eval()
image_features, text_features, logit_scale = model(G, objs)
logits_per_text = logit_scale * text_features @ image_features.t()
top_list = logits_per_text.detach().cpu().topk(top_k1, dim=0)[1]
obj_pick = np.random.choice(range(top_k1))
obj_index1 = top_list[obj_pick]
image_features, text_features, logit_scale = model(T, objs)
logits_per_text = logit_scale * text_features @ image_features.t()
top_list = logits_per_text.detach().cpu().topk(top_k2, dim=0)[1]
obj_pick = np.random.choice(range(top_k2))
obj_index2 = top_list[obj_pick]
random_pick = np.random.choice(range(4))
if random_pick == 0:
obj_index = obj_index2
else:
obj_index = obj_index1
return obj_index
def get_objs(model, T, objs, top_k2):
# model
model.eval()
image_features, text_features, logit_scale = model(T, objs)
logits_per_text = logit_scale * text_features @ image_features.t()
top_list = logits_per_text.detach().cpu().topk(top_k2, dim=0)
return top_list[0], top_list[1]
if __name__ == "__main__":
# texts1 = ["白色吊顶通过金属和黑线勾边打造出几何层叠的效果",
# "灰地毯的中性色调与床品、窗帘以深浅对比",
# "白灰色床品搭配灰色地毯"]
texts1 = ["空间","客厅","卧室","墙面","餐厅","公寓","住宅","沙发","家具","地毯","厨房","书房","背景墙","吊灯","墙",
"卫生间","儿童","床品","装饰","壁纸","地板","窗帘","吊顶","餐椅","别墅","地面","结构","布艺","餐桌","画"]
texts2 = [
"打造了一个现代、讲究的温馨的空间。",
"与地毯的图案",
"灰地毯的中性色调与床品、窗帘以深浅对比",
"白灰色床品搭配灰色地毯"
]
print(get_clip_score('902/epoch_50.pt', 423943, texts1))
# print(texts1, probs1)
# for i in range(len(texts1)):
# print(probs1[i], texts1[i])
# for i in range(len(texts2)):
# print(probs2[i], texts2[i])
|