Li commited on
Commit
d3fbc73
1 Parent(s): f407227

“update”

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -92,6 +92,9 @@ def generate(
92
  all_ids = set(range(flamingo.lang_encoder.lm_head.out_features))
93
  bad_words_ids = list(all_ids - set(loc_token_ids))
94
  bad_words_ids = [[b] for b in bad_words_ids]
 
 
 
95
  min_loc_token_id = min(loc_token_ids)
96
  max_loc_token_id = max(loc_token_ids)
97
  image_ori = image
@@ -103,9 +106,11 @@ def generate(
103
  if idx == 1:
104
  prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}<|#loc#|>"]
105
  bad_words_ids = None
 
106
  else:
107
  prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
108
- bad_words_ids = None
 
109
  encodings = tokenizer(
110
  prompt,
111
  padding="longest",
@@ -122,7 +127,7 @@ def generate(
122
  model=flamingo,
123
  batch_images=batch_images,
124
  attention_mask=attention_mask,
125
- max_generation_length=5,
126
  min_generation_length=4,
127
  num_beams=1,
128
  length_penalty=1.0,
 
92
  all_ids = set(range(flamingo.lang_encoder.lm_head.out_features))
93
  bad_words_ids = list(all_ids - set(loc_token_ids))
94
  bad_words_ids = [[b] for b in bad_words_ids]
95
+ loc_word_ids = list(set(loc_token_ids))
96
+ loc_word_ids = [[b] for b in loc_word_ids]
97
+
98
  min_loc_token_id = min(loc_token_ids)
99
  max_loc_token_id = max(loc_token_ids)
100
  image_ori = image
 
106
  if idx == 1:
107
  prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}<|#loc#|>"]
108
  bad_words_ids = None
109
+ max_generation_length = 5
110
  else:
111
  prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
112
+ bad_words_ids = loc_word_ids
113
+ max_generation_length = 100
114
  encodings = tokenizer(
115
  prompt,
116
  padding="longest",
 
127
  model=flamingo,
128
  batch_images=batch_images,
129
  attention_mask=attention_mask,
130
+ max_generation_length=max_generation_length,
131
  min_generation_length=4,
132
  num_beams=1,
133
  length_penalty=1.0,