hsarfraz commited on
Commit
2cdc2be
1 Parent(s): 0e770bf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md CHANGED
@@ -14,6 +14,61 @@ The base model is ['naver-clova-ix/donut-base'][base], the model is finetuned fo
14
 
15
  For inference use image size width: 1536 px and height: 1536 px
16
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  [base]: https://huggingface.co/naver-clova-ix/donut-base
 
14
 
15
  For inference use image size width: 1536 px and height: 1536 px
16
 
17
+ ```python
18
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
19
+ from PIL import Image
20
+ import torch
21
+ import re
22
 
23
+ model_name = 'hsarfraz/irs-tax-form-1040-2023-doc-parser'
24
+
25
+ processor = DonutProcessor.from_pretrained(model_name)
26
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
27
+
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ model.to(device)
30
+ model.eval()
31
+
32
+ image_name = 'replace with name of the form 1040 (2023) image file '
33
+
34
+ img = Image.open(image_name)
35
+ new_width = 1536
36
+ new_height = 1536
37
+
38
+ # resize input image to finetuned images size
39
+ img = img.resize((new_width, new_height), Image.LANCZOS)
40
+
41
+ pixel_values = processor(img.convert("RGB"), return_tensors="pt").pixel_values
42
+ pixel_values = pixel_values.to(device)
43
+
44
+ # prompt
45
+ task_prompt = "<s_cord-v2>"
46
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
47
+ decoder_input_ids = decoder_input_ids.to(device)
48
+
49
+ outputs = model.generate(pixel_values,decoder_input_ids=decoder_input_ids,
50
+ max_length=model.decoder.config.max_position_embeddings,
51
+ early_stopping=True,
52
+ pad_token_id=processor.tokenizer.pad_token_id,
53
+ eos_token_id=processor.tokenizer.eos_token_id,
54
+ use_cache=True,
55
+ num_beams=1,
56
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
57
+ return_dict_in_generate=True,
58
+ # output_scores=True,
59
+ )
60
+
61
+
62
+ sequence = processor.batch_decode(outputs.sequences)[0]
63
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
64
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
65
+ output_json = processor.token2json(sequence)
66
+
67
+ print('----------------------------------')
68
+ print('--- Parsed data in json format ---')
69
+ print('----------------------------------')
70
+ print(output_json)
71
+
72
+ ```
73
 
74
  [base]: https://huggingface.co/naver-clova-ix/donut-base