ImranzamanML commited on
Commit
bb4c945
1 Parent(s): bd1fdb8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from PyPDF2 import PdfReader
4
+ import google.generativeai as genai
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain.chains.question_answering import load_qa_chain
9
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+
12
+
13
+ def process_pdf_files(pdf_files, embedding_model_name):
14
+ text = ""
15
+ for pdf in pdf_files:
16
+ reader = PdfReader(pdf)
17
+ for page in reader.pages:
18
+ text += page.extract_text()
19
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=500)
20
+ text_chunks = text_splitter.split_text(text)
21
+ embeddings = GoogleGenerativeAIEmbeddings(model=embedding_model_name)
22
+ vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
23
+ vector_store.save_local("pdf_database")
24
+ return vector_store
25
+
26
+ def setup_qa_chain(chat_model_name):
27
+ prompt_template = """
28
+ Give answer to the asked question using the provided custom knowledge or given context only and if there is no related content then simply say "Your document dont contain related context to answer". Make sure to not answer incorrect.\n\n
29
+ Context:\n{context}\n
30
+ Question:\n{question}\n
31
+ Answer:
32
+ """
33
+ model = ChatGoogleGenerativeAI(model=chat_model_name, temperature=0.3)
34
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
35
+ return load_qa_chain(model, chain_type="stuff", prompt=prompt)
36
+
37
+ def get_response(user_question, chat_model_name, embedding_model_name):
38
+ embeddings = GoogleGenerativeAIEmbeddings(model=embedding_model_name)
39
+ vector_store = FAISS.load_local("pdf_database", embeddings, allow_dangerous_deserialization=True)
40
+ docs = vector_store.similarity_search(user_question)
41
+ chain = setup_qa_chain(chat_model_name)
42
+ response = chain(
43
+ {"input_documents": docs, "question": user_question},
44
+ return_only_outputs=True
45
+ )
46
+ return response["output_text"]
47
+
48
+ def main():
49
+ st.set_page_config(page_title="Talk to PDF", layout="wide")
50
+
51
+ st.markdown(
52
+ f"""
53
+ <style>
54
+ .stApp {{
55
+ background: url(data:image/png;base64,{get_base64_of_image('image.png')});
56
+ background-size: cover
57
+ }}
58
+ </style>
59
+ """,
60
+ unsafe_allow_html=True
61
+ )
62
+
63
+ st.title("Chat using Google Gemini Models")
64
+
65
+ st.subheader("Upload your PDF Files")
66
+ pdf_files = st.file_uploader("Upload your files", accept_multiple_files=True)
67
+ if st.button("Submit data") and pdf_files:
68
+ with st.spinner("Processing the data . . ."):
69
+ process_pdf_files(pdf_files, embedding_model_name)
70
+ st.success("Files submitted successfully")
71
+
72
+ st.sidebar.header("Configuration")
73
+ api_key = st.sidebar.text_input("Google API Key:", type="password")
74
+
75
+ default_chat_models = ["gemini-pro", "chat-model-2", "chat-model-3"]
76
+ selected_chat_model = st.sidebar.selectbox("Select a chat model", default_chat_models, index=0)
77
+ custom_chat_model = st.sidebar.text_input("Or enter a custom chat model name")
78
+
79
+ if custom_chat_model:
80
+ chat_model_name = custom_chat_model
81
+ else:
82
+ chat_model_name = selected_chat_model
83
+
84
+ default_embedding_models = ["models/embedding-001", "embedding-model-2", "embedding-model-3"]
85
+ selected_embedding_model = st.sidebar.selectbox("Select an embedding model", default_embedding_models, index=0)
86
+ custom_embedding_model = st.sidebar.text_input("Or enter a custom embedding model name")
87
+
88
+ if custom_embedding_model:
89
+ embedding_model_name = custom_embedding_model
90
+ else:
91
+ embedding_model_name = selected_embedding_model
92
+
93
+ if api_key:
94
+ genai.configure(api_key=api_key)
95
+
96
+ user_question = st.text_input("Ask questions from your custom knowledge base!")
97
+
98
+ if user_question:
99
+ with st.spinner("Generating response..."):
100
+ response = get_response(user_question, chat_model_name, embedding_model_name)
101
+ st.write("**Reply:** ", response)
102
+
103
+ else:
104
+ st.sidebar.warning("Please enter your Google API key!")
105
+
106
+ st.markdown("---")
107
+ st.write("Happy to Connect:")
108
+ kaggle, linkedin, google_scholar, youtube, github = st.columns(5)
109
+
110
+ image_urls = {
111
+ "kaggle": "https://www.kaggle.com/static/images/site-logo.svg",
112
+ "linkedin": "https://upload.wikimedia.org/wikipedia/commons/thumb/c/ca/LinkedIn_logo_initials.png/600px-LinkedIn_logo_initials.png",
113
+ "google_scholar": "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c7/Google_Scholar_logo.svg/768px-Google_Scholar_logo.svg.png",
114
+ "youtube": "https://upload.wikimedia.org/wikipedia/commons/thumb/7/72/YouTube_social_white_square_%282017%29.svg/640px-YouTube_social_white_square_%282017%29.svg.png",
115
+ "github": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"
116
+ }
117
+
118
+ social_links = {
119
+ "kaggle": "https://www.kaggle.com/muhammadimran112233",
120
+ "linkedin": "https://www.linkedin.com/in/muhammad-imran-zaman",
121
+ "google_scholar": "https://scholar.google.com/citations?user=ulVFpy8AAAAJ&hl=en",
122
+ "youtube": "https://www.youtube.com/@consolioo",
123
+ "github": "https://github.com/Imran-ml"
124
+ }
125
+
126
+ kaggle.markdown(f'<a href="{social_links["kaggle"]}"><img src="{image_urls["kaggle"]}" width="50" height="50"></a>', unsafe_allow_html=True)
127
+ linkedin.markdown(f'<a href="{social_links["linkedin"]}"><img src="{image_urls["linkedin"]}" width="50" height="50"></a>', unsafe_allow_html=True)
128
+ google_scholar.markdown(f'<a href="{social_links["google_scholar"]}"><img src="{image_urls["google_scholar"]}" width="50" height="50"></a>', unsafe_allow_html=True)
129
+ youtube.markdown(f'<a href="{social_links["youtube"]}"><img src="{image_urls["youtube"]}" width="50" height="50"></a>', unsafe_allow_html=True)
130
+ github.markdown(f'<a href="{social_links["github"]}"><img src="{image_urls["github"]}" width="50" height="50"></a>', unsafe_allow_html=True)
131
+ st.markdown("---")
132
+
133
+ def get_base64_of_image(image_path):
134
+ import base64
135
+ with open(image_path, "rb") as image_file:
136
+ base64_str = base64.b64encode(image_file.read()).decode()
137
+ return base64_str
138
+
139
+ if __name__ == "__main__":
140
+ main()