File size: 9,530 Bytes
3c6e15d
 
2353bb5
3c6e15d
 
aa7fb02
 
 
 
 
 
6066f74
2353bb5
3c6e15d
 
aa7fb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba2f4a
aa7fb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4a3c00
aa7fb02
 
 
 
 
 
 
 
 
 
6066f74
aa7fb02
 
 
 
 
 
 
 
 
 
3c6e15d
 
aa7fb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6066f74
 
 
 
 
 
aa7fb02
6066f74
 
aa7fb02
 
 
 
 
3c6e15d
 
 
 
aa7fb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c6e15d
aa7fb02
 
 
3c6e15d
 
aa7fb02
 
 
 
 
 
3c6e15d
aa7fb02
 
 
 
3c6e15d
 
d3372d9
aa7fb02
 
 
d3372d9
 
aa7fb02
 
d3372d9
 
 
2353bb5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import sys
import os
import streamlit as st
import torch
from PIL import Image
from Model.CLIP.cn_clip.clip import load_from_name
import Model.CLIP.usage.calculate
from Model.fur_rl.models.retriever_rl import DQN_v3
import Model.CLIP.cn_clip.clip as clip
import recommendation.datasets.img_preprocess
import recommendation.utils.ranker_1


st.title('Recommendation System V1')

def txt_embed(t_txt, g_txt, fb_txt, net, batch_size, device1):
    with torch.no_grad():
        f_embed_t = net.actor_net.txt_embed(clip.tokenize(t_txt).to(device1))
        f_embed_g = net.actor_net.txt_embed(clip.tokenize(g_txt).to(device1))
        for i in range(batch_size):
            if len(t_txt[i]) == 0:
                f_embed_t[i] = torch.zeros((1, 512)).to(device1)
            if len(g_txt[i]) == 0:
                f_embed_g[i] = torch.zeros((1, 512)).to(device1)
        f_embed = net.actor_net.txt_embed(clip.tokenize(fb_txt).to(device1))
    return f_embed_t, f_embed_g, f_embed


