Guna0pro commited on
Commit
c53fb3a
β€’
1 Parent(s): c29aba1

created a app.py file

Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !pip -q install git+https://github.com/huggingface/transformers # need to install from github
2
+ !pip install -q datasets loralib sentencepiece
3
+ !pip -q install bitsandbytes accelerate xformers
4
+ !pip -q install langchain
5
+ !pip -q install gradio
6
+
7
+ !pip -q install peft chromadb
8
+ !pip -q install unstructured
9
+ !pip install -q sentence_transformers
10
+ !pip -q install pypdf
11
+
12
+ from google.colab import drive
13
+ drive.mount('/content/drive')
14
+
15
+ """## LLaMA2 7B Chat
16
+
17
+ """
18
+
19
+ import torch
20
+ from peft import PeftModel, PeftConfig
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
22
+
23
+ bnb_config = BitsAndBytesConfig(load_in_4bit=True,
24
+ bnb_4bit_quant_type="nf4",
25
+ bnb_4bit_compute_dtype=torch.bfloat16,
26
+ bnb_4bit_use_double_quant=False)
27
+
28
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
29
+
30
+ #daryl149/llama-2-7b-chat-hf
31
+ #meta-llama/Llama-2-7b-chat-hf
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(model_id,token='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')
34
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config = bnb_config,device_map={"":0},token='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')
35
+
36
+ import json
37
+ import textwrap
38
+
39
+ B_INST, E_INST = "[INST]", "[/INST]"
40
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
41
+ DEFAULT_SYSTEM_PROMPT = """\
42
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
43
+
44
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
45
+
46
+
47
+
48
+ def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
49
+ SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
50
+ prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
51
+ return prompt_template
52
+
53
+ from langchain.embeddings import HuggingFaceEmbeddings
54
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
55
+ from langchain.vectorstores import Chroma
56
+ from langchain.document_loaders import PyPDFLoader
57
+
58
+ loader = PyPDFLoader("/content/drive/MyDrive/Gen AI and LLM/data/data.pdf")
59
+
60
+ text_splitter = RecursiveCharacterTextSplitter(
61
+ # Set a really small chunk size, just to show.
62
+ chunk_size = 800,
63
+ chunk_overlap = 50,
64
+ length_function = len,
65
+ )
66
+
67
+ pages = loader.load_and_split(text_splitter)
68
+
69
+ db = Chroma.from_documents(pages, HuggingFaceEmbeddings())
70
+
71
+ instruction = "Given the context that has been provided. \n {context}, Answer the following question - \n{question}"
72
+
73
+ system_prompt = """You are an expert in question and answering.
74
+ You will be given a context to answer from. Be precise in your answers wherever possible.
75
+ In case you are sure you don't know the answer then you say that based on the context you don't know the answer.
76
+ In all other instances you provide an answer to the best of your capability. Cite urls when you can access them related to the context."""
77
+
78
+ get_prompt(instruction, system_prompt)
79
+
80
+ """## Setting up with LangChain"""
81
+
82
+ from langchain import HuggingFacePipeline
83
+ from langchain import PromptTemplate, LLMChain
84
+ from langchain.chains import ConversationalRetrievalChain
85
+ from langchain.memory import ConversationBufferWindowMemory
86
+
87
+ template = get_prompt(instruction, system_prompt)
88
+ print(template)
89
+
90
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
91
+
92
+ memory = ConversationBufferWindowMemory(
93
+ memory_key="chat_history", k=5,
94
+ return_messages=True
95
+ )
96
+
97
+ retriever = db.as_retriever()
98
+
99
+ def create_pipeline(max_new_tokens=512):
100
+ pipe = pipeline("text-generation",
101
+ model=model,
102
+ tokenizer = tokenizer,
103
+ max_new_tokens = max_new_tokens,
104
+ temperature = 0.1)
105
+ return pipe
106
+
107
+ class GunaBot:
108
+ def __init__(self, memory, prompt, task:str = "text-generation", retriever = retriever):
109
+ self.memory = memory
110
+ self.prompt = prompt
111
+ self.retriever = retriever
112
+
113
+
114
+
115
+ def create_chat_bot(self, max_new_tokens = 512):
116
+ hf_pipe = create_pipeline(max_new_tokens)
117
+ llm = HuggingFacePipeline(pipeline =hf_pipe)
118
+ qa = ConversationalRetrievalChain.from_llm(
119
+ llm=llm,
120
+ retriever=self.retriever,
121
+ memory=self.memory,
122
+ combine_docs_chain_kwargs={"prompt": self.prompt}
123
+ )
124
+ return qa
125
+
126
+ Guna_bot = GunaBot(memory = memory, prompt = prompt)
127
+
128
+ bot = Guna_bot.create_chat_bot()
129
+
130
+ import gradio as gr
131
+ import random
132
+ import time
133
+
134
+ def clear_llm_memory():
135
+ bot.memory.clear()
136
+
137
+ def update_prompt(sys_prompt):
138
+ if sys_prompt == "":
139
+ sys_prompt = system_prompt
140
+ template = get_prompt(instruction, sys_prompt)
141
+
142
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
143
+
144
+ bot.combine_docs_chain.llm_chain.prompt = prompt
145
+
146
+ with gr.Blocks() as demo:
147
+ update_sys_prompt = gr.Textbox(label = "Update System Prompt")
148
+ chatbot = gr.Chatbot(label="Guna Bot", height = 300)
149
+ msg = gr.Textbox(label = "Question")
150
+ clear = gr.ClearButton([msg, chatbot])
151
+ clear_memory = gr.Button(value = "Clear LLM Memory")
152
+
153
+
154
+ def respond(message, chat_history):
155
+ bot_message = bot({"question": message})['answer']
156
+ chat_history.append((message, bot_message))
157
+ return "", chat_history
158
+
159
+ msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
160
+ clear_memory.click(clear_llm_memory)
161
+ update_sys_prompt.submit(update_prompt, inputs=update_sys_prompt)
162
+
163
+ demo.launch(share=False, debug=True)
164
+