update ranker_test.pth
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ from Model.fur_rl.models.retriever_rl import DQN_v3
|
|
9 |
import Model.CLIP.cn_clip.clip as clip
|
10 |
import recommendation.datasets.img_preprocess
|
11 |
import recommendation.utils.ranker_1
|
12 |
-
|
13 |
|
14 |
st.title('Recommendation System V1')
|
15 |
|
@@ -63,6 +63,7 @@ def start(modelName):
|
|
63 |
load_model_name = None
|
64 |
net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
|
65 |
net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
|
|
|
66 |
st.session_state['model'] = modelName
|
67 |
st.session_state['net'] = net
|
68 |
st.write(modelName + ' reloaded')
|
@@ -92,12 +93,16 @@ def start(modelName):
|
|
92 |
if 'ranker' not in st.session_state:
|
93 |
test_img_ids = "./recommendation/datasets/test_img_id_r.csv"
|
94 |
dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids)
|
95 |
-
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4)
|
96 |
ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64)
|
97 |
-
st.session_state['
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
st.write('ranker ready')
|
|
|
|
|
101 |
|
102 |
# st.write(st.session_state)
|
103 |
st.write(st.session_state['model'] + ' loaded')
|
|
|
9 |
import Model.CLIP.cn_clip.clip as clip
|
10 |
import recommendation.datasets.img_preprocess
|
11 |
import recommendation.utils.ranker_1
|
12 |
+
|
13 |
|
14 |
st.title('Recommendation System V1')
|
15 |
|
|
|
63 |
load_model_name = None
|
64 |
net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
|
65 |
net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
|
66 |
+
net.actor_net.eval()
|
67 |
st.session_state['model'] = modelName
|
68 |
st.session_state['net'] = net
|
69 |
st.write(modelName + ' reloaded')
|
|
|
93 |
if 'ranker' not in st.session_state:
|
94 |
test_img_ids = "./recommendation/datasets/test_img_id_r.csv"
|
95 |
dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids)
|
|
|
96 |
ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64)
|
97 |
+
# net = st.session_state['net']
|
98 |
+
# ranker_test.update_emb(model=net.actor_net) # 220.0789999961853s; 78s on 3090
|
99 |
+
# save ranker
|
100 |
+
# torch.save(ranker_test, './ranker_test.pth')
|
101 |
+
# load ranker from pth
|
102 |
+
ranker_test = torch.load('./models/ranker_test.pth', map_location=device1)
|
103 |
st.write('ranker ready')
|
104 |
+
st.session_state['ranker'] = ranker_test
|
105 |
+
|
106 |
|
107 |
# st.write(st.session_state)
|
108 |
st.write(st.session_state['model'] + ' loaded')
|
models/ranker_test.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1703781047
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53a0ab3288b8fda1c8a9599b50b9bccf0a8d33ab309a5ef3742e49ec419c6ffb
|
3 |
size 1703781047
|
recommendation/datasets/__pycache__/img_preprocess.cpython-38.pyc
CHANGED
Binary files a/recommendation/datasets/__pycache__/img_preprocess.cpython-38.pyc and b/recommendation/datasets/__pycache__/img_preprocess.cpython-38.pyc differ
|
|
recommendation/utils/__pycache__/ranker_1.cpython-38.pyc
CHANGED
Binary files a/recommendation/utils/__pycache__/ranker_1.cpython-38.pyc and b/recommendation/utils/__pycache__/ranker_1.cpython-38.pyc differ
|
|