jeremierostan commited on
Commit
af1dd95
·
verified ·
1 Parent(s): f7cfaf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -138
app.py CHANGED
@@ -13,7 +13,7 @@ from langchain.chains import create_retrieval_chain
13
  import os
14
  import markdown2
15
 
16
- # Retrieve API keys from environment variables
17
  openai_api_key = os.getenv('OPENAI_API_KEY')
18
  groq_api_key = os.getenv('GROQ_API_KEY')
19
  google_api_key = os.getenv('GEMINI_API_KEY')
@@ -23,26 +23,9 @@ openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
23
  groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
24
  gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
25
 
26
- # Define paths for regulation PDFs
27
- regulation_pdfs = {
28
- "GDPR": "GDPR.pdf",
29
- "FERPA": "FERPA.pdf",
30
- "COPPA": "COPPA.pdf"
31
- }
32
-
33
- # Global variables
34
- full_pdf_content = ""
35
- vector_store = None
36
- rag_chain = None
37
- pdfs_loaded = False
38
-
39
  # Function to extract text from PDF
40
  def extract_pdf(pdf_path):
41
- try:
42
- return extract_text(pdf_path)
43
- except Exception as e:
44
- print(f"Error extracting text from {pdf_path}: {str(e)}")
45
- return ""
46
 
47
  # Function to split text into chunks
48
  def split_text(text):
@@ -57,38 +40,40 @@ def generate_embeddings(docs):
57
  # Function for query preprocessing and simple HyDE-Lite
58
  def preprocess_query(query):
59
  prompt = ChatPromptTemplate.from_template("""
60
- Your role is to optimize user queries for retrieval from regulatory documents such as GDPR, FERPA, COPPA, and/or others.
61
  Transform the query into a more affirmative, keyword-focused statement.
62
  The transformed query should look like probable related passages in the official documents.
 
63
  Query: {query}
 
64
  Optimized query:
65
  """)
66
  chain = prompt | openai_client
67
  return chain.invoke({"query": query}).content
68
 
69
  # Function to create RAG chain with Groq
70
- def create_rag_chain(vector_store):
71
  prompt = ChatPromptTemplate.from_messages([
72
- ("system", "You are an AI assistant helping with regulatory compliance queries. Use the following context from the official regulatory documents to answer the user's question:\n\n{context}"),
73
  ("human", "{input}")
74
  ])
75
  document_chain = create_stuff_documents_chain(groq_client, prompt)
76
  return create_retrieval_chain(vector_store.as_retriever(), document_chain)
77
 
78
  # Function for Gemini response with long context
79
- def gemini_response(query, full_content):
80
  prompt = ChatPromptTemplate.from_messages([
81
- ("system", "You are an AI assistant helping with regulatory compliance queries. Use the following full content of the official regulatory documents to answer the user's question:\n\n{context}"),
82
  ("human", "{input}")
83
  ])
84
  chain = prompt | gemini_client
85
- return chain.invoke({"context": full_content, "input": query}).content
86
 
87
  # Function to generate final response
88
  def generate_final_response(response1, response2):
