File size: 9,429 Bytes
3c6e15d 2353bb5 3c6e15d aa7fb02 2353bb5 3c6e15d aa7fb02 3c6e15d aa7fb02 3c6e15d aa7fb02 3c6e15d aa7fb02 3c6e15d aa7fb02 3c6e15d aa7fb02 3c6e15d d3372d9 aa7fb02 d3372d9 aa7fb02 d3372d9 2353bb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import sys
import os
import streamlit as st
import torch
from PIL import Image
from Model.CLIP.cn_clip.clip import load_from_name
import Model.CLIP.usage.calculate
from Model.fur_rl.models.retriever_rl import DQN_v3
import Model.CLIP.cn_clip.clip as clip
import recommendation.datasets.img_preprocess
import recommendation.utils.ranker_1
import numpy as np
st.title('Recommendation System V1')
def txt_embed(t_txt, g_txt, fb_txt, net, batch_size, device1):
with torch.no_grad():
f_embed_t = net.actor_net.txt_embed(clip.tokenize(t_txt).to(device1))
f_embed_g = net.actor_net.txt_embed(clip.tokenize(g_txt).to(device1))
for i in range(batch_size):
if len(t_txt[i]) == 0:
f_embed_t[i] = torch.zeros((1, 512)).to(device1)
if len(g_txt[i]) == 0:
f_embed_g[i] = torch.zeros((1, 512)).to(device1)
f_embed = net.actor_net.txt_embed(clip.tokenize(fb_txt).to(device1))
return f_embed_t, f_embed_g, f_embed
def start(modelName):
device1 = "cuda" if torch.cuda.is_available() else "cpu"
if 'model' not in st.session_state:
# load model
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
download_root='../../data/pretrained_weights/',
resume=r'models\ClipEncoder.pt')
sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model)
net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1)
if modelName == 'p1-t10-g.pth':
load_model_name = './models/p1-t10-g.pth'
elif modelName == 'p2-t10-g.pth':
load_model_name = './models/p2-t10-g.pth'
else:
load_model_name = None
net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
st.session_state['model'] = load_model_name
st.session_state['net'] = net
st.write(modelName + ' first loaded')
if 'model' in st.session_state:
if st.session_state['model'] != modelName:
# load model
device1 = "cuda" if torch.cuda.is_available() else "cpu"
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
download_root='../../data/pretrained_weights/',
resume=r'models\ClipEncoder.pt')
sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model)
net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1)
if modelName == 'p1-t10-g.pth':
load_model_name = './models/p1-t10-g.pth'
elif modelName == 'p2-t10-g.pth':
load_model_name = './models/p2-t10-g.pth'
else:
load_model_name = None
net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
st.session_state['model'] = modelName
st.session_state['net'] = net
st.write(modelName + ' reloaded')
st.session_state['turn'] = 0
st.session_state['f_his_txt'] = [' ']
st.session_state['r_t'] = 0
st.session_state['g_t_id'] = 408203
if 'stop' not in st.session_state:
st.session_state['stop'] = 0
if 'turn' not in st.session_state or st.session_state['stop'] == 1:
st.session_state['turn'] = 0
st.session_state['r_t'] = 0
st.session_state['g_t_id'] = 408203
if st.session_state['stop'] == 1:
st.warning('It was the last turn and Retry Please!!')
if 'f_embed_his_t' not in st.session_state or st.session_state['stop'] == 1:
f_embed_his_t = torch.zeros((1, 12, 512))
f_embed_his_g = torch.zeros((1, 12, 512))
f_embed_his = torch.zeros((1, 12, 512))
st.session_state['f_embed_his_t'] = f_embed_his_t
st.session_state['f_embed_his_g'] = f_embed_his_g
st.session_state['f_embed_his'] = f_embed_his
G_imgs_his = torch.zeros((1, 12, 512))
st.session_state['G_imgs_his'] = G_imgs_his
st.session_state['stop'] = 0
st.write('f_embed_his ready')
if 'ranker' not in st.session_state:
test_img_ids = "./recommendation/datasets/test_img_id_r.csv"
dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4)
ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64)
st.session_state['ranker'] = ranker_test
net = st.session_state['net']
ranker_test.update_emb(model=net.actor_net) # 220.0789999961853s; 78s on 3090
st.write('ranker ready')
# st.write(st.session_state)
st.write(st.session_state['model'] + ' loaded')
st.write('turn: ' + str(st.session_state['turn']) + '/10')
def interactive(r_t, fg_t, fp_t):
# input
st.session_state['turn'] += 1
st.session_state['r_t'] = int(r_t)
g_t_id = st.session_state['g_t_id']
# st.image(Image.open('./imgs/' + str(g_t_id) + '.jpg').resize((224, 224)), caption='g_t', use_column_width='auto')
# st.write(st.session_state)
def model(r_t, fg_t, fp_t):
t = st.session_state['turn']
ranker = st.session_state['ranker']
net = st.session_state['net']
f_embed_his_t = st.session_state['f_embed_his_t'].to(device1)
f_embed_his_g = st.session_state['f_embed_his_g'].to(device1)
f_embed_his = st.session_state['f_embed_his'].to(device1)
G_imgs_his = st.session_state['G_imgs_his'].to(device1)
f_embed_t, f_embed_g, f_embed = txt_embed(fp_t, fg_t, " ", net, 1, device1)
f_embed_his_t[:, t] = f_embed_t
f_embed_his_g[:, t] = f_embed_g
f_embed_his[:, t] = f_embed
distence_, distance_t, distance_g = ranker.actions_metric(f_embed_his.to(torch.float32), f_embed_his_t.to(torch.float32), f_embed_his_g.to(torch.float32), 1, 12)
p, q, maxC_img, maxC_ids, C_imgs, C_ids, C_pre, p_db = net.actor_net.forward(
G_imgs_his, torch.cat((f_embed_his_t, f_embed_his_g), dim=1),
torch.cat((distance_t, -distance_g), dim=1),
ranker, k=4)
if t >= 10:
max_index = torch.zeros(1).long().to(device1)
for batch_i in range(1):
max_index[batch_i] = 0
if t > 10:
st.session_state['stop'] = 1
return 408203
print("final action: ", max_index)
else:
if st.session_state['model'] == './models/p1-t10-g.pth':
max_index = torch.zeros(1).long().to(device1)
for batch_i in range(1):
max_index[batch_i] = 0
else:
max_index = p.argmax(dim=1)
print("net action: ", max_index)
G_next_ids = torch.zeros(1).to(device1)
G_next_imgs = torch.zeros(1, 512).to(device1)
G_next_pre = torch.zeros(1, 3, 224, 224).to(device1)
for i in range(1):
G_next_ids[i] = C_ids[i][max_index[i]] # [16]
G_next_imgs[i] = C_imgs[i][max_index[i]] # [16, 512]
G_next_pre[i] = C_pre[i][max_index[i]] # [16, 3, 224, 224]
G_next_ids = G_next_ids.cpu().numpy().astype(int)
G_imgs_his_next = G_imgs_his
G_imgs_his_next[:, t + 1] = G_next_imgs
st.session_state[G_imgs_his] = G_imgs_his_next
st.session_state[f_embed_his_g] = f_embed_his_g
st.session_state[f_embed_his_t] = f_embed_his_t
st.session_state[f_embed_his] = f_embed_his
return G_next_ids[0]
def run_model(r_t, fg_t, fp_t):
g_t_id_next = model(r_t, fg_t, fp_t)
st.session_state['g_t_id'] = g_t_id_next
return g_t_id_next
g_t_id_next = run_model(r_t, fg_t, fp_t)
st.session_state['g_t_id'] = g_t_id_next
# st.write(st.session_state)
fp_t = st.text_input('Input text(positive)', value=' ')
fg_t = st.text_input('Input text(negative)', value=' ')
if fp_t == '':
fp_t = ' '
if fg_t == '':
fg_t = ' '
r_t = st.text_input('Input reward', value='0')
st.button('Interactive', on_click=interactive, args=(r_t, fg_t, fp_t))
g_t_id = st.session_state['g_t_id']
st.image(Image.open('./imgs/' + str(g_t_id) + '.jpg'),
caption='g_t_next:' + str(g_t_id), use_column_width='auto')
def main():
modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth'))
start(modelName)
if __name__ == "__main__":
main()
|