Zhonathon commited on
Commit
6066f74
1 Parent(s): 8950f9f

update ranker_test.pth

Browse files
app.py CHANGED
@@ -9,7 +9,7 @@ from Model.fur_rl.models.retriever_rl import DQN_v3
9
  import Model.CLIP.cn_clip.clip as clip
10
  import recommendation.datasets.img_preprocess
11
  import recommendation.utils.ranker_1
12
- import numpy as np
13
 
14
  st.title('Recommendation System V1')
15
 
@@ -63,6 +63,7 @@ def start(modelName):
63
  load_model_name = None
64
  net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
65
  net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
 
66
  st.session_state['model'] = modelName
67
  st.session_state['net'] = net
68
  st.write(modelName + ' reloaded')
@@ -92,12 +93,16 @@ def start(modelName):
92
  if 'ranker' not in st.session_state:
93
  test_img_ids = "./recommendation/datasets/test_img_id_r.csv"
94
  dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids)
95
- dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=4)
96
  ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64)
97
- st.session_state['ranker'] = ranker_test
98
- net = st.session_state['net']
99
- ranker_test.update_emb(model=net.actor_net) # 220.0789999961853s; 78s on 3090
 
 
 
100
  st.write('ranker ready')
 
 
101
 
102
  # st.write(st.session_state)
103
  st.write(st.session_state['model'] + ' loaded')
 
9
  import Model.CLIP.cn_clip.clip as clip
10
  import recommendation.datasets.img_preprocess
11
  import recommendation.utils.ranker_1
12
+
13
 
14
  st.title('Recommendation System V1')
15
 
 
63
  load_model_name = None
64
  net.actor_net.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_state_dict'])
65
  net.actor_optimizer.load_state_dict(torch.load(load_model_name, map_location=device1)['actor_optimizer'])
66
+ net.actor_net.eval()
67
  st.session_state['model'] = modelName
68
  st.session_state['net'] = net
69
  st.write(modelName + ' reloaded')
 
93
  if 'ranker' not in st.session_state:
94
  test_img_ids = "./recommendation/datasets/test_img_id_r.csv"
95
  dataset_test = recommendation.datasets.img_preprocess.Image_preprocess(test_img_ids)
 
96
  ranker_test = recommendation.utils.ranker_1.Ranker(device1, dataset_test, batch_size=64)
97
+ # net = st.session_state['net']
98
+ # ranker_test.update_emb(model=net.actor_net) # 220.0789999961853s; 78s on 3090
99
+ # save ranker
100
+ # torch.save(ranker_test, './ranker_test.pth')
101
+ # load ranker from pth
102
+ ranker_test = torch.load('./models/ranker_test.pth', map_location=device1)
103
  st.write('ranker ready')
104
+ st.session_state['ranker'] = ranker_test
105
+
106
 
107
  # st.write(st.session_state)
108
  st.write(st.session_state['model'] + ' loaded')
models/ranker_test.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8c8c80fce74913e38a3b0c20e7ff06db179ba069c945c504c18af2fba445ba9a
3
  size 1703781047
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53a0ab3288b8fda1c8a9599b50b9bccf0a8d33ab309a5ef3742e49ec419c6ffb
3
  size 1703781047
recommendation/datasets/__pycache__/img_preprocess.cpython-38.pyc CHANGED
Binary files a/recommendation/datasets/__pycache__/img_preprocess.cpython-38.pyc and b/recommendation/datasets/__pycache__/img_preprocess.cpython-38.pyc differ
 
recommendation/utils/__pycache__/ranker_1.cpython-38.pyc CHANGED
Binary files a/recommendation/utils/__pycache__/ranker_1.cpython-38.pyc and b/recommendation/utils/__pycache__/ranker_1.cpython-38.pyc differ