SahilJ2 commited on
Commit
6087f11
1 Parent(s): 50fdcbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -20
app.py CHANGED
@@ -99,31 +99,42 @@ def m2(que, image):
99
  return processor3.batch_decode(generated_ids, skip_special_tokens=True)
100
 
101
  def m3(que, image):
102
- processor3 = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
103
- model3 = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
104
 
105
- model3.to(device)
106
 
107
- prompt = "<s_docvqa><s_question>{que}</s_question><s_answer>"
108
- decoder_input_ids = processor3.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
109
 
110
- pixel_values = processor3(image, return_tensors="pt").pixel_values
111
 
112
- outputs = model3.generate(
113
- pixel_values.to(device),
114
- decoder_input_ids=decoder_input_ids.to(device),
115
- max_length=model3.decoder.config.max_position_embeddings,
116
- pad_token_id=processor3.tokenizer.pad_token_id,
117
- eos_token_id=processor3.tokenizer.eos_token_id,
118
- use_cache=True,
119
- bad_words_ids=[[processor3.tokenizer.unk_token_id]],
120
- return_dict_in_generate=True,
121
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- sequence = processor3.batch_decode(outputs.sequences)[0]
124
- sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "")
125
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
126
- return processor3.token2json(sequence)['answer']
127
 
128
  def m4(que, image):
129
  processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v2')
 
99
  return processor3.batch_decode(generated_ids, skip_special_tokens=True)
100
 
101
  def m3(que, image):
102
+ # processor3 = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
103
+ # model3 = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
104
 
105
+ # model3.to(device)
106
 
107
+ # prompt = "<s_docvqa><s_question>{que}</s_question><s_answer>"
108
+ # decoder_input_ids = processor3.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
109
 
110
+ # pixel_values = processor3(image, return_tensors="pt").pixel_values
111
 
112
+ # outputs = model3.generate(
113
+ # pixel_values.to(device),
114
+ # decoder_input_ids=decoder_input_ids.to(device),
115
+ # max_length=model3.decoder.config.max_position_embeddings,
116
+ # pad_token_id=processor3.tokenizer.pad_token_id,
117
+ # eos_token_id=processor3.tokenizer.eos_token_id,
118
+ # use_cache=True,
119
+ # bad_words_ids=[[processor3.tokenizer.unk_token_id]],
120
+ # return_dict_in_generate=True,
121
+ # )
122
+
123
+ # sequence = processor3.batch_decode(outputs.sequences)[0]
124
+ # sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "")
125
+ # sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
126
+ # return processor3.token2json(sequence)['answer']
127
+
128
+
129
+ processor3 = AutoProcessor.from_pretrained("google/pix2struct-docvqa-large")
130
+ model3 = AutoModelForSeq2SeqLM.from_pretrained("google/pix2struct-docvqa-large")
131
+
132
+ inputs = processor3(images=image, text=que, return_tensors="pt")
133
+
134
+ predictions = model3.generate(**inputs)
135
+ return processor3.decode(predictions[0], skip_special_tokens=True)
136
 
137
+
 
 
 
138
 
139
  def m4(que, image):
140
  processor3 = Pix2StructProcessor.from_pretrained('google/matcha-plotqa-v2')