gchhablani commited on
Commit
24e1057
1 Parent(s): 8b0c1f6

Switch to local checkpoints

Browse files
Files changed (2) hide show
  1. apps/mlm.py +2 -1
  2. apps/vqa.py +4 -3
apps/mlm.py CHANGED
@@ -43,7 +43,8 @@ def app(state):
43
  def load_model(ckpt):
44
  return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
45
 
46
- mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
 
47
  dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
48
 
49
  first_index = 15
 
43
  def load_model(ckpt):
44
  return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
45
 
46
+ #mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
47
+ mlm_checkpoints = ["./ckpt/mlm/ckpt-60k"]
48
  dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
49
 
50
  first_index = 15
apps/vqa.py CHANGED
@@ -45,9 +45,10 @@ def app(state):
45
  def load_model(ckpt):
46
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
47
 
48
- vqa_checkpoints = [
49
- "flax-community/clip-vision-bert-vqa-ft-6k"
50
- ] # TODO: Maybe add more checkpoints?
 
51
  dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
52
  code_to_name = {
53
  "en": "English",
 
45
  def load_model(ckpt):
46
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
47
 
48
+ #vqa_checkpoints = [
49
+ # "flax-community/clip-vision-bert-vqa-ft-6k"
50
+ #] # TODO: Maybe add more checkpoints?
51
+ vqa_checkpoints = ["./ckpt/vqa/ckpt-60k-5999"]
52
  dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
53
  code_to_name = {
54
  "en": "English",