DRv2 / app.py
Zhonathon's picture
Update app.py
bba8baa
raw
history blame
10.4 kB
import sys
import os
import streamlit as st
import torch
from PIL import Image
from Model.CLIP.cn_clip.clip import load_from_name
import Model.CLIP.usage.calculate
from Model.fur_rl.models.retriever_rl import DQN_v3
import Model.CLIP.cn_clip.clip as clip
import recommendation.datasets.img_preprocess
import recommendation.utils.ranker_1
st.title('Recommendation System V1')
def txt_embed(t_txt, g_txt, fb_txt, net, batch_size, device1):
with torch.no_grad():
f_embed_t = net.actor_net.txt_embed(clip.tokenize(t_txt).to(device1))
f_embed_g = net.actor_net.txt_embed(clip.tokenize(g_txt).to(device1))
for i in range(batch_size):
if len(t_txt[i]) == 0:
f_embed_t[i] = torch.zeros((1, 512)).to(device1)
if len(g_txt[i]) == 0:
f_embed_g[i] = torch.zeros((1, 512)).to(device1)
f_embed = net.actor_net.txt_embed(clip.tokenize(fb_txt).to(device1))
return f_embed_t, f_embed_g, f_embed
def start(modelName):
device1 = "cpu"
if 'model' not in st.session_state:
# load model
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
download_root='../../data/pretrained_weights/',
resume='./models/ClipEncoder.pt')
sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model)
net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1)
if modelName == 'p1-t10-g.pth':
load_model_name = './models/p1-t10-g.pth'
elif modelName == 'p2-t10-g.pth':
load_model_name = './models/p2-t10-g.pth'
else:
load_model_name = None
net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
st.session_state['model'] = load_model_name
st.session_state['net'] = net
st.write(modelName + ' first loaded')
if 'model' in st.session_state:
if st.session_state['model'] != modelName:
# load model
device1 = "cpu"
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
download_root='../../data/pretrained_weights/',
resume='./models/ClipEncoder.pt')
sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model)
net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1)
if modelName == 'p1-t10-g.pth':
load_model_name = './models/p1-t10-g.pth'
elif modelName == 'p2-t10-g.pth':
load_model_name = './models/p2-t10-g.pth'
else:
load_model_name = './models/p1-t10-g.pth'
net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
net.actor_net.eval()
st.session_state['model'] = modelName
st.session_state['net'] = net
st.write(modelName + ' reloaded')
st.session_state['turn'] = 0
st.session_state['f_his_txt'] = [' ']
st.session_state['r_t'] = 0
st.session_state['g_t_id'] = 408203
if 'stop' not in st.session_state:
st.session_state['stop'] = 0
if 'chat_history' not in st.session_state or st.session_state['stop'] == 1:
st.session_state['chat_history'] = []
if 'turn' not in st.session_state or st.session_state['stop'] == 1:
st.session_state['turn'] = 0
st.session_state['r_t'] = 0
st.session_state['g_t_id'] = 408203
if st.session_state['stop'] == 1:
st.warning('It was the last turn and Retry Please!!')
if 'f_embed_his_t' not in st.session_state or st.session_state['stop'] == 1:
f_embed_his_t = torch.zeros((1, 12, 512))
f_embed_his_g = torch.zeros((1, 12, 512))
f_embed_his = torch.zeros((1, 12, 512))
st.session_state['f_embed_his_t'] = f_embed_his_t
st.session_state['f_embed_his_g'] = f_embed_his_g
st.session_state['f_embed_his'] = f_embed_his
G_imgs_his = torch.zeros((1, 12, 512))
st.session_state['G_imgs_his'] = G_imgs_his
st.session_state['stop'] = 0
st.write('f_embed_his ready')
if 'ranker' not in st.session_state:
test_img_ids = "./recommendation/datasets/test_img_id_r.csv"
dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids)
ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64)
# net = st.session_state['net']
# ranker_test.update_emb(model=net.actor_net) # 220.0789999961853s; 78s on 3090
# save ranker
# torch.save(ranker_test, './ranker_test.pth')
# load ranker from pth
ranker_test = torch.load('./models/ranker_test.pth', map_location=device1)
st.write('ranker ready')
st.session_state['ranker'] = ranker_test
def stop():
st.session_state['stop'] = 1
# st.write(st.session_state)
st.write(st.session_state['model'] + ' loaded')
st.write('turn: ' + str(st.session_state['turn']) + '/10')
def interactive(r_t, fg_t, fp_t):
# input
st.session_state['turn'] += 1
st.session_state['r_t'] = int(r_t)
g_t_id = st.session_state['g_t_id']
# st.write(st.session_state)
def model(r_t, fg_t, fp_t):
t = st.session_state['turn']
ranker = st.session_state['ranker']
net = st.session_state['net']
f_embed_his_t = st.session_state['f_embed_his_t'].to(device1)
f_embed_his_g = st.session_state['f_embed_his_g'].to(device1)
f_embed_his = st.session_state['f_embed_his'].to(device1)
G_imgs_his = st.session_state['G_imgs_his'].to(device1)
f_embed_t, f_embed_g, f_embed = txt_embed(fp_t, fg_t, " ", net, 1, device1)
f_embed_his_t[:, t] = f_embed_t
f_embed_his_g[:, t] = f_embed_g
f_embed_his[:, t] = f_embed
distence_, distance_t, distance_g = ranker.actions_metric(f_embed_his.to(torch.float32), f_embed_his_t.to(torch.float32), f_embed_his_g.to(torch.float32), 1, 12)
p, q, maxC_img, maxC_ids, C_imgs, C_ids, C_pre, p_db = net.actor_net.forward(
G_imgs_his, torch.cat((f_embed_his_t, f_embed_his_g), dim=1),
torch.cat((distance_t, -distance_g), dim=1),
ranker, k=4)
if t >= 10:
max_index = torch.zeros(1).long().to(device1)
for batch_i in range(1):
max_index[batch_i] = 0
if t > 10:
st.session_state['stop'] = 1
return 408203
print("final action: ", max_index)
else:
if st.session_state['model'] == './models/p1-t10-g.pth':
max_index = torch.zeros(1).long().to(device1)
for batch_i in range(1):
max_index[batch_i] = 0
else:
max_index = p.argmax(dim=1)
print("net action: ", max_index)
G_next_ids = torch.zeros(1).to(device1)
G_next_imgs = torch.zeros(1, 512).to(device1)
G_next_pre = torch.zeros(1, 3, 224, 224).to(device1)
for i in range(1):
G_next_ids[i] = C_ids[i][max_index[i]] # [16]
G_next_imgs[i] = C_imgs[i][max_index[i]] # [16, 512]
G_next_pre[i] = C_pre[i][max_index[i]] # [16, 3, 224, 224]
G_next_ids = G_next_ids.cpu().numpy().astype(int)
G_imgs_his_next = G_imgs_his
G_imgs_his_next[:, t + 1] = G_next_imgs
st.session_state[G_imgs_his] = G_imgs_his_next
st.session_state[f_embed_his_g] = f_embed_his_g
st.session_state[f_embed_his_t] = f_embed_his_t
st.session_state[f_embed_his] = f_embed_his
return G_next_ids[0]
def run_model(r_t, fg_t, fp_t):
g_t_id_next = model(r_t, fg_t, fp_t)
st.session_state['g_t_id'] = g_t_id_next
return g_t_id_next
g_t_id_next = run_model(r_t, fg_t, fp_t)
st.session_state['g_t_id'] = g_t_id_next
# st.write(st.session_state)
st.session_state.chat_history.append(
{"user": "(positive)" + fp_t + ';(negative)' + fg_t,
"bot": str(st.session_state['g_t_id']), "reward": r_t})
fp_t = st.text_input('Input text(positive)', value=' ')
fg_t = st.text_input('Input text(negative)', value=' ')
if fp_t == '':
fp_t = ' '
if fg_t == '':
fg_t = ' '
r_t = st.text_input('Input reward', value='0')
st.button('Interactive', on_click=interactive, args=(r_t, fg_t, fp_t))
st.button('Stop', on_click=stop)
g_t_id = st.session_state['g_t_id']
st.image(Image.open('./imgs/' + str(g_t_id) + '.jpg'),
caption='g_t_next:' + str(g_t_id), use_column_width='auto')
for i in range(len(st.session_state.chat_history)):
st.sidebar.write("User: " + st.session_state.chat_history[i]['user'])
st.sidebar.write("reward: " + st.session_state.chat_history[i]['reward'])
st.sidebar.write("turn: " + str(i))
st.sidebar.write('---------------------------------------------------')
st.sidebar.image(Image.open('./imgs/' + st.session_state.chat_history[i]['bot'] + '.jpg'),
caption='g_t_next:' + str(g_t_id), width=224)
st.sidebar.write('---------------------------------------------------')
def main():
modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth', 'only CLIP'))
start(modelName)
if __name__ == "__main__":
main()