Edit model card

lilt-en-funsd

This model is a fine-tuned version of SCUT-DLVCLab/lilt-roberta-en-base on the funsd-layoutlmv3 dataset. It achieves the following results on the evaluation set:

  • Loss: 1.6117
  • Answer: {'precision': 0.8821428571428571, 'recall': 0.9069767441860465, 'f1': 0.8943874471937237, 'number': 817}
  • Header: {'precision': 0.6126126126126126, 'recall': 0.5714285714285714, 'f1': 0.591304347826087, 'number': 119}
  • Question: {'precision': 0.9045045045045045, 'recall': 0.9322191272051996, 'f1': 0.9181527206218564, 'number': 1077}
  • Overall Precision: 0.8797
  • Overall Recall: 0.9006
  • Overall F1: 0.8900
  • Overall Accuracy: 0.8204

Model Usage

from transformers import LiltForTokenClassification, LayoutLMv3Processor
from PIL import Image, ImageDraw, ImageFont
import torch

# load model and processor from huggingface hub
model = LiltForTokenClassification.from_pretrained("philschmid/lilt-en-funsd")
processor = LayoutLMv3Processor.from_pretrained("philschmid/lilt-en-funsd")


# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
    return [
        width * (bbox[0] / 1000),
        height * (bbox[1] / 1000),
        width * (bbox[2] / 1000),
        height * (bbox[3] / 1000),
    ]


label2color = {
    "B-HEADER": "blue",
    "B-QUESTION": "red",
    "B-ANSWER": "green",
    "I-HEADER": "blue",
    "I-QUESTION": "red",
    "I-ANSWER": "green",
}
# draw results onto the image
def draw_boxes(image, boxes, predictions):
    width, height = image.size
    normalizes_boxes = [unnormalize_box(box, width, height) for box in boxes]

    # draw predictions over the image
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    for prediction, box in zip(predictions, normalizes_boxes):
        if prediction == "O":
            continue
        draw.rectangle(box, outline="black")
        draw.rectangle(box, outline=label2color[prediction])
        draw.text((box[0] + 10, box[1] - 10), text=prediction, fill=label2color[prediction], font=font)
    return image


# run inference
def run_inference(image, model=model, processor=processor, output_image=True):
    # create model input
    encoding = processor(image, return_tensors="pt")
    del encoding["pixel_values"]
    # run inference
    outputs = model(**encoding)
    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    # get labels
    labels = [model.config.id2label[prediction] for prediction in predictions]
    if output_image:
        return draw_boxes(image, encoding["bbox"][0], labels)
    else:
        return labels


run_inference(dataset["test"][34]["image"])

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-05
  • train_batch_size: 8
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • training_steps: 2500
  • mixed_precision_training: Native AMP

Training results