def start(modelName):
    device1 = "cuda" if torch.cuda.is_available() else "cpu"
    if 'model' not in st.session_state:
        # load model
        sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
                                                                       download_root='../../data/pretrained_weights/',
                                                                       resume='./models/ClipEncoder.pt')
        sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model)
        net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1)
        if modelName == 'p1-t10-g.pth':
            load_model_name = './models/p1-t10-g.pth'
        elif modelName == 'p2-t10-g.pth':
            load_model_name = './models/p2-t10-g.pth'
        else:
            load_model_name = None
        net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
        net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
        st.session_state['model'] = load_model_name
        st.session_state['net'] = net
        st.write(modelName + ' first loaded')
    if 'model' in st.session_state:
        if st.session_state['model'] != modelName:
            # load model
            device1 = "cuda" if torch.cuda.is_available() else "cpu"
            sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
                                                                           download_root='../../data/pretrained_weights/',
                                                                           resume='./models/ClipEncoder.pt')
            sentence_clip_model = Model.CLIP.usage.calculate.convert_models_to_fp32(sentence_clip_model)
            net = DQN_v3(sentence_clip_model, sentence_clip_preprocess, device=device1)
            if modelName == 'p1-t10-g.pth':
                load_model_name = './models/p1-t10-g.pth'
            elif modelName == 'p2-t10-g.pth':
                load_model_name = './models/p2-t10-g.pth'
            else:
                load_model_name = None
            net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
            net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
            net.actor_net.eval()
            st.session_state['model'] = modelName
            st.session_state['net'] = net
            st.write(modelName + ' reloaded')
            st.session_state['turn'] = 0
            st.session_state['f_his_txt'] = [' ']
            st.session_state['r_t'] = 0
            st.session_state['g_t_id'] = 408203
    if 'stop' not in st.session_state:
        st.session_state['stop'] = 0
    if 'turn' not in st.session_state or st.session_state['stop'] == 1:
        st.session_state['turn'] = 0
        st.session_state['r_t'] = 0
        st.session_state['g_t_id'] = 408203
        if st.session_state['stop'] == 1:
            st.warning('It was the last turn and Retry Please!!')
    if 'f_embed_his_t' not in st.session_state or st.session_state['stop'] == 1:
        f_embed_his_t = torch.zeros((1, 12, 512))
        f_embed_his_g = torch.zeros((1, 12, 512))
        f_embed_his = torch.zeros((1, 12, 512))
        st.session_state['f_embed_his_t'] = f_embed_his_t
        st.session_state['f_embed_his_g'] = f_embed_his_g
        st.session_state['f_embed_his'] = f_embed_his
        G_imgs_his = torch.zeros((1, 12, 512))
        st.session_state['G_imgs_his'] = G_imgs_his
        st.session_state['stop'] = 0
        st.write('f_embed_his ready')
    if 'ranker' not in st.session_state:
        test_img_ids = "./recommendation/datasets/test_img_id_r.csv"
        dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids)
        ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64)
        # net = st.session_state['net']
        # ranker_test.update_emb(model=net.actor_net)  # 220.0789999961853s; 78s on 3090
        # save ranker
        # torch.save(ranker_test, './ranker_test.pth')
        # load ranker from pth
        ranker_test = torch.load('./models/ranker_test.pth', map_location=device1)
        st.write('ranker ready')
        st.session_state['ranker'] = ranker_test


    # st.write(st.session_state)
    st.write(st.session_state['model'] + ' loaded')
    st.write('turn: ' + str(st.session_state['turn']) + '/10')
    def interactive(r_t, fg_t, fp_t):
        # input
        st.session_state['turn'] += 1
        st.session_state['r_t'] = int(r_t)
        g_t_id = st.session_state['g_t_id']
        # st.image(Image.open('./imgs/' + str(g_t_id) + '.jpg').resize((224, 224)), caption='g_t', use_column_width='auto')


        # st.write(st.session_state)
        def model(r_t, fg_t, fp_t):
            t = st.session_state['turn']
            ranker = st.session_state['ranker']
            net = st.session_state['net']
            f_embed_his_t = st.session_state['f_embed_his_t'].to(device1)
            f_embed_his_g = st.session_state['f_embed_his_g'].to(device1)
            f_embed_his = st.session_state['f_embed_his'].to(device1)
            G_imgs_his = st.session_state['G_imgs_his'].to(device1)

            f_embed_t, f_embed_g, f_embed = txt_embed(fp_t, fg_t, " ", net, 1, device1)
            f_embed_his_t[:, t] = f_embed_t
            f_embed_his_g[:, t] = f_embed_g
            f_embed_his[:, t] = f_embed
            distence_, distance_t, distance_g = ranker.actions_metric(f_embed_his.to(torch.float32), f_embed_his_t.to(torch.float32), f_embed_his_g.to(torch.float32), 1, 12)

            p, q, maxC_img, maxC_ids, C_imgs, C_ids, C_pre, p_db = net.actor_net.forward(
                G_imgs_his, torch.cat((f_embed_his_t, f_embed_his_g), dim=1),
                torch.cat((distance_t, -distance_g), dim=1),
                ranker, k=4)
            if t >= 10:
                max_index = torch.zeros(1).long().to(device1)
                for batch_i in range(1):
                    max_index[batch_i] = 0
                if t > 10:
                    st.session_state['stop'] = 1
                    return 408203
                print("final action: ", max_index)
            else:
                if st.session_state['model'] == './models/p1-t10-g.pth':
                    max_index = torch.zeros(1).long().to(device1)
                    for batch_i in range(1):
                        max_index[batch_i] = 0
                else:
                    max_index = p.argmax(dim=1)
                print("net action: ", max_index)

            G_next_ids = torch.zeros(1).to(device1)
            G_next_imgs = torch.zeros(1, 512).to(device1)
            G_next_pre = torch.zeros(1, 3, 224, 224).to(device1)

            for i in range(1):
                G_next_ids[i] = C_ids[i][max_index[i]]  # [16]
                G_next_imgs[i] = C_imgs[i][max_index[i]]  # [16, 512]
                G_next_pre[i] = C_pre[i][max_index[i]]  # [16, 3, 224, 224]
            G_next_ids = G_next_ids.cpu().numpy().astype(int)
            G_imgs_his_next = G_imgs_his
            G_imgs_his_next[:, t + 1] = G_next_imgs
            st.session_state[G_imgs_his] = G_imgs_his_next
            st.session_state[f_embed_his_g] = f_embed_his_g
            st.session_state[f_embed_his_t] = f_embed_his_t
            st.session_state[f_embed_his] = f_embed_his
            return G_next_ids[0]



        def run_model(r_t, fg_t, fp_t):
            g_t_id_next = model(r_t, fg_t, fp_t)
            st.session_state['g_t_id'] = g_t_id_next
            return g_t_id_next
        g_t_id_next = run_model(r_t, fg_t, fp_t)
        st.session_state['g_t_id'] = g_t_id_next
        # st.write(st.session_state)


    fp_t = st.text_input('Input text(positive)', value=' ')
    fg_t = st.text_input('Input text(negative)', value=' ')
    if fp_t == '':
        fp_t = ' '
    if fg_t == '':
        fg_t = ' '
    r_t = st.text_input('Input reward', value='0')
    st.button('Interactive', on_click=interactive, args=(r_t, fg_t, fp_t))
    g_t_id = st.session_state['g_t_id']
    st.image(Image.open('./imgs/' + str(g_t_id) + '.jpg'),
             caption='g_t_next:' + str(g_t_id), use_column_width='auto')



def main():
    modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth'))
    start(modelName)


if __name__ == "__main__":
    main()