karida commited on
Commit
d4a6a10
1 Parent(s): 8b5a1c6

Add gradio

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. main.py +165 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.15.0
8
- app_file: app.py
9
  pinned: false
10
  license: other
11
  ---
 
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.15.0
8
+ app_file: main.py
9
  pinned: false
10
  license: other
11
  ---
main.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import ImageDraw, ImageFont, Image
3
+ from transformers import AutoModelForTokenClassification, AutoProcessor
4
+ import fitz # PyMuPDF
5
+ import io
6
+
7
+
8
+ def extract_data_from_pdf(pdf_path, page_number=0):
9
+ """
10
+ Extracts image, words, and bounding boxes from a specified page of a PDF.
11
+
12
+ Args:
13
+ - pdf_path (str): Path to the PDF file.
14
+ - page_number (int): Page number to extract data from (0-indexed).
15
+
16
+ Returns:
17
+ - image: An image of the specified page.
18
+ - words: A list of words found on the page.
19
+ - boxes: A list of bounding boxes corresponding to the words.
20
+ """
21
+ # Open the PDF
22
+ doc = fitz.open(pdf_path)
23
+ page = doc.load_page(page_number)
24
+
25
+ # Extract image of the page
26
+ pix = page.get_pixmap()
27
+ image_bytes = pix.tobytes("png")
28
+ image = Image.open(io.BytesIO(image_bytes))
29
+
30
+ # Extract words and their bounding boxes
31
+ words = []
32
+ boxes = []
33
+ for word in page.get_text("words"):
34
+ words.append(word[4])
35
+ boxes.append(word[:4]) # (x0, y0, x1, y1)
36
+
37
+ doc.close()
38
+ return image, words, boxes
39
+
40
+
41
+ def merge_pairs_v2(pairs):
42
+ if not pairs:
43
+ return []
44
+
45
+ merged = [pairs[0]]
46
+ for current in pairs[1:]:
47
+ last = merged[-1]
48
+ if last[0] == current[0]:
49
+ # Merge 'y' values (as strings) if 'x' values are the same
50
+ merged[-1] = [last[0], last[1] + " " + current[1]]
51
+ else:
52
+ merged.append(current)
53
+
54
+ return merged
55
+
56
+
57
+ def create_pretty_table(data):
58
+ table = "<div>"
59
+ for row in data:
60
+ color = (
61
+ "blue"
62
+ if row[0] == "Heder"
63
+ else "green"
64
+ if row[0] == "Section"
65
+ else "black"
66
+ )
67
+ table += "<p style='color:{};'>---{}---</p>{}".format(
68
+ color, row[0], row[1]
69
+ )
70
+ table += "</div>"
71
+ return table
72
+
73
+
74
+ # When using this function in Gradio, set the output type to 'html'
75
+
76
+
77
+ def interference(example, page_number=0):
78
+ image, words, boxes = extract_data_from_pdf(example, page_number)
79
+ boxes = [list(map(int, box)) for box in boxes]
80
+
81
+ # Process the image and words
82
+ model = AutoModelForTokenClassification.from_pretrained(
83
+ "karida/LayoutLMv3_RFP"
84
+ )
85
+ processor = AutoProcessor.from_pretrained(
86
+ "microsoft/layoutlmv3-base", apply_ocr=False
87
+ )
88
+ encoding = processor(image, words, boxes=boxes, return_tensors="pt")
89
+
90
+ # Prediction
91
+ with torch.no_grad():
92
+ outputs = model(**encoding)
93
+
94
+ logits = outputs.logits
95
+ predictions = logits.argmax(-1).squeeze().tolist()
96
+ model_words = encoding.word_ids()
97
+
98
+ # Process predictions
99
+ token_boxes = encoding.bbox.squeeze().tolist()
100
+ width, height = image.size
101
+
102
+ true_predictions = [model.config.id2label[pred] for pred in predictions]
103
+ true_boxes = token_boxes
104
+ # Draw annotations on the image
105
+ draw = ImageDraw.Draw(image)
106
+ font = ImageFont.load_default()
107
+
108
+ def iob_to_label(label):
109
+ label = label[2:]
110
+ return "other" if not label else label.lower()
111
+
112
+ label2color = {
113
+ "question": "blue",
114
+ "answer": "green",
115
+ "header": "orange",
116
+ "other": "violet",
117
+ }
118
+
119
+ # print(len(true_predictions), len(true_boxes), len(model_words))
120
+
121
+ table = []
122
+ ids = set()
123
+
124
+ for prediction, box, model_word in zip(
125
+ true_predictions, true_boxes, model_words
126
+ ):
127
+ predicted_label = iob_to_label(prediction)
128
+ draw.rectangle(box, outline=label2color[predicted_label], width=2)
129
+ # draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
130
+ if model_word and model_word not in ids and predicted_label != "other":
131
+ ids.add(model_word)
132
+ table.append([predicted_label[0], words[model_word]])
133
+
134
+ values = merge_pairs_v2(table)
135
+ values = [
136
+ ["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] for x in values
137
+ ]
138
+ table = create_pretty_table(values)
139
+ return image, table
140
+
141
+
142
+ import gradio as gr
143
+
144
+ description_text = """
145
+ <p>
146
+ Heading - <span style='color: blue;'>shown in blue</span><br>
147
+ Section - <span style='color: green;'>shown in green</span><br>
148
+ other - (ignored)<span style='color: violet;'>shown in violet</span>
149
+ </p>
150
+ """
151
+
152
+ flagging_options = ["great example", "bad example"]
153
+
154
+
155
+ iface = gr.Interface(
156
+ fn=interference,
157
+ inputs=["file", "number"],
158
+ outputs=["image", "html"],
159
+ # examples=[["output.pdf", 1]],
160
+ description=description_text,
161
+ flagging_options=flagging_options,
162
+ )
163
+ # iface.save(".")
164
+ if __name__ == "__main__":
165
+ iface.launch()