Training Loss Epoch Step Validation Loss Answer Header Question Overall Precision Overall Recall Overall F1 Overall Accuracy
0.0211 10.53 200 1.5528 {'precision': 0.8458904109589042, 'recall': 0.9069767441860465, 'f1': 0.8753691671588896, 'number': 817} {'precision': 0.5684210526315789, 'recall': 0.453781512605042, 'f1': 0.5046728971962617, 'number': 119} {'precision': 0.896551724137931, 'recall': 0.89322191272052, 'f1': 0.8948837209302325, 'number': 1077} 0.8596 0.8728 0.8662 0.8011
0.0132 21.05 400 1.3143 {'precision': 0.8447058823529412, 'recall': 0.8788249694002448, 'f1': 0.8614277144571085, 'number': 817} {'precision': 0.6020408163265306, 'recall': 0.4957983193277311, 'f1': 0.543778801843318, 'number': 119} {'precision': 0.8854262144821264, 'recall': 0.8969359331476323, 'f1': 0.8911439114391144, 'number': 1077} 0.8548 0.8659 0.8603 0.8095
0.0052 31.58 600 1.5747 {'precision': 0.8482446206115515, 'recall': 0.9167686658506732, 'f1': 0.8811764705882352, 'number': 817} {'precision': 0.6283185840707964, 'recall': 0.5966386554621849, 'f1': 0.6120689655172413, 'number': 119} {'precision': 0.8997161778618732, 'recall': 0.883008356545961, 'f1': 0.8912839737582005, 'number': 1077} 0.8626 0.8798 0.8711 0.8030
0.0073 42.11 800 1.4848 {'precision': 0.8487972508591065, 'recall': 0.9069767441860465, 'f1': 0.8769230769230769, 'number': 817} {'precision': 0.5190839694656488, 'recall': 0.5714285714285714, 'f1': 0.5439999999999999, 'number': 119} {'precision': 0.8941947565543071, 'recall': 0.8867223769730733, 'f1': 0.8904428904428905, 'number': 1077} 0.8514 0.8763 0.8636 0.7969
0.0057 52.63 1000 1.3993 {'precision': 0.8852071005917159, 'recall': 0.9155446756425949, 'f1': 0.9001203369434416, 'number': 817} {'precision': 0.5454545454545454, 'recall': 0.6050420168067226, 'f1': 0.5737051792828685, 'number': 119} {'precision': 0.899090909090909, 'recall': 0.9182915506035283, 'f1': 0.9085898024804776, 'number': 1077} 0.8710 0.8987 0.8846 0.8198
0.0023 63.16 1200 1.6463 {'precision': 0.8961201501877347, 'recall': 0.8763769889840881, 'f1': 0.886138613861386, 'number': 817} {'precision': 0.5625, 'recall': 0.5294117647058824, 'f1': 0.5454545454545455, 'number': 119} {'precision': 0.888, 'recall': 0.9275766016713092, 'f1': 0.9073569482288827, 'number': 1077} 0.8733 0.8833 0.8782 0.8082
0.001 73.68 1400 1.6476 {'precision': 0.8676814988290398, 'recall': 0.9069767441860465, 'f1': 0.8868940754039496, 'number': 817} {'precision': 0.6571428571428571, 'recall': 0.5798319327731093, 'f1': 0.6160714285714286, 'number': 119} {'precision': 0.908256880733945, 'recall': 0.9192200557103064, 'f1': 0.9137055837563451, 'number': 1077} 0.8785 0.8942 0.8863 0.8137
0.0014 84.21 1600 1.6493 {'precision': 0.8814814814814815, 'recall': 0.8739290085679314, 'f1': 0.8776889981561156, 'number': 817} {'precision': 0.6194690265486725, 'recall': 0.5882352941176471, 'f1': 0.603448275862069, 'number': 119} {'precision': 0.894404332129964, 'recall': 0.9201485608170845, 'f1': 0.9070938215102976, 'number': 1077} 0.8740 0.8818 0.8778 0.8041
0.0006 94.74 1800 1.6193 {'precision': 0.8766467065868263, 'recall': 0.8959608323133414, 'f1': 0.8861985472154963, 'number': 817} {'precision': 0.6068376068376068, 'recall': 0.5966386554621849, 'f1': 0.6016949152542374, 'number': 119} {'precision': 0.8946428571428572, 'recall': 0.9303621169916435, 'f1': 0.912152935821575, 'number': 1077} 0.8711 0.8967 0.8837 0.8137
0.0001 105.26 2000 1.6048 {'precision': 0.8751472320376914, 'recall': 0.9094247246022031, 'f1': 0.8919567827130852, 'number': 817} {'precision': 0.6140350877192983, 'recall': 0.5882352941176471, 'f1': 0.6008583690987125, 'number': 119} {'precision': 0.9062784349408554, 'recall': 0.924791086350975, 'f1': 0.9154411764705882, 'number': 1077} 0.8773 0.8987 0.8879 0.8194
0.0001 115.79 2200 1.6117 {'precision': 0.8821428571428571, 'recall': 0.9069767441860465, 'f1': 0.8943874471937237, 'number': 817} {'precision': 0.6126126126126126, 'recall': 0.5714285714285714, 'f1': 0.591304347826087, 'number': 119} {'precision': 0.9045045045045045, 'recall': 0.9322191272051996, 'f1': 0.9181527206218564, 'number': 1077} 0.8797 0.9006 0.8900 0.8204
0.0001 126.32 2400 1.6163 {'precision': 0.8799048751486326, 'recall': 0.9057527539779682, 'f1': 0.8926417370325694, 'number': 817} {'precision': 0.6052631578947368, 'recall': 0.5798319327731093, 'f1': 0.5922746781115881, 'number': 119} {'precision': 0.9062784349408554, 'recall': 0.924791086350975, 'f1': 0.9154411764705882, 'number': 1077} 0.8788 0.8967 0.8876 0.8192

Framework versions

  • Transformers 4.24.0
  • Pytorch 1.12.1+cu113
  • Datasets 2.7.0
  • Tokenizers 0.12.1
Downloads last month
649
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.