AalianK commited on
Commit
32f3cb0
Β·
verified Β·
1 Parent(s): 8941872

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+
5
+ from langchain.document_loaders import PyPDFLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.llms import HuggingFaceHub
12
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
13
+ from langchain.prompts import PromptTemplate
14
+ #from langchain.chains import (
15
+ # StuffDocumentsChain, LLMChain, ConversationalRetrievalChain
16
+ #)
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
18
+
19
+ # Static model name
20
+ llm_name = "meta-llama/Llama-2-7b-chat-hf"
21
+
22
+ # Static file path for multiple files
23
+ static_file_paths = [
24
+ "IRM ISO_IEC_27001_2022(en) (1).pdf",
25
+ #"SCF - Cybersecurity & Data Privacy Capability Maturity Model (CP-CMM) (2023.4).pdf",
26
+ #"AG_Level1_V2.0_Final_20211210.pdf",
27
+ #"CIS_Controls_v8_v21.10.pdf",
28
+ #"CSF PDF v11.1.0-1.pdf",
29
+ #"ISO_31000_2018(en)-1.pdf",
30
+ #"OWASP Application Security Verification Standard 4.0.3-en-1.pdf",
31
+ #"NIST.CSWP.29.ipd The NIST Cybersecurity Framework 2.0 202308-1 (1).pdf",
32
+ #"ISO_IEC_27002_2022(en)-1.pdf",
33
+ ]
34
+
35
+ # Use cuda for faster processing
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
37
+
38
+ # Load documents
39
+ loaders = [PyPDFLoader(x) for x in static_file_paths]
40
+ pages = []
41
+ for loader in loaders:
42
+ pages.extend(loader.load())
43
+ text_splitter = RecursiveCharacterTextSplitter(
44
+ chunk_size=600,
45
+ chunk_overlap=40,
46
+ )
47
+ doc_splits = text_splitter.split_documents(pages)
48
+ embedding = HuggingFaceEmbeddings()
49
+ vectordb = Chroma.from_documents(
50
+ documents=doc_splits,
51
+ embedding=embedding,
52
+ )
53
+
54
+ # Load model
55
+ tokenizer = AutoTokenizer.from_pretrained(llm_name, token=os.environ['HUGGINGFACEHUB_API_TOKEN'],)
56
+ model = AutoModelForCausalLM.from_pretrained(llm_name, token=os.environ['HUGGINGFACEHUB_API_TOKEN'], torch_dtype=torch.float16)
57
+ model = model.to(device)
58
+
59
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, device=device, token=os.environ['HUGGINGFACEHUB_API_TOKEN'])
60
+ hf = HuggingFacePipeline(pipeline=pipe)
61
+
62
+ # Set up template and memory
63
+ template = """You are a helpful and appreciative cybersecurity expert who gives comprehensive answers using lists, step-by-step instructions and other aids. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
64
+ {context}
65
+ Question: {question}
66
+ Helpful Answer:
67
+ """
68
+ prompt = PromptTemplate.from_template(template)
69
+ memory = ConversationBufferMemory(
70
+ memory_key="chat_history",
71
+ output_key='answer',
72
+ return_messages=True
73
+ )
74
+ retriever = vectordb.as_retriever()
75
+ qachain = ConversationalRetrievalChain.from_llm(
76
+ hf,
77
+ retriever=retriever,
78
+ chain_type="stuff",
79
+ memory=memory,
80
+ return_source_documents=True,
81
+ combine_docs_chain_kwargs={
82
+ "prompt": prompt,
83
+ }
84
+ )
85
+
86
+ def format_chat_history(message, chat_history):
87
+ formatted_chat_history = []
88
+ for user_message, bot_message in chat_history:
89
+ formatted_chat_history.append(f"User: {user_message}")
90
+ formatted_chat_history.append(f"Assistant: {bot_message}")
91
+ return formatted_chat_history
92
+
93
+ # Conversation with chatbot
94
+ def conversation(qa_chain, message, history):
95
+ formatted_chat_history = format_chat_history(message, history)
96
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
97
+ response_answer = response["answer"]
98
+ response_sources = response["source_documents"]
99
+ response_source1 = response_sources[0].page_content.strip()
100
+ response_source2 = response_sources[1].page_content.strip()
101
+ response_source1_page = response_sources[0].metadata["page"] + 1
102
+ response_source2_page = response_sources[1].metadata["page"] + 1
103
+ new_history = history + [(message, response_answer)]
104
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
105
+
106
+ def demo():
107
+ with gr.Blocks(theme="base") as demo:
108
+ qa_chain = gr.State(qachain)
109
+
110
+ gr.Markdown(
111
+ """<center><h2>Context Chatbot</center></h2>
112
+ <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
113
+ When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
114
+ """)
115
+
116
+ # Conversation with chatbot
117
+ with gr.Tab("Step 3 - Conversation with chatbot"):
118
+ chatbot = gr.Chatbot(height=600)
119
+ with gr.Row():
120
+ msg = gr.Textbox(placeholder="Type message", container=True)
121
+ with gr.Row():
122
+ submit_btn = gr.Button("Submit")
123
+ clear_btn = gr.ClearButton([msg, chatbot])
124
+ with gr.Accordion("Advanced - Document references", open=False):
125
+ with gr.Row():
126
+ response_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
127
+ response_source1_page = gr.Number(label="Page", scale=1)
128
+ with gr.Row():
129
+ response_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
130
+ response_source2_page = gr.Number(label="Page", scale=1)
131
+
132
+ # Preprocessing events
133
+ #db_btn.click(initialize_database, outputs=[vector_db, db_progress])
134
+
135
+ # Chatbot events
136
+ submit_btn.click(conversation, \
137
+ inputs=[qa_chain, msg, chatbot], \
138
+ outputs=[qa_chain, msg, chatbot, response_source1, response_source1_page, response_source2, response_source2_page], \
139
+ queue=False)
140
+ clear_btn.click(lambda:[None,"",0,"",0], \
141
+ inputs=None, \
142
+ outputs=[chatbot], \
143
+ queue=False)
144
+
145
+ demo.queue().launch(debug=True)
146
+
147
+ if __name__ == "__main__":
148
+ demo()