|
import sys |
|
import os |
|
import streamlit as st |
|
import torch |
|
from PIL import Image |
|
from Model.CLIP.cn_clip.clip import load_from_name |
|
import Model.CLIP.usage.calculate |
|
from Model.fur_rl.models.retriever_rl import DQN_v3 |
|
import Model.CLIP.cn_clip.clip as clip |
|
import recommendation.datasets.img_preprocess |
|
import recommendation.utils.ranker_1 |
|
|
|
|
|
st.title('Recommendation System V1') |
|
|
|
def txt_embed(t_txt, g_txt, fb_txt, net, batch_size, device1): |
|
with torch.no_grad(): |
|
f_embed_t = net.actor_net.txt_embed(clip.tokenize(t_txt).to(device1)) |
|
f_embed_g = net.actor_net.txt_embed(clip.tokenize(g_txt).to(device1)) |
|
for i in range(batch_size): |
|
if len(t_txt[i]) == 0: |
|
f_embed_t[i] = torch.zeros((1, 512)).to(device1) |
|
if len(g_txt[i]) == 0: |
|
f_embed_g[i] = torch.zeros((1, 512)).to(device1) |
|
f_embed = net.actor_net.txt_embed(clip.tokenize(fb_txt).to(device1)) |
|
return f_embed_t, f_embed_g, f_embed |
|
|
|
|
|
def start(modelName): |
|
device1 = "cpu" |
|
if 'model' not in st.session_state: |
|
|
|
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1, |
|
download_root='../../data/pretrained_weights/', |
|
resume='./models/ClipEncoder.pt') |
|
sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model) |
|
net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1) |
|
if modelName == 'p1-t10-g.pth': |
|
load_model_name = './models/p1-t10-g.pth' |
|
elif modelName == 'p2-t10-g.pth': |
|
load_model_name = './models/p2-t10-g.pth' |
|
else: |
|
load_model_name = None |
|
net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict']) |
|
net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer']) |
|
st.session_state['model'] = load_model_name |
|
st.session_state['net'] = net |
|
st.write(modelName + ' first loaded') |
|
if 'model' in st.session_state: |
|
if st.session_state['model'] != modelName: |
|
|
|
device1 = "cpu" |
|
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1, |
|
download_root='../../data/pretrained_weights/', |
|
resume='./models/ClipEncoder.pt') |
|
sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model) |
|
net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1) |
|
if modelName == 'p1-t10-g.pth': |
|
load_model_name = './models/p1-t10-g.pth' |
|
elif modelName == 'p2-t10-g.pth': |
|
load_model_name = './models/p2-t10-g.pth' |
|
else: |
|
load_model_name = './models/p1-t10-g.pth' |
|
net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict']) |
|
net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer']) |
|
net.actor_net.eval() |
|
st.session_state['model'] = modelName |
|
st.session_state['net'] = net |
|
st.write(modelName + ' reloaded') |
|
st.session_state['turn'] = 0 |
|
st.session_state['f_his_txt'] = [' '] |
|
st.session_state['r_t'] = 0 |
|
st.session_state['g_t_id'] = 408203 |
|
if 'stop' not in st.session_state: |
|
st.session_state['stop'] = 0 |
|
if 'chat_history' not in st.session_state or st.session_state['stop'] == 1: |
|
st.session_state['chat_history'] = [] |
|
if 'turn' not in st.session_state or st.session_state['stop'] == 1: |
|
st.session_state['turn'] = 0 |
|
st.session_state['r_t'] = 0 |
|
st.session_state['g_t_id'] = 408203 |
|
if st.session_state['stop'] == 1: |
|
st.warning('It was the last turn and Retry Please!!') |
|
if 'f_embed_his_t' not in st.session_state or st.session_state['stop'] == 1: |
|
f_embed_his_t = torch.zeros((1, 12, 512)) |
|
f_embed_his_g = torch.zeros((1, 12, 512)) |
|
f_embed_his = torch.zeros((1, 12, 512)) |
|
st.session_state['f_embed_his_t'] = f_embed_his_t |
|
st.session_state['f_embed_his_g'] = f_embed_his_g |
|
st.session_state['f_embed_his'] = f_embed_his |
|
G_imgs_his = torch.zeros((1, 12, 512)) |
|
st.session_state['G_imgs_his'] = G_imgs_his |
|
st.session_state['stop'] = 0 |
|
st.write('f_embed_his ready') |
|
if 'ranker' not in st.session_state: |
|
test_img_ids = "./recommendation/datasets/test_img_id_r.csv" |
|
dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids) |
|
ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64) |
|
|
|
|
|
|
|
|
|
|
|
ranker_test = torch.load('./models/ranker_test.pth', map_location=device1) |
|
st.write('ranker ready') |
|
st.session_state['ranker'] = ranker_test |
|
|
|
def stop(): |
|
st.session_state['stop'] = 1 |
|
|
|
|
|
st.write(st.session_state['model'] + ' loaded') |
|
st.write('turn: ' + str(st.session_state['turn']) + '/10') |
|
def interactive(r_t, fg_t, fp_t): |
|
|
|
st.session_state['turn'] += 1 |
|
st.session_state['r_t'] = int(r_t) |
|
g_t_id = st.session_state['g_t_id'] |
|
|
|
|
|
def model(r_t, fg_t, fp_t): |
|
t = st.session_state['turn'] |
|
ranker = st.session_state['ranker'] |
|
net = st.session_state['net'] |
|
f_embed_his_t = st.session_state['f_embed_his_t'].to(device1) |
|
f_embed_his_g = st.session_state['f_embed_his_g'].to(device1) |
|
f_embed_his = st.session_state['f_embed_his'].to(device1) |
|
G_imgs_his = st.session_state['G_imgs_his'].to(device1) |
|
|
|
f_embed_t, f_embed_g, f_embed = txt_embed(fp_t, fg_t, " ", net, 1, device1) |
|
f_embed_his_t[:, t] = f_embed_t |
|
f_embed_his_g[:, t] = f_embed_g |
|
f_embed_his[:, t] = f_embed |
|
distence_, distance_t, distance_g = ranker.actions_metric(f_embed_his.to(torch.float32), f_embed_his_t.to(torch.float32), f_embed_his_g.to(torch.float32), 1, 12) |
|
|
|
p, q, maxC_img, maxC_ids, C_imgs, C_ids, C_pre, p_db = net.actor_net.forward( |
|
G_imgs_his, torch.cat((f_embed_his_t, f_embed_his_g), dim=1), |
|
torch.cat((distance_t, -distance_g), dim=1), |
|
ranker, k=4) |
|
if t >= 10: |
|
max_index = torch.zeros(1).long().to(device1) |
|
for batch_i in range(1): |
|
max_index[batch_i] = 0 |
|
if t > 10: |
|
st.session_state['stop'] = 1 |
|
return 408203 |
|
print("final action: ", max_index) |
|
else: |
|
if st.session_state['model'] == './models/p1-t10-g.pth': |
|
max_index = torch.zeros(1).long().to(device1) |
|
for batch_i in range(1): |
|
max_index[batch_i] = 0 |
|
else: |
|
max_index = p.argmax(dim=1) |
|
print("net action: ", max_index) |
|
|
|
G_next_ids = torch.zeros(1).to(device1) |
|
G_next_imgs = torch.zeros(1, 512).to(device1) |
|
G_next_pre = torch.zeros(1, 3, 224, 224).to(device1) |
|
|
|
for i in range(1): |
|
G_next_ids[i] = C_ids[i][max_index[i]] |
|
G_next_imgs[i] = C_imgs[i][max_index[i]] |
|
G_next_pre[i] = C_pre[i][max_index[i]] |
|
G_next_ids = G_next_ids.cpu().numpy().astype(int) |
|
G_imgs_his_next = G_imgs_his |
|
G_imgs_his_next[:, t + 1] = G_next_imgs |
|
st.session_state[G_imgs_his] = G_imgs_his_next |
|
st.session_state[f_embed_his_g] = f_embed_his_g |
|
st.session_state[f_embed_his_t] = f_embed_his_t |
|
st.session_state[f_embed_his] = f_embed_his |
|
return G_next_ids[0] |
|
|
|
|
|
|
|
def run_model(r_t, fg_t, fp_t): |
|
g_t_id_next = model(r_t, fg_t, fp_t) |
|
st.session_state['g_t_id'] = g_t_id_next |
|
return g_t_id_next |
|
g_t_id_next = run_model(r_t, fg_t, fp_t) |
|
st.session_state['g_t_id'] = g_t_id_next |
|
|
|
st.session_state.chat_history.append( |
|
{"user": "(positive)" + fp_t + ';(negative)' + fg_t, |
|
"bot": str(st.session_state['g_t_id']), "reward": r_t}) |
|
|
|
fp_t = st.text_input('Input text(positive)', value=' ') |
|
fg_t = st.text_input('Input text(negative)', value=' ') |
|
if fp_t == '': |
|
fp_t = ' ' |
|
if fg_t == '': |
|
fg_t = ' ' |
|
r_t = st.text_input('Input reward', value='0') |
|
st.button('Interactive', on_click=interactive, args=(r_t, fg_t, fp_t)) |
|
st.button('Stop', on_click=stop) |
|
g_t_id = st.session_state['g_t_id'] |
|
st.image(Image.open('./imgs/' + str(g_t_id) + '.jpg'), |
|
caption='g_t_next:' + str(g_t_id), use_column_width='auto') |
|
|
|
for i in range(len(st.session_state.chat_history)): |
|
st.sidebar.write("User: " + st.session_state.chat_history[i]['user']) |
|
st.sidebar.write("reward: " + st.session_state.chat_history[i]['reward']) |
|
st.sidebar.write("turn: " + str(i)) |
|
st.sidebar.write('---------------------------------------------------') |
|
st.sidebar.image(Image.open('./imgs/' + st.session_state.chat_history[i]['bot'] + '.jpg'), |
|
caption='g_t_next:' + str(g_t_id), width=224) |
|
st.sidebar.write('---------------------------------------------------') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth', 'only CLIP')) |
|
start(modelName) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|
|
|
|
|