89
  prompt = ChatPromptTemplate.from_template("""
90
- You are an AI assistant helping educators understand and implement data protection and regulatory compliance (GDPR, FERPA, COPPA, and/or others).
91
- Your goal is to provide simple, practical explanation of and advice on how to meet regulatory requirements based on the given responses.
92
  To do so:
93
  1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done
94
  internally as a hidden state.
@@ -99,128 +84,52 @@ def generate_final_response(response1, response2):
99
  chain = prompt | openai_client
100
  return chain.invoke({"response1": response1, "response2": response2}).content
101
 
102
- def markdown_to_html(content):
103
- return markdown2.markdown(content)
104
-
105
- def load_pdfs(gdpr, ferpa, coppa, additional_pdfs):
106
- global full_pdf_content, vector_store, rag_chain, pdfs_loaded
107
-
108
- documents = []
109
- full_pdf_content = ""
110
-
111
- # Load selected regulation PDFs
112
- selected_regulations = []
113
- if gdpr:
114
- selected_regulations.append("GDPR")
115
- if ferpa:
116
- selected_regulations.append("FERPA")
117
- if coppa:
118
- selected_regulations.append("COPPA")
119
-
120
- for regulation in selected_regulations:
121
- if regulation in regulation_pdfs:
122
- pdf_path = regulation_pdfs[regulation]
123
- if os.path.exists(pdf_path):
124
- pdf_content = extract_pdf(pdf_path)
125
- if pdf_content:
126
- full_pdf_content += pdf_content + "\n\n"
127
- documents.extend(split_text(pdf_content))
128
- print(f"Loaded {regulation} PDF")
129
- else:
130
- print(f"Failed to extract content from {regulation} PDF")
131
- else:
132
- print(f"PDF file for {regulation} not found at {pdf_path}")
133
-
134
- # Load additional user-uploaded PDFs
135
- if additional_pdfs is not None:
136
- for pdf_file in additional_pdfs:
137
- pdf_content = extract_pdf(pdf_file.name)
138
- if pdf_content:
139
- full_pdf_content += pdf_content + "\n\n"
140
- documents.extend(split_text(pdf_content))
141
- print(f"Loaded additional PDF: {pdf_file.name}")
142
- else:
143
- print(f"Failed to extract content from uploaded PDF: {pdf_file.name}")
144
-
145
- if not documents:
146
- pdfs_loaded = False
147
- return "No PDFs were successfully loaded. Please check your selections and uploads."
148
-
149
- print(f"Total documents loaded: {len(documents)}")
150
- print(f"Total content length: {len(full_pdf_content)} characters")
151
-
152
- vector_store = generate_embeddings(documents)
153
- rag_chain = create_rag_chain(vector_store)
154
-
155
- pdfs_loaded = True
156
- return f"PDFs loaded and RAG system updated successfully! Loaded {len(documents)} document chunks."
157
-
158
  def process_query(user_query):
159
- global rag_chain, full_pdf_content, pdfs_loaded
160
-
161
- if not pdfs_loaded:
162
- return ("Please load PDFs before asking questions.",
163
- "Please load PDFs before asking questions.",
164
- "Please load PDFs and initialize the system before asking questions.")
165
-
166
  preprocessed_query = preprocess_query(user_query)
 
 
167
 
168
- # Get RAG response using Groq
169
  rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
170
 
171
- # Get Gemini response with full PDF content
172
- gemini_resp = gemini_response(preprocessed_query, full_pdf_content)
173
 
174
  final_response = generate_final_response(rag_response, gemini_resp)
175
  html_content = markdown_to_html(final_response)
176
 
177
  return rag_response, gemini_resp, html_content
178
 
179
- # Gradio interface
180
- with gr.Blocks() as iface:
181
- gr.Markdown("# Data Protection Team")
182
- gr.Markdown("Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions.")
183
-
184
- with gr.Row():
185
- gdpr_checkbox = gr.Checkbox(label="GDPR (EU)")
186
- ferpa_checkbox = gr.Checkbox(label="FERPA (US)")
187
- coppa_checkbox = gr.Checkbox(label="COPPA (US <13)")
188
-
189
- gr.Markdown("**Optional: upload additional PDFs if needed (national regulation, school policy)**")
190
- additional_pdfs = gr.File(
191
- file_count="multiple",
192
- label="Upload additional PDFs",
193
- file_types=[".pdf"],
194
- elem_id="file_upload"
195
- )
196
-
197
- load_button = gr.Button("Load PDFs")
198
- load_output = gr.Textbox(label="Load Status")
199
-
200
- gr.Markdown("**Ask your data protection related question**")
201
- query_input = gr.Textbox(label="Your Question", placeholder="Ask your question here...")
202
- query_button = gr.Button("Submit Query")
203
-
204
- gr.Markdown("**Results**")
205
- rag_output = gr.Textbox(label="RAG Pipeline (Llama3.1) Response")
206
- gemini_output = gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response")
207
- final_output = gr.HTML(label="Final (GPT-4o) Response")
208
-
209
- load_button.click(
210
- load_pdfs,
211
- inputs=[
212
- gdpr_checkbox,
213
- ferpa_checkbox,
214
- coppa_checkbox,
215
- additional_pdfs
216
- ],
217
- outputs=load_output
218
- )
219
 
220
- query_button.click(
221
- process_query,
222
- inputs=query_input,
223
- outputs=[rag_output, gemini_output, final_output]
224
- )
 
 
 
 
 
 
 
 
225
 
226
  iface.launch()
 
13
  import os
14
  import markdown2
15
 
16
+ # Retrieve API keys from HF secrets
17
  openai_api_key = os.getenv('OPENAI_API_KEY')
18
  groq_api_key = os.getenv('GROQ_API_KEY')
19
  google_api_key = os.getenv('GEMINI_API_KEY')
 
23
  groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
24
  gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Function to extract text from PDF
27
  def extract_pdf(pdf_path):
28
+ return extract_text(pdf_path)
 
 
 
 
29
 
30
  # Function to split text into chunks
31
  def split_text(text):
 
40
  # Function for query preprocessing and simple HyDE-Lite
41
  def preprocess_query(query):
42
  prompt = ChatPromptTemplate.from_template("""
43
+ Your role is to optimize user queries for retrieval from official regulation documents about data protection.
44
  Transform the query into a more affirmative, keyword-focused statement.
45
  The transformed query should look like probable related passages in the official documents.
46
+
47
  Query: {query}
48
+
49
  Optimized query:
50
  """)
51
  chain = prompt | openai_client
52
  return chain.invoke({"query": query}).content
53
 
54
  # Function to create RAG chain with Groq
55
+ def create_rag_chain():
56
  prompt = ChatPromptTemplate.from_messages([
57
+ ("system", "You are an AI assistant helping with data protection related queries. Use the following context from the official regulation documents to answer the user's question:\n\n{context}"),
58
  ("human", "{input}")
59
  ])
60
  document_chain = create_stuff_documents_chain(groq_client, prompt)
61
  return create_retrieval_chain(vector_store.as_retriever(), document_chain)
62
 
63
  # Function for Gemini response with long context
64
+ def gemini_response(query):
65
  prompt = ChatPromptTemplate.from_messages([
66
+ ("system", "You are an AI assistant helping with data protection related queries. Use the following full content of the official regulation documents to answer the user's question:\n\n{context}"),
67
  ("human", "{input}")
68
  ])
69
  chain = prompt | gemini_client
70
+ return chain.invoke({"context": full_pdf_content, "input": query}).content
71
 
72
  # Function to generate final response
73
  def generate_final_response(response1, response2):
74
  prompt = ChatPromptTemplate.from_template("""
75
+ You are an AI assistant helping educators understand and implement data protection and compliance with official regulations when using AI.
76
+ Your goal is to provide simple, practical explanation of and advice on how to meet these regulatory requirements based on the 2 given responses.
77
  To do so:
78
  1. Analyze the following two responses. Inspect their content, and highlight differences. This MUST be done
79
  internally as a hidden state.
 
84
  chain = prompt | openai_client
85
  return chain.invoke({"response1": response1, "response2": response2}).content
86
 
87
+ # Function to process the query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def process_query(user_query):
 
 
 
 
 
 
 
89
  preprocessed_query = preprocess_query(user_query)
90
+ print(f"Original query: {user_query}")
91
+ print(f"Preprocessed query: {preprocessed_query}")
92
 
93
+ # Get RAG response using Groq with the preprocessed query
94
  rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
95
 
96
+ # Get Gemini response with full PDF content and preprocessed query
97
+ gemini_resp = gemini_response(preprocessed_query)
98
 
99
  final_response = generate_final_response(rag_response, gemini_resp)
100
  html_content = markdown_to_html(final_response)
101
 
102
  return rag_response, gemini_resp, html_content
103
 
104
+ # Initialize
105
+ pdf_paths = ["GDPR.pdf", "FERPA.pdf", "COPPA.pdf"]
106
+ full_pdf_content = ""
107
+ all_documents = []
108
+
109
+ for pdf_path in pdf_paths:
110
+ extracted_text = extract_pdf(pdf_path)
111
+ full_pdf_content += extracted_text + "\n\n"
112
+ all_documents.extend(split_text(extracted_text))
113
+
114
+ vector_store = generate_embeddings(all_documents)
115
+ rag_chain = create_rag_chain()
116
+
117
+ # Function to output the final response as markdown
118
+ def markdown_to_html(content):
119
+ return markdown2.markdown(content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # Gradio interface
122
+ iface = gr.Interface(
123
+ fn=process_query,
124
+ inputs=gr.Textbox(label="Ask your data protection related question"),
125
+ outputs=[
126
+ gr.Textbox(label="RAG Pipeline (Llama3.1) Response"),
127
+ gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"),
128
+ gr.HTML(label="Final (GPT-4o) Response")
129
+ ],
130
+ title="Data Protection Team",
131
+ description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions (GDPR, FERPA, COPPA).",
132
+ allow_flagging="never"
133
+ )
134
 
135
  iface.launch()