import sys import os import streamlit as st import torch from PIL import Image from Model.CLIP.cn_clip.clip import load_from_name import Model.CLIP.usage.calculate from Model.fur_rl.models.retriever_rl import DQN_v3 import Model.CLIP.cn_clip.clip as clip import recommendation.datasets.img_preprocess import recommendation.utils.ranker_1 st.title('Recommendation System V1') def txt_embed(t_txt, g_txt, fb_txt, net, batch_size, device1): with torch.no_grad(): f_embed_t = net.actor_net.txt_embed(clip.tokenize(t_txt).to(device1)) f_embed_g = net.actor_net.txt_embed(clip.tokenize(g_txt).to(device1)) for i in range(batch_size): if len(t_txt[i]) == 0: f_embed_t[i] = torch.zeros((1, 512)).to(device1) if len(g_txt[i]) == 0: f_embed_g[i] = torch.zeros((1, 512)).to(device1) f_embed = net.actor_net.txt_embed(clip.tokenize(fb_txt).to(device1)) return f_embed_t, f_embed_g, f_embed def start(modelName): device1 = "cpu" if 'model' not in st.session_state: # load model sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1, download_root='../../data/pretrained_weights/', resume='./models/ClipEncoder.pt') sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model) net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1) if modelName == 'p1-t10-g.pth': load_model_name = './models/p1-t10-g.pth' elif modelName == 'p2-t10-g.pth': load_model_name = './models/p2-t10-g.pth' else: load_model_name = None net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict']) net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer']) st.session_state['model'] = load_model_name st.session_state['net'] = net st.write(modelName + ' first loaded') if 'model' in st.session_state: if st.session_state['model'] != modelName: # load model device1 = "cpu" sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1, download_root='../../data/pretrained_weights/', resume='./models/ClipEncoder.pt') sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model) net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1) if modelName == 'p1-t10-g.pth': load_model_name = './models/p1-t10-g.pth' elif modelName == 'p2-t10-g.pth': load_model_name = './models/p2-t10-g.pth' else: load_model_name = './models/p1-t10-g.pth' net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict']) net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer']) net.actor_net.eval() st.session_state['model'] = modelName st.session_state['net'] = net st.write(modelName + ' reloaded') st.session_state['turn'] = 0 st.session_state['f_his_txt'] = [' '] st.session_state['r_t'] = 0 st.session_state['g_t_id'] = 408203 if 'stop' not in st.session_state: st.session_state['stop'] = 0 if 'chat_history' not in st.session_state or st.session_state['stop'] == 1: st.session_state['chat_history'] = [] if 'turn' not in st.session_state or st.session_state['stop'] == 1: st.session_state['turn'] = 0 st.session_state['r_t'] = 0 st.session_state['g_t_id'] = 408203 if st.session_state['stop'] == 1: st.warning('It was the last turn and Retry Please!!') if 'f_embed_his_t' not in st.session_state or st.session_state['stop'] == 1: f_embed_his_t = torch.zeros((1, 12, 512)) f_embed_his_g = torch.zeros((1, 12, 512)) f_embed_his = torch.zeros((1, 12, 512)) st.session_state['f_embed_his_t'] = f_embed_his_t st.session_state['f_embed_his_g'] = f_embed_his_g st.session_state['f_embed_his'] = f_embed_his G_imgs_his = torch.zeros((1, 12, 512)) st.session_state['G_imgs_his'] = G_imgs_his st.session_state['stop'] = 0 st.write('f_embed_his ready') if 'ranker' not in st.session_state: test_img_ids = "./recommendation/datasets/test_img_id_r.csv" dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids) ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64) # net = st.session_state['net'] # ranker_test.update_emb(model=net.actor_net) # 220.0789999961853s; 78s on 3090 # save ranker # torch.save(ranker_test, './ranker_test.pth') # load ranker from pth ranker_test = torch.load('./models/ranker_test.pth', map_location=device1) st.write('ranker ready') st.session_state['ranker'] = ranker_test def stop(): st.session_state['stop'] = 1 # st.write(st.session_state) st.write(st.session_state['model'] + ' loaded') st.write('turn: ' + str(st.session_state['turn']) + '/10') def interactive(r_t, fg_t, fp_t): # input st.session_state['turn'] += 1 st.session_state['r_t'] = int(r_t) g_t_id = st.session_state['g_t_id'] # st.write(st.session_state) def model(r_t, fg_t, fp_t): t = st.session_state['turn'] ranker = st.session_state['ranker'] net = st.session_state['net'] f_embed_his_t = st.session_state['f_embed_his_t'].to(device1) f_embed_his_g = st.session_state['f_embed_his_g'].to(device1) f_embed_his = st.session_state['f_embed_his'].to(device1) G_imgs_his = st.session_state['G_imgs_his'].to(device1) f_embed_t, f_embed_g, f_embed = txt_embed(fp_t, fg_t, " ", net, 1, device1) f_embed_his_t[:, t] = f_embed_t f_embed_his_g[:, t] = f_embed_g f_embed_his[:, t] = f_embed distence_, distance_t, distance_g = ranker.actions_metric(f_embed_his.to(torch.float32), f_embed_his_t.to(torch.float32), f_embed_his_g.to(torch.float32), 1, 12) p, q, maxC_img, maxC_ids, C_imgs, C_ids, C_pre, p_db = net.actor_net.forward( G_imgs_his, torch.cat((f_embed_his_t, f_embed_his_g), dim=1), torch.cat((distance_t, -distance_g), dim=1), ranker, k=4) if t >= 10: max_index = torch.zeros(1).long().to(device1) for batch_i in range(1): max_index[batch_i] = 0 if t > 10: st.session_state['stop'] = 1 return 408203 print("final action: ", max_index) else: if st.session_state['model'] == './models/p1-t10-g.pth': max_index = torch.zeros(1).long().to(device1) for batch_i in range(1): max_index[batch_i] = 0 else: max_index = p.argmax(dim=1) print("net action: ", max_index) G_next_ids = torch.zeros(1).to(device1) G_next_imgs = torch.zeros(1, 512).to(device1) G_next_pre = torch.zeros(1, 3, 224, 224).to(device1) for i in range(1): G_next_ids[i] = C_ids[i][max_index[i]] # [16] G_next_imgs[i] = C_imgs[i][max_index[i]] # [16, 512] G_next_pre[i] = C_pre[i][max_index[i]] # [16, 3, 224, 224] G_next_ids = G_next_ids.cpu().numpy().astype(int) G_imgs_his_next = G_imgs_his G_imgs_his_next[:, t + 1] = G_next_imgs st.session_state[G_imgs_his] = G_imgs_his_next st.session_state[f_embed_his_g] = f_embed_his_g st.session_state[f_embed_his_t] = f_embed_his_t st.session_state[f_embed_his] = f_embed_his return G_next_ids[0] def run_model(r_t, fg_t, fp_t): g_t_id_next = model(r_t, fg_t, fp_t) st.session_state['g_t_id'] = g_t_id_next return g_t_id_next g_t_id_next = run_model(r_t, fg_t, fp_t) st.session_state['g_t_id'] = g_t_id_next # st.write(st.session_state) st.session_state.chat_history.append( {"user": "(positive)" + fp_t + ';(negative)' + fg_t, "bot": str(st.session_state['g_t_id']), "reward": r_t}) fp_t = st.text_input('Input text(positive)', value=' ') fg_t = st.text_input('Input text(negative)', value=' ') if fp_t == '': fp_t = ' ' if fg_t == '': fg_t = ' ' r_t = st.text_input('Input reward', value='0') st.button('Interactive', on_click=interactive, args=(r_t, fg_t, fp_t)) st.button('Stop', on_click=stop) g_t_id = st.session_state['g_t_id'] st.image(Image.open('./imgs/' + str(g_t_id) + '.jpg'), caption='g_t_next:' + str(g_t_id), use_column_width='auto') for i in range(len(st.session_state.chat_history)): st.sidebar.write("User: " + st.session_state.chat_history[i]['user']) st.sidebar.write("reward: " + st.session_state.chat_history[i]['reward']) st.sidebar.write("turn: " + str(i)) st.sidebar.write('---------------------------------------------------') st.sidebar.image(Image.open('./imgs/' + st.session_state.chat_history[i]['bot'] + '.jpg'), caption='g_t_next:' + str(g_t_id), width=224) st.sidebar.write('---------------------------------------------------') def main(): modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth', 'only CLIP')) start(modelName) if __name__ == "__main__": main()