Zhonathon commited on
Commit
638f79a
1 Parent(s): 6066f74

update ranker_test.pth

Browse files
Model/fur_rl/models/__pycache__/retriever_rl.cpython-38.pyc CHANGED
Binary files a/Model/fur_rl/models/__pycache__/retriever_rl.cpython-38.pyc and b/Model/fur_rl/models/__pycache__/retriever_rl.cpython-38.pyc differ
 
app.py CHANGED
@@ -27,7 +27,7 @@ def txt_embed(t_txt, g_txt, fb_txt, net, batch_size, device1):
27
 
28
 
29
  def start(modelName):
30
- device1 = "cuda" if torch.cuda.is_available() else "cpu"
31
  if 'model' not in st.session_state:
32
  # load model
33
  sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
@@ -49,7 +49,7 @@ def start(modelName):
49
  if 'model' in st.session_state:
50
  if st.session_state['model'] != modelName:
51
  # load model
52
- device1 = "cuda" if torch.cuda.is_available() else "cpu"
53
  sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
54
  download_root='../../data/pretrained_weights/',
55
  resume='./models/ClipEncoder.pt')
@@ -195,7 +195,7 @@ def start(modelName):
195
 
196
 
197
  def main():
198
- modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth'))
199
  start(modelName)
200
 
201
 
 
27
 
28
 
29
  def start(modelName):
30
+ device1 = "cpu"
31
  if 'model' not in st.session_state:
32
  # load model
33
  sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
 
49
  if 'model' in st.session_state:
50
  if st.session_state['model'] != modelName:
51
  # load model
52
+ device1 = "cpu"
53
  sentence_clip_model, sentence_clip_preprocess = load_from_name("ViT-B-16", device=device1,
54
  download_root='../../data/pretrained_weights/',
55
  resume='./models/ClipEncoder.pt')
 
195
 
196
 
197
  def main():
198
+ modelName = st.selectbox('Select model', ('p1-t10-g.pth', 'p2-t10-g.pth', 'only CLIP'))
199
  start(modelName)
200
 
201
 
recommendation/utils/__pycache__/ranker_1.cpython-38.pyc CHANGED
Binary files a/recommendation/utils/__pycache__/ranker_1.cpython-38.pyc and b/recommendation/utils/__pycache__/ranker_1.cpython-38.pyc differ
 
recommendation/utils/ranker_1.py CHANGED
@@ -144,9 +144,9 @@ class Ranker:
144
  def actions_metric(self, f_embed_his, f_embed_his_t, f_embed_his_g, batch_size=32, lens=4):
145
  # f_embed_his: [turn, batch, emb]
146
  # distance: [batch, turn, data]
147
- distance_ = torch.zeros(batch_size, lens, self.data_emb.size(0)).to(self.device)
148
- distance_t = torch.zeros(batch_size, lens, self.data_emb.size(0)).to(self.device)
149
- distance_g = torch.zeros(batch_size, lens, self.data_emb.size(0)).to(self.device)
150
  for turn_i in range(lens):
151
  for batch_i in range(batch_size): # batch
152
  # dot distance
 
144
  def actions_metric(self, f_embed_his, f_embed_his_t, f_embed_his_g, batch_size=32, lens=4):
145
  # f_embed_his: [turn, batch, emb]
146
  # distance: [batch, turn, data]
147
+ distance_ = torch.zeros(batch_size, lens, len(self.data_emb))
148
+ distance_t = torch.zeros(batch_size, lens, len(self.data_emb))
149
+ distance_g = torch.zeros(batch_size, lens, len(self.data_emb))
150
  for turn_i in range(lens):
151
  for batch_i in range(batch_size): # batch
152
  # dot distance