srinivas-mushroom commited on
Commit
63459a6
1 Parent(s): 24236d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -150
app.py CHANGED
@@ -1,151 +1,42 @@
1
  import gradio as gr
2
- import csv
3
- from io import StringIO
4
- from PIL import Image
5
- import numpy as np
6
- import base64
7
-
8
- # Define the annotation types
9
- ANNOTATION_TYPES = ['rect', 'circle']
10
-
11
- # Define the Annotation class
12
- class Annotation:
13
- def __init__(self, x, y, width, height, annotation_type):
14
- self.x = x
15
- self.y = y
16
- self.width = width
17
- self.height = height
18
- self.type = annotation_type
19
-
20
- # Define the Gradio interface
21
- def annotate_images(images):
22
- # Define the canvas size
23
- canvas_size = (600, 600)
24
-
25
- # Define the initial state
26
- state = {
27
- 'image': None,
28
- 'annotations': [],
29
- 'annotation_type': ANNOTATION_TYPES[0],
30
- 'start_point': None
31
- }
32
-
33
- # Define the canvas drawing function
34
- def draw_canvas(canvas, image_data, annotations):
35
- # Convert the image data to a PIL Image object
36
- image = Image.fromarray(image_data)
37
-
38
- # Resize the image to fit the canvas
39
- image = image.resize(canvas_size)
40
-
41
- # Draw the image on the canvas
42
- canvas.draw_image(image, (canvas_size[0]/2, canvas_size[1]/2))
43
-
44
- # Draw the annotations on the canvas
45
- for annotation in annotations:
46
- x, y, width, height = annotation.x, annotation.y, annotation.width, annotation.height
47
- if annotation.type == 'rect':
48
- canvas.draw_rect(x, y, width, height, stroke_color='red')
49
- elif annotation.type == 'circle':
50
- radius = np.sqrt(np.power(width, 2) + np.power(height, 2)) / 2
51
- center_x, center_y = x + width / 2, y + height / 2
52
- canvas.draw_circle(center_x, center_y, radius, stroke_color='red')
53
-
54
- # Define the canvas mousedown event handler
55
- def canvas_mousedown(canvas, x, y):
56
- state['start_point'] = (x, y)
57
-
58
- # Define the canvas mousemove event handler
59
- def canvas_mousemove(canvas, x, y):
60
- if state['start_point'] is not None:
61
- start_x, start_y = state['start_point']
62
- end_x, end_y = x, y
63
- annotation_type = state['annotation_type']
64
- draw_annotation(canvas, start_x, start_y, end_x, end_y, annotation_type)
65
-
66
- # Define the canvas mouseup event handler
67
- def canvas_mouseup(canvas, x, y):
68
- if state['start_point'] is not None:
69
- start_x, start_y = state['start_point']
70
- end_x, end_y = x, y
71
- annotation_type = state['annotation_type']
72
- add_annotation(start_x, start_y, end_x, end_y, annotation_type)
73
- state['start_point'] = None
74
-
75
- # Define the add annotation function
76
- def add_annotation(start_x, start_y, end_x, end_y, annotation_type):
77
- # Calculate the width and height of the annotation
78
- width = np.abs(start_x - end_x)
79
- height = np.abs(start_y - end_y)
80
-
81
- # Create the annotation object
82
- annotation = Annotation(start_x, start_y, width, height, annotation_type)
83
-
84
- # Add the annotation to the array
85
- state['annotations'].append(annotation)
86
-
87
- # Redraw the canvas
88
- draw_canvas(canvas, state['image'], state['annotations'])
89
-
90
- #
91
-
92
- # Define the draw annotation function
93
- def draw_annotation(canvas, start_x, start_y, end_x, end_y, annotation_type):
94
- canvas.clear()
95
- draw_canvas(canvas, state['image'], state['annotations'])
96
- width = np.abs(start_x - end_x)
97
- height = np.abs(start_y - end_y)
98
- if annotation_type == 'rect':
99
- canvas.draw_rect(start_x, start_y, width, height, stroke_color='red')
100
- elif annotation_type == 'circle':
101
- radius = np.sqrt(np.power(width, 2) + np.power(height, 2)) / 2
102
- center_x, center_y = start_x + width / 2, start_y + height / 2
103
- canvas.draw_circle(center_x, center_y, radius, stroke_color='red')
104
-
105
- # Define the annotation type dropdown event handler
106
- def annotation_type_changed(value):
107
- state['annotation_type'] = value
108
-
109
- # Define the download annotations button click event handler
110
- def download_annotations_clicked():
111
- # Define the csv headers
112
- headers = ['x', 'y', 'width', 'height', 'type']
113
-
114
- # Define the csv data
115
- rows = [[str(annotation.x), str(annotation.y), str(annotation.width), str(annotation.height), annotation.type]
116
- for annotation in state['annotations']]
117
-
118
- # Create the csv string
119
- csv_string = StringIO()
120
- csv_writer = csv.writer(csv_string)
121
- csv_writer.writerow(headers)
122
- for row in rows:
123
- csv_writer.writerow(row)
124
-
125
- # Download the csv file
126
- b64_csv = base64.b64encode(csv_string.getvalue().encode()).decode()
127
- href = f'data:text/csv;base64,{b64_csv}'
128
- download_link = f'<a href="{href}" download="annotations.csv">Download Annotations CSV</a>'
129
- gr.Interface.show(download_link)
130
-
131
- # Define the interface components
132
- image = gr.inputs.Image(label='Image')
133
- annotation_type = gr.inputs.Dropdown(ANNOTATION_TYPES, label='Annotation Type', default=ANNOTATION_TYPES[0], onchange=annotation_type_changed)
134
- download_annotations = gr.outputs.Button(label='Download Annotations', type='button', onclick=download_annotations_clicked)
135
- canvas = gr.outputs.Canvas(draw_event_handlers={
136
- 'mousedown': canvas_mousedown,
137
- 'mousemove': canvas_mousemove,
138
- 'mouseup': canvas_mouseup
139
- })
140
-
141
- # Define the interface function
142
- def annotate_images(images):
143
- state['image'] = images[0]
144
- draw_canvas(canvas, state['image'], state['annotations'])
145
- return canvas, annotation_type, download_annotations
146
-
147
- # Create the interface
148
- interface = gr.Interface(annotate_images, inputs=image, outputs=[canvas, annotation_type, download_annotations], capture_session=True)
149
-
150
- return interface
151
-
 
