update ranker_test.pth
Browse files
Model/fur_rl/models/__pycache__/retriever_rl.cpython-38.pyc
CHANGED
Binary files a/Model/fur_rl/models/__pycache__/retriever_rl.cpython-38.pyc and b/Model/fur_rl/models/__pycache__/retriever_rl.cpython-38.pyc differ
|
|
app.py
CHANGED
@@ -27,7 +27,7 @@ def txt_embed(t_txt, g_txt, fb_txt, net, batch_size, device1):
|
|
27 |
|
28 |
|
29 |
def start(modelName):
|
30 |
-
device1 = "
|
31 |
if 'model' not in st.session_state:
|
32 |
# load model
|
33 |
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
|
@@ -49,7 +49,7 @@ def start(modelName):
|
|
49 |
if 'model' in st.session_state:
|
50 |
if st.session_state['model'] != modelName:
|
51 |
# load model
|
52 |
-
device1 = "
|
53 |
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
|
54 |
download_root='../../data/pretrained_weights/',
|
55 |
resume='./models/ClipEncoder.pt')
|
@@ -195,7 +195,7 @@ def start(modelName):
|
|
195 |
|
196 |
|
197 |
def main():
|
198 |
-
modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth'))
|
199 |
start(modelName)
|
200 |
|
201 |
|
|
|
27 |
|
28 |
|
29 |
def start(modelName):
|
30 |
+
device1 = "cpu"
|
31 |
if 'model' not in st.session_state:
|
32 |
# load model
|
33 |
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
|
|
|
49 |
if 'model' in st.session_state:
|
50 |
if st.session_state['model'] != modelName:
|
51 |
# load model
|
52 |
+
device1 = "cpu"
|
53 |
sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
|
54 |
download_root='../../data/pretrained_weights/',
|
55 |
resume='./models/ClipEncoder.pt')
|
|
|
195 |
|
196 |
|
197 |
def main():
|
198 |
+
modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth', 'only CLIP'))
|
199 |
start(modelName)
|
200 |
|
201 |
|
recommendation/utils/__pycache__/ranker_1.cpython-38.pyc
CHANGED
Binary files a/recommendation/utils/__pycache__/ranker_1.cpython-38.pyc and b/recommendation/utils/__pycache__/ranker_1.cpython-38.pyc differ
|
|
recommendation/utils/ranker_1.py
CHANGED
@@ -144,9 +144,9 @@ class Ranker:
|
|
144 |
def actions_metric(self, f_embed_his, f_embed_his_t, f_embed_his_g, batch_size=32, lens=4):
|
145 |
# f_embed_his: [turn, batch, emb]
|
146 |
# distance: [batch, turn, data]
|
147 |
-
distance_ = torch.zeros(batch_size, lens, self.data_emb
|
148 |
-
distance_t = torch.zeros(batch_size, lens, self.data_emb
|
149 |
-
distance_g = torch.zeros(batch_size, lens, self.data_emb
|
150 |
for turn_i in range(lens):
|
151 |
for batch_i in range(batch_size): # batch
|
152 |
# dot distance
|
|
|
144 |
def actions_metric(self, f_embed_his, f_embed_his_t, f_embed_his_g, batch_size=32, lens=4):
|
145 |
# f_embed_his: [turn, batch, emb]
|
146 |
# distance: [batch, turn, data]
|
147 |
+
distance_ = torch.zeros(batch_size, lens, len(self.data_emb))
|
148 |
+
distance_t = torch.zeros(batch_size, lens, len(self.data_emb))
|
149 |
+
distance_g = torch.zeros(batch_size, lens, len(self.data_emb))
|
150 |
for turn_i in range(lens):
|
151 |
for batch_i in range(batch_size): # batch
|
152 |
# dot distance
|