jykoh commited on
Commit
05e5f88
1 Parent(s): 294a555
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +1 -1
  3. app.py +3 -4
  4. cc3m_embeddings.pkl +3 -0
  5. fromage/models.py +1 -1
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ cc3m_embeddings.pkl filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: FROMAGe
3
  emoji: 🧀
4
- sdk: docker
5
  colorFrom: blue
6
  colorTo: red
7
  pinned: true
1
  ---
2
  title: FROMAGe
3
  emoji: 🧀
4
+ sdk: gradio
5
  colorFrom: blue
6
  colorTo: red
7
  pinned: true
app.py CHANGED
@@ -13,10 +13,9 @@ import tempfile
13
  class FromageChatBot:
14
  def __init__(self):
15
  # Download model from HF Hub.
16
- huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
17
- huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
18
- huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='cc3m_embeddings.pkl')
19
- self.model = models.load_fromage('./')
20
  self.chat_history = ''
21
  self.input_image = None
22
 
13
  class FromageChatBot:
14
  def __init__(self):
15
  # Download model from HF Hub.
16
+ ckpt_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
17
+ args_path = huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
18
+ self.model = models.load_fromage('./', args_path, ckpt_path)
 
19
  self.chat_history = ''
20
  self.input_image = None
21
 
cc3m_embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a20fa8168bd72e848ff088820b767383dded455a57ac5dd2d97d43e600402195
3
+ size 2979901225
fromage/models.py CHANGED
@@ -594,7 +594,7 @@ class Fromage(nn.Module):
594
  return return_outputs
595
 
596
 
597
- def load_fromage(model_dir: str) -> Fromage:
598
  model_args_path = os.path.join(model_dir, 'model_args.json')
599
  model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
600
  embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]
594
  return return_outputs
595
 
596
 
597
+ def load_fromage(embeddings_dir: str, args_path: str, ckpt_path: str) -> Fromage:
598
  model_args_path = os.path.join(model_dir, 'model_args.json')
599
  model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
600
  embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]