aravind-selvam commited on
Commit
21889db
1 Parent(s): dffa852

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md CHANGED
@@ -1,3 +1,58 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+ ```
5
+ import re
6
+ import transformers
7
+ from PIL import Image
8
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
9
+ import torch
10
+ import random
11
+ import numpy as np
12
+
13
+ fine_tuned_model = VisionEncoderDecoderModel.from_pretrained("aravind-selvam/donut_finetuned_chart")
14
+ processor = DonutProcessor.from_pretrained("aravind-selvam/donut_finetuned_chart")
15
+
16
+ # Move model to GPU
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ fine_tuned_model.to(device)
19
+
20
+ # Load random document image from the test set
21
+ dataset = load_dataset("hf-internal-testing/example-documents", split="test")
22
+ sample_image = dataset[1]
23
+
24
+ def run_prediction(sample, model=fine_tuned_model, processor=processor):
25
+ # pixel values
26
+ pixel_values = processor(image, return_tensors="pt").pixel_values
27
+ # prepare inputs
28
+ task_prompt = "<s>"
29
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
30
+
31
+ # run inference
32
+ outputs = model.generate(
33
+ pixel_values.to(device),
34
+ decoder_input_ids=decoder_input_ids.to(device),
35
+ max_length=model.decoder.config.max_position_embeddings,
36
+ early_stopping=True,
37
+ pad_token_id=processor.tokenizer.pad_token_id,
38
+ eos_token_id=processor.tokenizer.eos_token_id,
39
+ use_cache=True,
40
+ num_beams=2,
41
+ # bad_words_ids=[[processor.tokenizer.unk_token_id]],
42
+ return_dict_in_generate=True,
43
+ )
44
+
45
+ # process output
46
+ prediction = processor.batch_decode(outputs.sequences)[0]
47
+ prediction = re.sub(r"<one>", "1", prediction)
48
+ prediction = processor.token2json(prediction)
49
+
50
+
51
+ # load reference target
52
+ target = processor.token2json(test_sample["target_sequence"])
53
+ return prediction, target
54
+
55
+ prediction, target = run_prediction(sample_image)
56
+ print(f"Reference:\n {target}")
57
+ print(f"Prediction:\n {prediction}")
58
+ ```