ajimeno commited on
Commit
ec9e91a
1 Parent(s): 34ca5b1

Updated Chipper model

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. logits_ngrams.py +1 -1
app.py CHANGED
@@ -13,7 +13,7 @@ from logits_ngrams import NoRepeatNGramLogitsProcessor, get_table_token_ids
13
  def run_prediction(sample, model, processor, mode):
14
 
15
  skip_tokens = get_table_token_ids(processor)
16
- no_repeat_ngram_size = 10
17
 
18
  if mode == "OCR":
19
  prompt = "<s><s_pretraining>"
@@ -35,9 +35,9 @@ def run_prediction(sample, model, processor, mode):
35
  decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
36
  logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)],
37
  do_sample=True,
38
- top_p=0.92, #.92,
39
  top_k=5,
40
- no_repeat_ngram_size=0,
41
  num_beams=3,
42
  output_attentions=False,
43
  output_hidden_states=False,
@@ -81,7 +81,7 @@ else:
81
  st.image(image, caption='Your target document')
82
 
83
  with st.spinner(f'Processing the document ...'):
84
- pre_trained_model = "unstructuredio/chipper-fast-fine-tuning"
85
  processor = DonutProcessor.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN'])
86
 
87
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
13
  def run_prediction(sample, model, processor, mode):
14
 
15
  skip_tokens = get_table_token_ids(processor)
16
+ no_repeat_ngram_size = 15
17
 
18
  if mode == "OCR":
19
  prompt = "<s><s_pretraining>"
 
35
  decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
36
  logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)],
37
  do_sample=True,
38
+ top_p=0.92,
39
  top_k=5,
40
+ no_repeat_ngram_size=25,
41
  num_beams=3,
42
  output_attentions=False,
43
  output_hidden_states=False,
 
81
  st.image(image, caption='Your target document')
82
 
83
  with st.spinner(f'Processing the document ...'):
84
+ pre_trained_model = "unstructuredio/chipper-fast-fine-tuning-oct-23-release"
85
  processor = DonutProcessor.from_pretrained(pre_trained_model, token=os.environ['HF_TOKEN'])
86
 
87
  device = "cuda" if torch.cuda.is_available() else "cpu"
logits_ngrams.py CHANGED
@@ -59,5 +59,5 @@ def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len
59
 
60
 
61
  def get_table_token_ids(processor):
62
- skip_tokens = {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if token.startswith("<t") or token.startswith("</t") }
63
 
 
59
 
60
 
61
  def get_table_token_ids(processor):
62
+ return {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if token.startswith("<t") or token.startswith("</t") }
63