1
  import gradio as gr
2
+ import requests
3
+ import io
4
+ import json
5
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
6
+
7
+ # Download and load pre-trained model and tokenizer
8
+ model_name = "distilbert-base-cased-distilled-squad"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
11
+
12
+ def answer_question(pdf_file, question):
13
+ # Convert PDF to text
14
+ pdf_data = pdf_file.read()
15
+ pdf_stream = io.BytesIO(pdf_data)
16
+ response = requests.post(
17
+ 'https://pdftotext.com/ExtractText',
18
+ files={'pdffile': pdf_stream},
19
+ data={'form': 'pdftotext'}
20
+ )
21
+ text = response.text.strip()
22
+
23
+ # Tokenize question and text
24
+ input_ids = tokenizer.encode(question, text)
25
+
26
+ # Perform question answering
27
+ outputs = model(torch.tensor([input_ids]), return_dict=True)
28
+ answer_start = outputs.start_logits.argmax().item()
29
+ answer_end = outputs.end_logits.argmax().item()
30
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end+1]))
31
+
32
+ return answer
33
+
34
+ inputs = [
35
+ gr.inputs.File(label="PDF document"),
36
+ gr.inputs.Textbox(label="Question")
37
+ ]
38
+
39
+ outputs = gr.outputs.Textbox(label="Answer")
40
+
41
+ gr.Interface(fn=answer_question, inputs=inputs, outputs=outputs, title="PDF Question Answering Tool",
42
+ description="Upload a PDF document and ask a question. The app will use a pre-trained model to find the answer.").launch()