jeremierostan commited on
Commit
59b3a37
1 Parent(s): 20251e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pdfminer.high_level import extract_text
3
+ from langchain_groq import ChatGroq
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.embeddings import OpenAIEmbeddings
8
+ from langchain.schema import Document
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain.prompts import ChatPromptTemplate
11
+ from langchain.chains.combine_documents import create_stuff_documents_chain
12
+ from langchain.chains import create_retrieval_chain
13
+ import os
14
+ import markdown2
15
+
16
+ # Retrieve API keys from HF secrets
17
+ openai_api_key=os.getenv('OPENAI_API_KEY')
18
+ groq_api_key=os.getenv('GROQ_API_KEY')
19
+ google_api_key=os.getenv('GEMINI')
20
+
21
+ # Initialize API clients with the API keys
22
+ openai_client = ChatOpenAI(model_name="gpt-4o", api_key=openai_api_key)
23
+ groq_client = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, api_key=groq_api_key)
24
+ gemini_client = ChatGoogleGenerativeAI(model="gemini-1.5-pro", api_key=google_api_key)
25
+
26
+ # Function to extract text from PDF
27
+ def extract_pdf(pdf_path):
28
+ return extract_text(pdf_path)
29
+
30
+ # Function to split text into chunks
31
+ def split_text(text):
32
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
33
+ return [Document(page_content=t) for t in splitter.split_text(text)]
34
+
35
+ # Function to generate embeddings and store in vector database
36
+ def generate_embeddings(docs):
37
+ embeddings = OpenAIEmbeddings(api_key=openai_api_key)
38
+ return FAISS.from_documents(docs, embeddings)
39
+
40
+ # Function for query preprocessing and simple HyDE-Lite
41
+ def preprocess_query(query):
42
+ prompt = ChatPromptTemplate.from_template("""
43
+ Your role is to optimize user queries for retrieval from a GDPR regulation document.
44
+ Transform the query into a more affirmative, keyword-focused statement.
45
+ The transformed query should look like probable related passages in the official document.
46
+
47
+ Query: {query}
48
+
49
+ Optimized query:
50
+ """)
51
+ chain = prompt | openai_client
52
+ return chain.invoke({"query": query}).content
53
+
54
+ # Function to create RAG chain with Groq
55
+ def create_rag_chain():
56
+ prompt = ChatPromptTemplate.from_messages([
57
+ ("system", "You are an AI assistant helping with GDPR-related queries. Use the following context from the official GDPR regulation document to answer the user's question:\n\n{context}"),
58
+ ("human", "{input}")
59
+ ])
60
+ document_chain = create_stuff_documents_chain(groq_client, prompt)
61
+ return create_retrieval_chain(vector_store.as_retriever(), document_chain)
62
+
63
+ # Function for Gemini response with long context
64
+ def gemini_response(query):
65
+ prompt = ChatPromptTemplate.from_messages([
66
+ ("system", "You are an AI assistant helping with GDPR-related queries. Use the following full content of the official GDPR regulation document to answer the user's question:\n\n{context}"),
67
+ ("human", "{input}")
68
+ ])
69
+ chain = prompt | gemini_client
70
+ return chain.invoke({"context": full_pdf_content, "input": query}).content
71
+
72
+ # Function to generate final response
73
+ def generate_final_response(response1, response2):
74
+ prompt = ChatPromptTemplate.from_template("""
75
+ You are an AI assistant helping educators understand and implement AI data protection and GDPR compliance.
76
+ Your goal is to provide simple, practical explanation of and advice on how to meet GDPR requirements based on the given responses.
77
+ To do so, analyze the following two responses, combining similar elements and highlighting any differences. This MUST be done
78
+ internally as a hidden state. Only output your own final response.
79
+ If the responses contradict each other on important points, include that in your response.
80
+ """)
81
+ chain = prompt | openai_client
82
+ return chain.invoke({"response1": response1, "response2": response2}).content
83
+
84
+ def markdown_to_html(content):
85
+ return markdown2.markdown(content)
86
+
87
+ def process_query(user_query):
88
+ preprocessed_query = preprocess_query(user_query)
89
+
90
+ # Get RAG response using Groq
91
+ rag_response = rag_chain.invoke({"input": preprocessed_query})["answer"]
92
+
93
+ # Get Gemini response with full PDF content
94
+ gemini_resp = gemini_response(preprocessed_query)
95
+
96
+ final_response = generate_final_response(rag_response, gemini_resp)
97
+ html_content = markdown_to_html(final_response)
98
+
99
+ return rag_response, gemini_resp, html_content
100
+
101
+ # Initialize
102
+ GDPR_PDF_PATH = "/content/GDPR.pdf"
103
+ full_pdf_content = extract_pdf(GDPR_PDF_PATH)
104
+ extracted_text = extract_pdf(GDPR_PDF_PATH)
105
+ documents = split_text(extracted_text)
106
+ vector_store = generate_embeddings(documents)
107
+ rag_chain = create_rag_chain()
108
+
109
+ # Gradio interface
110
+ iface = gr.Interface(
111
+ fn=process_query,
112
+ inputs=gr.Textbox(label="Ask your data protection related question"),
113
+ outputs=[
114
+ gr.Textbox(label="RAG Pipeline (Llama3.1) Response"),
115
+ gr.Textbox(label="Long Context (Gemini 1.5 Pro) Response"),
116
+ gr.HTML(label="Final (GPT-4o) Response")
117
+ ],
118
+ title="Data Protection Team",
119
+ description="Get responses combining advanced RAG, Long Context, and SOTA models to data protection related questions .",
120
+ allow_flagging="never"
121
+ )
122
+
123
+ iface.launch(debug=True)