pierreguillou commited on
Commit
b050ba1
·
1 Parent(s): f43f6f8

Update files/functions.py

Browse files
Files changed (1) hide show
  1. files/functions.py +44 -23
files/functions.py CHANGED
@@ -51,22 +51,13 @@ label2color = {
51
 
52
  # bounding boxes start and end of a sequence
53
  cls_box = [0, 0, 0, 0]
54
- sep_box = cls_box
55
 
56
  # model
57
- from transformers import AutoTokenizer, AutoModelForTokenClassification
58
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
 
60
- model_id = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
61
-
62
- tokenizer = AutoTokenizer.from_pretrained(model_id)
63
- model = AutoModelForTokenClassification.from_pretrained(model_id);
64
- model.to(device);
65
-
66
- # get labels
67
- id2label = model.config.id2label
68
- label2id = model.config.label2id
69
- num_labels = len(id2label)
70
 
71
  # (tokenization) The maximum length of a feature (sequence)
72
  if str(384) in model_id:
@@ -81,7 +72,21 @@ doc_stride = 128 # The authorized overlap between two part of the context when s
81
 
82
  # max PDF page images that will be displayed
83
  max_imgboxes = 2
 
 
84
  examples_dir = 'files/'
 
 
 
 
 
 
 
 
 
 
 
 
85
  image_wo_content = examples_dir + "wo_content.png" # image without content
86
  pdf_blank = examples_dir + "blank.pdf" # blank PDF
87
  image_blank = examples_dir + "blank.png" # blank image
@@ -368,8 +373,8 @@ def extraction_data_from_image(images):
368
 
369
  # https://pyimagesearch.com/2021/11/15/tesseract-page-segmentation-modes-psms-explained-how-to-improve-your-ocr-accuracy/
370
  custom_config = r'--oem 3 --psm 3 -l eng' # default config PyTesseract: --oem 3 --psm 3 -l eng+deu+fra+jpn+por+spa+rus+hin+chi_sim
371
- results, lines, row_indexes, par_boxes, line_boxes = dict(), dict(), dict(), dict(), dict()
372
- images_ids_list, lines_list, par_boxes_list, line_boxes_list, images_list, page_no_list, num_pages_list = list(), list(), list(), list(), list(), list(), list()
373
 
374
  try:
375
  for i,image in enumerate(images):
@@ -401,11 +406,15 @@ def extraction_data_from_image(images):
401
  results[i] = pytesseract.image_to_data(img, config=custom_config, output_type=pytesseract.Output.DICT)
402
  # results[i] = os.popen(f'tesseract {img_filepath} - {custom_config}').read()
403
 
 
 
 
404
  lines[i], row_indexes[i], par_boxes[i], line_boxes[i] = get_data(results[i], factor, conf_min=0)
405
  lines_list.append(lines[i])
406
  par_boxes_list.append(par_boxes[i])
407
  line_boxes_list.append(line_boxes[i])
408
  images_ids_list.append(i)
 
409
  images_list.append(images[i])
410
  page_no_list.append(i)
411
  num_pages_list.append(num_imgs)
@@ -414,7 +423,7 @@ def extraction_data_from_image(images):
414
  print(f"There was an error within the extraction of PDF text by the OCR!")
415
  else:
416
  from datasets import Dataset
417
- dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts": lines_list, "bboxes_line": line_boxes_list})
418
 
419
  # print(f"The text data was successfully extracted by the OCR!")
420
 
@@ -424,11 +433,12 @@ def extraction_data_from_image(images):
424
 
425
  def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
426
 
427
- images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list = list(), list(), list(), list(), list()
428
 
429
  # get batch
430
  batch_images_ids = example["images_ids"]
431
  batch_images = example["images"]
 
432
  batch_bboxes_line = example["bboxes_line"]
433
  batch_texts = example["texts"]
434
  batch_images_size = [image.size for image in batch_images]
@@ -439,12 +449,13 @@ def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
439
  if not isinstance(batch_images_ids, list):
440
  batch_images_ids = [batch_images_ids]
441
  batch_images = [batch_images]
 
442
  batch_bboxes_line = [batch_bboxes_line]
443
  batch_texts = [batch_texts]
444
  batch_width, batch_height = [batch_width], [batch_height]
445
 
446
  # process all images of the batch
447
- for num_batch, (image_id, boxes, texts, width, height) in enumerate(zip(batch_images_ids, batch_bboxes_line, batch_texts, batch_width, batch_height)):
448
  tokens_list = []
449
  bboxes_list = []
450
 
@@ -506,6 +517,7 @@ def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
506
  bb_list.append(bb)
507
  images_ids_list.append(image_id)
508
  chunks_ids_list.append(i)
 
509
 
510
  return {
511
  "images_ids": images_ids_list,
@@ -513,6 +525,7 @@ def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
513
  "input_ids": input_ids_list,
514
  "attention_mask": attention_mask_list,
515
  "normalized_bboxes": bb_list,
 
516
  }
517
 
518
  from torch.utils.data import Dataset
@@ -534,18 +547,21 @@ class CustomDataset(Dataset):
534
  encoding["input_ids"] = example["input_ids"]
535
  encoding["attention_mask"] = example["attention_mask"]
536
  encoding["bbox"] = example["normalized_bboxes"]
 
537
 
538
  return encoding
539
 
540
  import torch.nn.functional as F
541
 
 
 
542
  # get predictions at token level
543
  def predictions_token_level(images, custom_encoded_dataset):
544
 
545
  num_imgs = len(images)
546
  if num_imgs > 0:
547
 
548
- chunk_ids, input_ids, bboxes, outputs, token_predictions = dict(), dict(), dict(), dict(), dict()
549
  images_ids_list = list()
550
 
551
  for i,encoding in enumerate(custom_encoded_dataset):
@@ -556,6 +572,7 @@ def predictions_token_level(images, custom_encoded_dataset):
556
  input_id = torch.tensor(encoding['input_ids'])[None]
557
  attention_mask = torch.tensor(encoding['attention_mask'])[None]
558
  bbox = torch.tensor(encoding['bbox'])[None]
 
559
 
560
  # save data in dictionnaries
561
  if image_id not in images_ids_list: images_ids_list.append(image_id)
@@ -569,14 +586,18 @@ def predictions_token_level(images, custom_encoded_dataset):
569
  if image_id in bboxes: bboxes[image_id].append(bbox)
570
  else: bboxes[image_id] = [bbox]
571
 
 
 
 
572
  # get prediction with forward pass
573
  with torch.no_grad():
574
  output = model(
575
- input_ids=input_id,
576
- attention_mask=attention_mask,
577
- bbox=bbox
 
578
  )
579
-
580
  # save probabilities of predictions in dictionnary
581
  if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
582
  else: outputs[image_id] = [F.softmax(output.logits.squeeze(), dim=-1)]
 
51
 
52
  # bounding boxes start and end of a sequence
53
  cls_box = [0, 0, 0, 0]
54
+ sep_box = [1000, 1000, 1000, 1000]
55
 
56
  # model
57
+ model_id = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
 
58
 
59
+ # tokenizer
60
+ tokenizer_id = "xlm-roberta-base"
 
 
 
 
 
 
 
 
61
 
62
  # (tokenization) The maximum length of a feature (sequence)
63
  if str(384) in model_id:
 
72
 
73
  # max PDF page images that will be displayed
74
  max_imgboxes = 2
75
+
76
+ # get files
77
  examples_dir = 'files/'
78
+ Path(examples_dir).mkdir(parents=True, exist_ok=True)
79
+ from huggingface_hub import hf_hub_download
80
+ files = ["example.pdf", "blank.pdf", "blank.png", "languages_iso.csv", "languages_tesseract.csv", "wo_content.png"]
81
+ for file_name in files:
82
+ path_to_file = hf_hub_download(
83
+ repo_id = "pierreguillou/Inference-APP-Document-Understanding-at-linelevel-v2",
84
+ filename = "files/" + file_name,
85
+ repo_type = "space"
86
+ )
87
+ shutil.copy(path_to_file,examples_dir)
88
+
89
+ # path to files
90
  image_wo_content = examples_dir + "wo_content.png" # image without content
91
  pdf_blank = examples_dir + "blank.pdf" # blank PDF
92
  image_blank = examples_dir + "blank.png" # blank image
 
373
 
374
  # https://pyimagesearch.com/2021/11/15/tesseract-page-segmentation-modes-psms-explained-how-to-improve-your-ocr-accuracy/
375
  custom_config = r'--oem 3 --psm 3 -l eng' # default config PyTesseract: --oem 3 --psm 3 -l eng+deu+fra+jpn+por+spa+rus+hin+chi_sim
376
+ results, lines, row_indexes, par_boxes, line_boxes, images_pixels = dict(), dict(), dict(), dict(), dict(), dict()
377
+ images_ids_list, lines_list, par_boxes_list, line_boxes_list, images_list, images_pixels_list, page_no_list, num_pages_list = list(), list(), list(), list(), list(), list(), list(), list()
378
 
379
  try:
380
  for i,image in enumerate(images):
 
406
  results[i] = pytesseract.image_to_data(img, config=custom_config, output_type=pytesseract.Output.DICT)
407
  # results[i] = os.popen(f'tesseract {img_filepath} - {custom_config}').read()
408
 
409
+ # get image pixels
410
+ images_pixels[i] = feature_extractor(images[i], return_tensors="pt").pixel_values
411
+
412
  lines[i], row_indexes[i], par_boxes[i], line_boxes[i] = get_data(results[i], factor, conf_min=0)
413
  lines_list.append(lines[i])
414
  par_boxes_list.append(par_boxes[i])
415
  line_boxes_list.append(line_boxes[i])
416
  images_ids_list.append(i)
417
+ images_pixels_list.append(images_pixels[i])
418
  images_list.append(images[i])
419
  page_no_list.append(i)
420
  num_pages_list.append(num_imgs)
 
423
  print(f"There was an error within the extraction of PDF text by the OCR!")
424
  else:
425
  from datasets import Dataset
426
+ dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "images_pixels": images_pixels_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts": lines_list, "bboxes_line": line_boxes_list})
427
 
428
  # print(f"The text data was successfully extracted by the OCR!")
429
 
 
433
 
434
  def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
435
 
436
+ images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
437
 
438
  # get batch
439
  batch_images_ids = example["images_ids"]
440
  batch_images = example["images"]
441
+ batch_images_pixels = example["images_pixels"]
442
  batch_bboxes_line = example["bboxes_line"]
443
  batch_texts = example["texts"]
444
  batch_images_size = [image.size for image in batch_images]
 
449
  if not isinstance(batch_images_ids, list):
450
  batch_images_ids = [batch_images_ids]
451
  batch_images = [batch_images]
452
+ batch_images_pixels = [batch_images_pixels]
453
  batch_bboxes_line = [batch_bboxes_line]
454
  batch_texts = [batch_texts]
455
  batch_width, batch_height = [batch_width], [batch_height]
456
 
457
  # process all images of the batch
458
+ for num_batch, (image_id, image_pixels, boxes, texts, width, height) in enumerate(zip(batch_images_ids, batch_images_pixels, batch_bboxes_line, batch_texts, batch_width, batch_height)):
459
  tokens_list = []
460
  bboxes_list = []
461
 
 
517
  bb_list.append(bb)
518
  images_ids_list.append(image_id)
519
  chunks_ids_list.append(i)
520
+ images_pixels_list.append(image_pixels)
521
 
522
  return {
523
  "images_ids": images_ids_list,
 
525
  "input_ids": input_ids_list,
526
  "attention_mask": attention_mask_list,
527
  "normalized_bboxes": bb_list,
528
+ "images_pixels": images_pixels_list
529
  }
530
 
531
  from torch.utils.data import Dataset
 
547
  encoding["input_ids"] = example["input_ids"]
548
  encoding["attention_mask"] = example["attention_mask"]
549
  encoding["bbox"] = example["normalized_bboxes"]
550
+ encoding["images_pixels"] = example["images_pixels"]
551
 
552
  return encoding
553
 
554
  import torch.nn.functional as F
555
 
556
+ import torch.nn.functional as F
557
+
558
  # get predictions at token level
559
  def predictions_token_level(images, custom_encoded_dataset):
560
 
561
  num_imgs = len(images)
562
  if num_imgs > 0:
563
 
564
+ chunk_ids, input_ids, bboxes, pixels_values, outputs, token_predictions = dict(), dict(), dict(), dict(), dict(), dict()
565
  images_ids_list = list()
566
 
567
  for i,encoding in enumerate(custom_encoded_dataset):
 
572
  input_id = torch.tensor(encoding['input_ids'])[None]
573
  attention_mask = torch.tensor(encoding['attention_mask'])[None]
574
  bbox = torch.tensor(encoding['bbox'])[None]
575
+ pixel_values = torch.tensor(encoding["images_pixels"])
576
 
577
  # save data in dictionnaries
578
  if image_id not in images_ids_list: images_ids_list.append(image_id)
 
586
  if image_id in bboxes: bboxes[image_id].append(bbox)
587
  else: bboxes[image_id] = [bbox]
588
 
589
+ if image_id in pixels_values: pixels_values[image_id].append(pixel_values)
590
+ else: pixels_values[image_id] = [pixel_values]
591
+
592
  # get prediction with forward pass
593
  with torch.no_grad():
594
  output = model(
595
+ input_ids=input_id.to(device),
596
+ attention_mask=attention_mask.to(device),
597
+ bbox=bbox.to(device),
598
+ image=pixel_values.to(device)
599
  )
600
+
601
  # save probabilities of predictions in dictionnary
602
  if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
603
  else: outputs[image_id] = [F.softmax(output.logits.squeeze(), dim=-1)]