mrm8488 commited on
Commit
3f23a54
1 Parent(s): b3deade

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +142 -1
README.md CHANGED
@@ -1,3 +1,144 @@
1
 
 
2
 
3
- # LayoutLM fine-tuned on FUNSD for Document token classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ # LayoutLM fine-tuned on FUNSD for Document token classification
3
 
4
+ ## Usage
5
+
6
+ ```python
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import pytesseract
11
+ from transformers import LayoutLMForTokenClassification
12
+
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ model = LayoutLMForTokenClassification.from_pretrained("mrm8488/layoutlm-finetuned-funsd", num_labels=num_labels)
17
+ model.to(device)
18
+
19
+
20
+ image = Image.open("/83443897.png")
21
+ image = image.convert("RGB")
22
+
23
+ # Display the image
24
+
25
+
26
+ # Run Tesseract (OCR) on the image
27
+
28
+ width, height = image.size
29
+ w_scale = 1000/width
30
+ h_scale = 1000/height
31
+
32
+ ocr_df = pytesseract.image_to_data(image, output_type='data.frame') \
33
+
34
+ ocr_df = ocr_df.dropna() \
35
+ .assign(left_scaled = ocr_df.left*w_scale,
36
+ width_scaled = ocr_df.width*w_scale,
37
+ top_scaled = ocr_df.top*h_scale,
38
+ height_scaled = ocr_df.height*h_scale,
39
+ right_scaled = lambda x: x.left_scaled + x.width_scaled,
40
+ bottom_scaled = lambda x: x.top_scaled + x.height_scaled)
41
+
42
+ float_cols = ocr_df.select_dtypes('float').columns
43
+ ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
44
+ ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
45
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
46
+ ocr_df[:20]
47
+
48
+ # create a list of words, actual bounding boxes, and normalized boxes
49
+
50
+ words = list(ocr_df.text)
51
+ coordinates = ocr_df[['left', 'top', 'width', 'height']]
52
+ actual_boxes = []
53
+ for idx, row in coordinates.iterrows():
54
+ x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format
55
+ actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
56
+ actual_boxes.append(actual_box)
57
+
58
+ def normalize_box(box, width, height):
59
+ return [
60
+ int(1000 * (box[0] / width)),
61
+ int(1000 * (box[1] / height)),
62
+ int(1000 * (box[2] / width)),
63
+ int(1000 * (box[3] / height)),
64
+ ]
65
+
66
+ boxes = []
67
+ for box in actual_boxes:
68
+ boxes.append(normalize_box(box, width, height))
69
+
70
+ # Display boxes
71
+
72
+ def convert_example_to_features(image, words, boxes, actual_boxes, tokenizer, args, cls_token_box=[0, 0, 0, 0],
73
+ sep_token_box=[1000, 1000, 1000, 1000],
74
+ pad_token_box=[0, 0, 0, 0]):
75
+ width, height = image.size
76
+
77
+ tokens = []
78
+ token_boxes = []
79
+ actual_bboxes = [] # we use an extra b because actual_boxes is already used
80
+ token_actual_boxes = []
81
+ for word, box, actual_bbox in zip(words, boxes, actual_boxes):
82
+ word_tokens = tokenizer.tokenize(word)
83
+ tokens.extend(word_tokens)
84
+ token_boxes.extend([box] * len(word_tokens))
85
+ actual_bboxes.extend([actual_bbox] * len(word_tokens))
86
+ token_actual_boxes.extend([actual_bbox] * len(word_tokens))
87
+
88
+ # Truncation: account for [CLS] and [SEP] with "- 2".
89
+ special_tokens_count = 2
90
+ if len(tokens) > args.max_seq_length - special_tokens_count:
91
+ tokens = tokens[: (args.max_seq_length - special_tokens_count)]
92
+ token_boxes = token_boxes[: (args.max_seq_length - special_tokens_count)]
93
+ actual_bboxes = actual_bboxes[: (args.max_seq_length - special_tokens_count)]
94
+ token_actual_boxes = token_actual_boxes[: (args.max_seq_length - special_tokens_count)]
95
+
96
+ # add [SEP] token, with corresponding token boxes and actual boxes
97
+ tokens += [tokenizer.sep_token]
98
+ token_boxes += [sep_token_box]
99
+ actual_bboxes += [[0, 0, width, height]]
100
+ token_actual_boxes += [[0, 0, width, height]]
101
+
102
+ segment_ids = [0] * len(tokens)
103
+
104
+ # next: [CLS] token
105
+ tokens = [tokenizer.cls_token] + tokens
106
+ token_boxes = [cls_token_box] + token_boxes
107
+ actual_bboxes = [[0, 0, width, height]] + actual_bboxes
108
+ token_actual_boxes = [[0, 0, width, height]] + token_actual_boxes
109
+ segment_ids = [1] + segment_ids
110
+
111
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
112
+
113
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
114
+ # tokens are attended to.
115
+ input_mask = [1] * len(input_ids)
116
+
117
+ # Zero-pad up to the sequence length.
118
+ padding_length = args.max_seq_length - len(input_ids)
119
+ input_ids += [tokenizer.pad_token_id] * padding_length
120
+ input_mask += [0] * padding_length
121
+ segment_ids += [tokenizer.pad_token_id] * padding_length
122
+ token_boxes += [pad_token_box] * padding_length
123
+ token_actual_boxes += [pad_token_box] * padding_length
124
+
125
+ assert len(input_ids) == args.max_seq_length
126
+ assert len(input_mask) == args.max_seq_length
127
+ assert len(segment_ids) == args.max_seq_length
128
+ assert len(token_boxes) == args.max_seq_length
129
+ assert len(token_actual_boxes) == args.max_seq_length
130
+
131
+ return input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes
132
+
133
+ input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes = convert_example_to_features(image=image, words=words, boxes=boxes, actual_boxes=actual_boxes, tokenizer=tokenizer, args=args)
134
+
135
+ input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)
136
+ attention_mask = torch.tensor(input_mask, device=device).unsqueeze(0)
137
+ token_type_ids = torch.tensor(segment_ids, device=device).unsqueeze(0)
138
+ bbox = torch.tensor(token_boxes, device=device).unsqueeze(0)
139
+
140
+
141
+ outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
142
+
143
+
144
+ ```