sugiv commited on
Commit
452c0e2
1 Parent(s): 170c3c0

Adding initial set of files

Browse files
Files changed (4) hide show
  1. app.py +175 -0
  2. backend.py +175 -0
  3. enriched_pdf.pkl +3 -0
  4. pdf_classes.py +25 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from backend import get_answer
3
+ import pickle
4
+ import re
5
+ import os
6
+
7
+ from pdf_classes import PDFSegment, PDFPage, RichPDFDocument
8
+
9
+ def load_enriched_pdf(file_path):
10
+ with open(file_path, 'rb') as f:
11
+ return pickle.load(f)
12
+
13
+ # Load the enriched PDF
14
+ enriched_pdf = load_enriched_pdf('enriched_pdf.pkl')
15
+
16
+
17
+ # Access API tokens from environment variables
18
+ jina_api_token = os.getenv('JINA_API_TOKEN')
19
+ gpt4_api_key = os.getenv('GPT4_API_KEY')
20
+ pinecone_api_key = os.getenv('PINECONE_API_KEY')
21
+
22
+ # Initialize Pinecone with environment variables
23
+ os.environ["PINECONE_API_KEY"] = pinecone_api_key
24
+ os.environ["PINECONE_ENVIRONMENT"] = "us-east-1"
25
+ # Sample data for papers (5 papers for the grid)
26
+
27
+ papers = [
28
+ {"id": "1", "title": "Attention Is All You Need", "authors": "Vaswani et al.", "year": 2017},
29
+ {"id": "2", "title": "BERT", "authors": "Devlin et al.", "year": 2018},
30
+ {"id": "3", "title": "GPT-3", "authors": "Brown et al.", "year": 2020},
31
+ {"id": "4", "title": "Transformer-XL", "authors": "Dai et al.", "year": 2019},
32
+ {"id": "5", "title": "T5", "authors": "Raffel et al.", "year": 2020},
33
+ ]
34
+
35
+ predefined_questions = {
36
+ '1': [
37
+ 'Explain equation one in laymen terms and explain each and every component?',
38
+ 'Create list of authors who contributed to the paper in the same order, starting from left to right and go down?',
39
+ 'Explain figure two, left to right and also explain the flow of the diagram?',
40
+ 'Explain the position-wise Feed forward networks and equation two?',
41
+ 'Please summarize the findings from table 1?',
42
+ 'Explain the optimizer used and explain equation 3',
43
+ 'What is BLUE score for Tranformer model from Table 2?',
44
+ 'What does Figure 1 illustrate about the overall architecture of the Transformer model?',
45
+ 'How does Figure 2 depict the difference between Scaled Dot-Product Attention and Multi-Head Attention?',
46
+ 'Based on Figure 1, how many encoder and decoder layers are used in the Transformer model?',
47
+ 'What mathematical formula is shown in Figure 2 for Scaled Dot-Product Attention?',
48
+ 'According to Table 1, how does the complexity of Self-Attention compare to Recurrent and Convolutional layers?',
49
+ 'What does Table 2 reveal about the BLEU scores and training costs of the Transformer compared to other models?',
50
+ "How does Table 3 visualize the impact of different model variations on the Transformer's performance?",
51
+ 'What does Equation 3 in the paper represent, and how is it visually presented?',
52
+ 'Can you describe the sinusoidal function used for positional encoding as shown in the equations in Section 3.5?',
53
+ "How does Figure 1 illustrate the flow of information in the Transformer's encoder-decoder structure?"
54
+ ]
55
+ }
56
+
57
+
58
+ css = """
59
+ body { font-family: Arial, sans-serif; }
60
+ .container { max-width: 800px; margin: 0 auto; padding: 20px; }
61
+ .hero { text-align: center; margin-bottom: 30px; }
62
+ .paper-grid { display: grid; grid-template-columns: repeat(5, 1fr); gap: 10px; margin-bottom: 30px; }
63
+ .paper-tile { background-color: white; border: 2px solid #ddd; border-radius: 8px; padding: 10px; cursor: pointer; transition: all 0.3s; }
64
+ .paper-tile:hover { transform: translateY(-5px); box-shadow: 0 5px 15px rgba(0,0,0,0.1); }
65
+ .paper-tile.selected { border-color: #007bff; background-color: #e6f3ff; }
66
+ .paper-tile h3 { margin-top: 0; font-size: 14px; }
67
+ .paper-tile p { margin: 5px 0; font-size: 12px; color: #666; }
68
+ #chat-area { background-color: white; border-radius: 8px; padding: 20px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
69
+ """
70
+
71
+ def update_predefined_questions(paper_id):
72
+ if paper_id in predefined_questions:
73
+ return gr.Dropdown(choices=predefined_questions[paper_id], visible=True)
74
+ return gr.Dropdown(choices=[], visible=False)
75
+
76
+ def format_answer(answer):
77
+ # Convert LaTeX-style math to Markdown-style math
78
+ answer = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', answer)
79
+ answer = re.sub(r'\\\((.*?)\\\)', r'$\1$', answer)
80
+
81
+ # Format headers
82
+ lines = answer.split('\n')
83
+ formatted_lines = []
84
+ for line in lines:
85
+ if line.startswith('###'):
86
+ formatted_lines.append(f"\n{line}\n")
87
+ elif line.startswith('**') and line.endswith('**'):
88
+ formatted_lines.append(f"\n{line}\n")
89
+ else:
90
+ formatted_lines.append(line)
91
+
92
+ # Join lines back together
93
+ formatted_answer = '\n'.join(formatted_lines)
94
+
95
+ # Add spacing around math blocks
96
+ formatted_answer = re.sub(r'(\\\\.*?\\\\)', r'\n\1\n', formatted_answer)
97
+
98
+ return formatted_answer
99
+
100
+ def update_chat_area(paper_id, predefined_question):
101
+ if not paper_id:
102
+ return "Please select a paper first."
103
+
104
+ selected_paper = next((p for p in papers if p['id'] == paper_id), None)
105
+ if not selected_paper:
106
+ return "Invalid paper selection."
107
+
108
+ if selected_paper['id'] != '1':
109
+ return "This paper will be supported soon."
110
+
111
+ if not predefined_question:
112
+ return "Please select a predefined question."
113
+
114
+ # Call the backend function to get the answer
115
+ answer = get_answer(predefined_question, enriched_pdf, jina_api_token, gpt4_api_key)
116
+ return format_answer(answer) if answer else "Failed to generate an answer. Please try again."
117
+
118
+ with gr.Blocks(css=css) as demo:
119
+ gr.HTML('''
120
+ <div class="hero">
121
+ <h1>AI Paper Q&A</h1>
122
+ <p>Select a paper and ask questions about it</p>
123
+ </div>
124
+ ''')
125
+
126
+ paper_id_input = gr.Textbox(visible=False)
127
+
128
+ with gr.Row():
129
+ paper_tiles = gr.Radio(
130
+ choices=[f"{p['title']} ({p['authors']}, {p['year']})" for p in papers],
131
+ label="Select a paper",
132
+ info="Choose one of the papers to ask questions about."
133
+ )
134
+
135
+ predefined_question_dropdown = gr.Dropdown(label="Select a predefined question", choices=[], visible=False)
136
+ custom_question_input = gr.Textbox(
137
+ label="Or ask your own question here...",
138
+ value="Will be supported later after adding prompt guard",
139
+ interactive=False
140
+ )
141
+
142
+ submit_btn = gr.Button("Submit")
143
+
144
+ chat_output = gr.Markdown(label="Answer")
145
+
146
+ def update_chat_area_with_loading(paper_id, predefined_question):
147
+ # Display loading message while processing
148
+ loading_message = "**Generating answer...**"
149
+
150
+ # Return early with loading message to show progress
151
+ yield loading_message
152
+
153
+ # Call the actual function and yield its result
154
+ yield update_chat_area(paper_id, predefined_question)
155
+
156
+ paper_tiles.change(
157
+ fn=lambda x: next((p['id'] for p in papers if f"{p['title']} ({p['authors']}, {p['year']})" == x), None),
158
+ inputs=[paper_tiles],
159
+ outputs=[paper_id_input]
160
+ )
161
+
162
+ paper_id_input.change(
163
+ fn=update_predefined_questions,
164
+ inputs=[paper_id_input],
165
+ outputs=[predefined_question_dropdown]
166
+ )
167
+
168
+ submit_btn.click(
169
+ fn=update_chat_area_with_loading,
170
+ inputs=[paper_id_input, predefined_question_dropdown],
171
+ outputs=[chat_output]
172
+ )
173
+
174
+ if __name__ == '__main__':
175
+ demo.launch()
backend.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import numpy as np
3
+ from openai import OpenAI
4
+ import json
5
+ from pinecone import Pinecone
6
+ import os
7
+
8
+ import requests
9
+ import numpy as np
10
+ from PIL import Image as PILImage
11
+ from docarray import BaseDoc
12
+ from docarray import DocList
13
+ from docarray.typing import ImageTensor, NdArray
14
+ from typing import List, Dict, Optional
15
+ import requests
16
+ import base64
17
+ from pdf_classes import RichPDFDocument
18
+ import io
19
+
20
+ # Access environment variables individually and pass them as separate arguments
21
+ pc = Pinecone(
22
+ api_key=os.environ["PINECONE_API_KEY"],
23
+ environment=os.environ["PINECONE_ENVIRONMENT"]
24
+ )
25
+ print("Connected to Pinecone")
26
+ index = pc.Index("rich-pdf-late-chunks")
27
+
28
+
29
+ def create_question_embedding(question: str, api_token: str) -> np.ndarray:
30
+ url = 'https://api.jina.ai/v1/embeddings'
31
+ headers = {
32
+ 'Content-Type': 'application/json',
33
+ 'Authorization': f'Bearer {api_token}'
34
+ }
35
+ data = {
36
+ "model": "jina-clip-v1",
37
+ "input": [{"text": question}]
38
+ }
39
+ response = requests.post(url, headers=headers, json=data)
40
+ if response.status_code == 200:
41
+ result = response.json()
42
+ return np.array(result['data'][0]['embedding'])
43
+ else:
44
+ raise Exception(f"Error creating embedding: {response.text}")
45
+
46
+ from openai import OpenAI
47
+
48
+ def create_few_shot_prompt(question: str, rich_pdf: RichPDFDocument, pinecone_index, api_token: str, top_k: int = 3):
49
+ prompt = f"Question: {question}\n\n"
50
+ prompt += "Here are relevant excerpts from the document:\n\n"
51
+ image_data = []
52
+ included_pages = set()
53
+
54
+ question_embedding = create_question_embedding(question, api_token)
55
+ results = pinecone_index.query(vector=question_embedding.tolist(), top_k=top_k, include_metadata=True)
56
+
57
+ for i, match in enumerate(results['matches'], 1):
58
+ metadata = match['metadata']
59
+ #print(f"Processing chunk {i}: {metadata}")
60
+ segment_types = metadata['segment_types'].split(',')
61
+ page_numbers = [int(pn) for pn in metadata['page_numbers'].split(',')]
62
+
63
+ # Handle potential JSON decoding errors
64
+ try:
65
+ contents = json.loads(metadata['contents'])
66
+ except json.JSONDecodeError:
67
+ contents = [metadata['contents']] # Treat as a single content item if JSON decoding fails
68
+
69
+ prompt += f"Excerpt {i}:\n"
70
+ prompt += f"Pages: {', '.join(map(str, page_numbers))}\n"
71
+ prompt += f"Types: {', '.join(segment_types)}\n"
72
+
73
+ for j, content in enumerate(contents, 1):
74
+ if isinstance(content, str) and '[Image' in content:
75
+ prompt += f"Image content {j}: {content}\n"
76
+ else:
77
+ prompt += f"Text content {j}: {content[:200]}...\n" # Limit text content to 200 characters
78
+
79
+ prompt += "\n"
80
+
81
+ # Add only one full-page screenshot as a reference
82
+ if not included_pages and page_numbers:
83
+ page_num = page_numbers[0]
84
+ prompt += f"\nFull-page context for Page {page_num + 1}: [Full-page screenshot]\n"
85
+ buffered = io.BytesIO()
86
+ PILImage.fromarray(rich_pdf.pages[page_num].screenshot).save(buffered, format="PNG")
87
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
88
+ image_data.append({
89
+ "type": "image_url",
90
+ "image_url": {
91
+ "url": f"data:image/png;base64,{img_base64}",
92
+ "detail": "low"
93
+ }
94
+ })
95
+ included_pages.add(page_num)
96
+ #print(f"Added full-page screenshot of page {page_num + 1} to image_data")
97
+
98
+ prompt += "\nInstructions for answering the question:\n"
99
+ prompt += "1. Carefully review all provided excerpts.\n"
100
+ prompt += "2. Use the full-page screenshot to understand the overall context.\n"
101
+ prompt += "3. Refer to specific excerpts in your answer when applicable.\n"
102
+ prompt += "4. If the question asks for specific information, provide a clear and concise answer.\n"
103
+ prompt += "5. If the answer isn't directly stated, use the context to infer the most likely answer.\n\n"
104
+ prompt += f"Now, please answer the following question based on the provided information:\n{question}\n"
105
+
106
+ #print(f"\nTotal images included: {len(image_data)}")
107
+ return prompt, image_data
108
+
109
+
110
+ def query_gpt4o(question: str, rich_pdf, pinecone_index, api_token: str, gpt4_api_key: str):
111
+ client = OpenAI(api_key=gpt4_api_key)
112
+
113
+ prompt, image_data = create_few_shot_prompt(question, rich_pdf, pinecone_index, api_token)
114
+ #print("Prompt generated is:", prompt)
115
+ #print(f"\nNumber of images included: {len(image_data)}")
116
+
117
+ content_list = [{"type": "text", "text": prompt}] + image_data
118
+ #print(f"Total number of content items (text + images): {len(content_list)}")
119
+
120
+ try:
121
+ response = client.chat.completions.create(
122
+ model="gpt-4o", # Ensure this is the correct model name for GPT-4 with vision capabilities
123
+ messages=[
124
+ {
125
+ "role": "system",
126
+ "content": "You are an advanced AI assistant capable of analyzing various types of documents, including but not limited to research papers, financial reports, and general texts. Your task is to provide accurate and relevant answers to questions by carefully examining both textual and visual information provided from the document. When appropriate, cite specific excerpts, images, or page numbers in your responses. Explain your reasoning clearly, especially when making inferences or connections between different parts of the document."
127
+ },
128
+ {
129
+ "role": "user",
130
+ "content": content_list
131
+ }
132
+ ],
133
+ max_tokens=500 # Increased token limit for more detailed responses
134
+ )
135
+ return response.choices[0].message.content
136
+ except Exception as e:
137
+ print(f"Failed to execute GPT-4V query: {e}")
138
+ return None
139
+
140
+ import re
141
+
142
+ def format_answer(answer):
143
+ # Convert LaTeX-style math to Markdown-style math
144
+ answer = re.sub(r'\$\$(.*?)\$\$', r'\\[\1\\]', answer)
145
+ answer = re.sub(r'\$(.*?)\$', r'\\(\1\\)', answer)
146
+
147
+ # Format headers
148
+ lines = answer.split('\n')
149
+ formatted_lines = []
150
+ for line in lines:
151
+ if line.startswith('###'):
152
+ formatted_lines.append(f"\n{line}\n")
153
+ elif line.startswith('**') and line.endswith('**'):
154
+ formatted_lines.append(f"\n{line}\n")
155
+ else:
156
+ formatted_lines.append(line)
157
+
158
+ # Join lines back together
159
+ formatted_answer = '\n'.join(formatted_lines)
160
+
161
+ # Add spacing around math blocks
162
+ formatted_answer = re.sub(r'(\\\\.*?\\\\)', r'\n\1\n', formatted_answer)
163
+
164
+ return formatted_answer
165
+
166
+
167
+ # Example usage function
168
+ def get_answer(question: str, enriched_pdf:RichPDFDocument, jina_api_token: str, gpt4_api_key: str):
169
+ answer_generated = query_gpt4o(question, enriched_pdf, index, jina_api_token, gpt4_api_key)
170
+
171
+ #print(answer_generated)
172
+ if answer_generated:
173
+ return format_answer(answer_generated)
174
+ else:
175
+ return None
enriched_pdf.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e35595825f67b68880e9c7b3dec8ca15a3d92422a27e911974df592d89b39a8b
3
+ size 30402603
pdf_classes.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pdf_classes.py
2
+ from docarray import BaseDoc
3
+ from docarray import DocList
4
+ from docarray.typing import ImageTensor, NdArray
5
+ from typing import Dict, Optional
6
+
7
+ class PDFSegment(BaseDoc):
8
+ page_number: int
9
+ segment_type: str # 'text', 'image', 'table', or 'hybrid'
10
+ content: Optional[str]
11
+ image: Optional[ImageTensor]
12
+ position: Dict[str, int] # {x, y, width, height}
13
+ relationships: Dict[str, Optional[str]] # {'prev': id, 'next': id, 'parent': id}
14
+ embedding: Optional[NdArray[768]]
15
+
16
+ class PDFPage(BaseDoc):
17
+ page_number: int
18
+ screenshot: ImageTensor
19
+ embedding: Optional[NdArray[768]] = None
20
+
21
+ class RichPDFDocument(BaseDoc):
22
+ file_path: str
23
+ num_pages: int
24
+ segments: DocList[PDFSegment]
25
+ pages: DocList[PDFPage]