Commit
·
fdaaec7
1
Parent(s):
4c1f7f8
Update README.md
Browse files
README.md
CHANGED
@@ -25,17 +25,73 @@ It achieves the following results on the evaluation set:
|
|
25 |
- Overall F1: 0.8900
|
26 |
- Overall Accuracy: 0.8204
|
27 |
|
28 |
-
## Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
More information needed
|
31 |
-
|
32 |
-
## Intended uses & limitations
|
33 |
-
|
34 |
-
More information needed
|
35 |
-
|
36 |
-
## Training and evaluation data
|
37 |
-
|
38 |
-
More information needed
|
39 |
|
40 |
## Training procedure
|
41 |
|
|
|
25 |
- Overall F1: 0.8900
|
26 |
- Overall Accuracy: 0.8204
|
27 |
|
28 |
+
## Model Usage
|
29 |
+
|
30 |
+
```python
|
31 |
+
from transformers import LiltForTokenClassification, LayoutLMv3Processor
|
32 |
+
from PIL import Image, ImageDraw, ImageFont
|
33 |
+
import torch
|
34 |
+
|
35 |
+
# load model and processor from huggingface hub
|
36 |
+
model = LiltForTokenClassification.from_pretrained("philschmid/lilt-en-funsd")
|
37 |
+
processor = LayoutLMv3Processor.from_pretrained("philschmid/lilt-en-funsd")
|
38 |
+
|
39 |
+
|
40 |
+
# helper function to unnormalize bboxes for drawing onto the image
|
41 |
+
def unnormalize_box(bbox, width, height):
|
42 |
+
return [
|
43 |
+
width * (bbox[0] / 1000),
|
44 |
+
height * (bbox[1] / 1000),
|
45 |
+
width * (bbox[2] / 1000),
|
46 |
+
height * (bbox[3] / 1000),
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
label2color = {
|
51 |
+
"B-HEADER": "blue",
|
52 |
+
"B-QUESTION": "red",
|
53 |
+
"B-ANSWER": "green",
|
54 |
+
"I-HEADER": "blue",
|
55 |
+
"I-QUESTION": "red",
|
56 |
+
"I-ANSWER": "green",
|
57 |
+
}
|
58 |
+
# draw results onto the image
|
59 |
+
def draw_boxes(image, boxes, predictions):
|
60 |
+
width, height = image.size
|
61 |
+
normalizes_boxes = [unnormalize_box(box, width, height) for box in boxes]
|
62 |
+
|
63 |
+
# draw predictions over the image
|
64 |
+
draw = ImageDraw.Draw(image)
|
65 |
+
font = ImageFont.load_default()
|
66 |
+
for prediction, box in zip(predictions, normalizes_boxes):
|
67 |
+
if prediction == "O":
|
68 |
+
continue
|
69 |
+
draw.rectangle(box, outline="black")
|
70 |
+
draw.rectangle(box, outline=label2color[prediction])
|
71 |
+
draw.text((box[0] + 10, box[1] - 10), text=prediction, fill=label2color[prediction], font=font)
|
72 |
+
return image
|
73 |
+
|
74 |
+
|
75 |
+
# run inference
|
76 |
+
def run_inference(image, model=model, processor=processor, output_image=True):
|
77 |
+
# create model input
|
78 |
+
encoding = processor(image, return_tensors="pt")
|
79 |
+
del encoding["pixel_values"]
|
80 |
+
# run inference
|
81 |
+
outputs = model(**encoding)
|
82 |
+
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
83 |
+
# get labels
|
84 |
+
labels = [model.config.id2label[prediction] for prediction in predictions]
|
85 |
+
if output_image:
|
86 |
+
return draw_boxes(image, encoding["bbox"][0], labels)
|
87 |
+
else:
|
88 |
+
return labels
|
89 |
+
|
90 |
+
|
91 |
+
run_inference(dataset["test"][34]["image"])
|
92 |
+
|
93 |
+
```
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
## Training procedure
|
97 |
|