jykoh commited on
Commit
b1c9822
1 Parent(s): 4b5bfb2

Fix loading

Browse files
Files changed (2) hide show
  1. Dockerfile → blah +0 -0
  2. fromage/models.py +5 -7
Dockerfile → blah RENAMED
File without changes
fromage/models.py CHANGED
@@ -594,17 +594,15 @@ class Fromage(nn.Module):
594
  return return_outputs
595
 
596
 
597
- def load_fromage(embeddings_dir: str, args_path: str, ckpt_path: str) -> Fromage:
598
- model_args_path = os.path.join(model_dir, 'model_args.json')
599
- model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
600
- embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]
601
 
602
  if not os.path.exists(model_args_path):
603
- raise ValueError(f'model_args.json does not exist in {model_dir}.')
604
  if not os.path.exists(model_ckpt_path):
605
- raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.')
606
  if len(embs_paths) == 0:
607
- raise ValueError(f'cc3m_embeddings_*.pkl files do not exist in {model_dir}.')
608
 
609
  # Load embeddings.
610
  # Construct embedding matrix for nearest neighbor lookup.
594
  return return_outputs
595
 
596
 
597
+ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str) -> Fromage:
598
+ embs_paths = [s for s in glob.glob(os.path.join(embeddings_dir, 'cc3m_embeddings*.pkl'))]
 
 
599
 
600
  if not os.path.exists(model_args_path):
601
+ raise ValueError(f'model_args.json does not exist at {model_args_path}.')
602
  if not os.path.exists(model_ckpt_path):
603
+ raise ValueError(f'pretrained_ckpt.pth.tar does not exist at {model_ckpt_path}.')
604
  if len(embs_paths) == 0:
605
+ raise ValueError(f'cc3m_embeddings_*.pkl files do not exist in {embeddings_dir}.')
606
 
607
  # Load embeddings.
608
  # Construct embedding matrix for nearest neighbor lookup.