pierreguillou commited on
Commit
dca869a
1 Parent(s): fe811a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -52,21 +52,25 @@ os.system('python -m pip install --upgrade pip')
52
 
53
  ## model / feature extractor / tokenizer
54
 
 
 
 
 
 
55
  import torch
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
 
58
- # model 1
 
59
  from transformers import AutoTokenizer, AutoModelForTokenClassification
60
- model_id = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
61
- tokenizer1 = AutoTokenizer.from_pretrained(model_id)
62
- model1 = AutoModelForTokenClassification.from_pretrained(model_id);
63
- model1.to(device);
64
 
65
- # model 2
66
  from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
67
- model_id = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
68
- model2 = LayoutLMv2ForTokenClassification.from_pretrained(model_id);
69
- model2.to(device);
70
 
71
  # feature extractor
72
  from transformers import LayoutLMv2FeatureExtractor
@@ -74,27 +78,27 @@ feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
74
 
75
  # tokenizer
76
  from transformers import AutoTokenizer
77
- tokenizer_id = "xlm-roberta-base"
78
- tokenizer2 = AutoTokenizer.from_pretrained(tokenizer_id)
79
 
80
- # APP outputs
81
- def app_outputs(uploaded_pdf):
82
  filename, msg, images = pdf_to_images(uploaded_pdf)
83
  num_images = len(images)
84
 
85
  if not msg.startswith("Error with the PDF"):
86
-
87
  # Extraction of image data (text and bounding boxes)
88
  dataset, lines, row_indexes, par_boxes, line_boxes = extraction_data_from_image(images)
89
  # prepare our data in the format of the model
90
- encoded_dataset = dataset.map(prepare_inference_features, batched=True, batch_size=64, remove_columns=dataset.column_names)
 
91
  custom_encoded_dataset = CustomDataset(encoded_dataset, tokenizer)
92
  # Get predictions (token level)
93
- outputs, images_ids_list, chunk_ids, input_ids, bboxes = predictions_token_level(images, custom_encoded_dataset)
94
  # Get predictions (line level)
95
- probs_bbox, bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = predictions_line_level(dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes)
96
  # Get labeled images with lines bounding boxes
97
- images = get_labeled_images(dataset, images_ids_list, bboxes_list_dict, probs_dict_dict)
98
 
99
  img_files = list()
100
  # get image of PDF without bounding boxes
@@ -130,7 +134,7 @@ def app_outputs(uploaded_pdf):
130
  df, df_empty = dict(), pd.DataFrame()
131
  df[0], df[1] = df_empty.to_csv(csv_file, encoding="utf-8", index=False), df_empty.to_csv(csv_file, encoding="utf-8", index=False)
132
 
133
- return msg, img_files[0], img_files[1], images[0], images[1], csv_files[0], csv_files[1], df[0], df[1]
134
 
135
  # gradio APP
136
  with gr.Blocks(title="Inference APP for Document Understanding at line level (v1 - LiLT base vs LayoutXLM base)", css=".gradio-container") as demo:
 
52
 
53
  ## model / feature extractor / tokenizer
54
 
55
+ # models
56
+ model_id_lilt = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
57
+ model_id_layoutxlm = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
58
+
59
+ # get device
60
  import torch
61
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
 
63
+ ## model LiLT
64
+ import transformers
65
  from transformers import AutoTokenizer, AutoModelForTokenClassification
66
+ tokenizer_lilt = AutoTokenizer.from_pretrained(model_id_lilt)
67
+ model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
68
+ model_lilt.to(device);
 
69
 
70
+ ## model LayoutXLM
71
  from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
72
+ model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
73
+ model_layoutxlm.to(device);
 
74
 
75
  # feature extractor
76
  from transformers import LayoutLMv2FeatureExtractor
 
78
 
79
  # tokenizer
80
  from transformers import AutoTokenizer
81
+ tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
 
82
 
83
+ # APP outputs by model
84
+ def app_outputs_by_model(uploaded_pdf, model_id, model, tokenizer, max_length, id2label, cls_box, sep_box):
85
  filename, msg, images = pdf_to_images(uploaded_pdf)
86
  num_images = len(images)
87
 
88
  if not msg.startswith("Error with the PDF"):
89
+
90
  # Extraction of image data (text and bounding boxes)
91
  dataset, lines, row_indexes, par_boxes, line_boxes = extraction_data_from_image(images)
92
  # prepare our data in the format of the model
93
+ prepare_inference_features_partial = partial(prepare_inference_features, tokenizer=tokenizer, max_length=max_length, cls_box=cls_box, sep_box=sep_box)
94
+ encoded_dataset = dataset.map(prepare_inference_features_partial, batched=True, batch_size=64, remove_columns=dataset.column_names)
95
  custom_encoded_dataset = CustomDataset(encoded_dataset, tokenizer)
96
  # Get predictions (token level)
97
+ outputs, images_ids_list, chunk_ids, input_ids, bboxes = predictions_token_level(images, custom_encoded_dataset, model_id, model)
98
  # Get predictions (line level)
99
+ probs_bbox, bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = predictions_line_level(max_length, tokenizer, id2label, dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes, cls_box, sep_box)
100
  # Get labeled images with lines bounding boxes
101
+ images = get_labeled_images(id2label, dataset, images_ids_list, bboxes_list_dict, probs_dict_dict)
102
 
103
  img_files = list()
104
  # get image of PDF without bounding boxes
 
134
  df, df_empty = dict(), pd.DataFrame()
135
  df[0], df[1] = df_empty.to_csv(csv_file, encoding="utf-8", index=False), df_empty.to_csv(csv_file, encoding="utf-8", index=False)
136
 
137
+ return msg, img_files[0], images[0], csv_files[0], df[0]
138
 
139
  # gradio APP
140
  with gr.Blocks(title="Inference APP for Document Understanding at line level (v1 - LiLT base vs LayoutXLM base)", css=".gradio-container") as demo: