hieult44 commited on
Commit
113ebc1
1 Parent(s): 8c7b475

Add application file

Browse files
Files changed (1) hide show
  1. app.py +410 -176
app.py CHANGED
@@ -1,194 +1,428 @@
1
- """
2
- This module provides functions for working with PDF files and URLs. It uses the urllib.request library
3
- to download files from URLs, and the fitz library to extract text from PDF files. And GPT3 modules to generate
4
- text completions.
5
- """
6
- import urllib.request
7
- import fitz
8
- import re
9
- import numpy as np
10
- import tensorflow_hub as hub
11
  import openai
12
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import os
14
- from sklearn.neighbors import NearestNeighbors
15
-
16
- def download_pdf(url, output_path):
17
- urllib.request.urlretrieve(url, output_path)
18
-
19
-
20
- def preprocess(text):
21
- text = text.replace('\n', ' ')
22
- text = re.sub('\s+', ' ', text)
23
- return text
24
-
25
-
26
- def pdf_to_text(path, start_page=1, end_page=None):
27
- doc = fitz.open(path)
28
- total_pages = doc.page_count
29
-
30
- if end_page is None:
31
- end_page = total_pages
32
-
33
- text_list = []
34
-
35
- for i in range(start_page-1, end_page):
36
- text = doc.load_page(i).get_text("text")
37
- text = preprocess(text)
38
- text_list.append(text)
39
-
40
- doc.close()
41
- return text_list
42
-
43
-
44
- def text_to_chunks(texts, word_length=150, start_page=1):
45
- text_toks = [t.split(' ') for t in texts]
46
- page_nums = []
47
- chunks = []
48
-
49
- for idx, words in enumerate(text_toks):
50
- for i in range(0, len(words), word_length):
51
- chunk = words[i:i+word_length]
52
- if (i+word_length) > len(words) and (len(chunk) < word_length) and (
53
- len(text_toks) != (idx+1)):
54
- text_toks[idx+1] = chunk + text_toks[idx+1]
55
- continue
56
- chunk = ' '.join(chunk).strip()
57
- chunk = f'[{idx+start_page}]' + ' ' + '"' + chunk + '"'
58
- chunks.append(chunk)
59
- return chunks
60
-
61
-
62
- class SemanticSearch:
63
-
64
- def __init__(self):
65
- self.use = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
66
- self.fitted = False
67
-
68
-
69
- def fit(self, data, batch=1000, n_neighbors=5):
70
- self.data = data
71
- self.embeddings = self.get_text_embedding(data, batch=batch)
72
- n_neighbors = min(n_neighbors, len(self.embeddings))
73
- self.nn = NearestNeighbors(n_neighbors=n_neighbors)
74
- self.nn.fit(self.embeddings)
75
- self.fitted = True
 
 
 
 
 
 
 
 
76
 
77
-
78
- def __call__(self, text, return_data=True):
79
- inp_emb = self.use([text])
80
- neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0]
81
-
82
- if return_data:
83
- return [self.data[i] for i in neighbors]
84
- else:
85
- return neighbors
86
-
87
-
88
- def get_text_embedding(self, texts, batch=1000):
89
- embeddings = []
90
- for i in range(0, len(texts), batch):
91
- text_batch = texts[i:(i+batch)]
92
- emb_batch = self.use(text_batch)
93
- embeddings.append(emb_batch)
94
- embeddings = np.vstack(embeddings)
95
- return embeddings
96
-
97
-
98
-
99
- def load_recommender(path, start_page=1):
100
- global recommender
101
- texts = pdf_to_text(path, start_page=start_page)
102
- chunks = text_to_chunks(texts, start_page=start_page)
103
- recommender.fit(chunks)
104
- return 'Corpus Loaded.'
105
-
106
- def generate_text(openAI_key,prompt, engine="text-davinci-003"):
107
- openai.api_key = openAI_key
108
- completions = openai.Completion.create(
109
- engine=engine,
110
- prompt=prompt,
111
- max_tokens=512,
112
- n=1,
113
- stop=None,
114
- temperature=0.7,
115
  )
116
- message = completions.choices[0].text
117
- return message
118
-
119
- def generate_answer(question,openAI_key):
120
- topn_chunks = recommender(question)
121
- prompt = ""
122
- prompt += 'search results:\n\n'
123
- for c in topn_chunks:
124
- prompt += c + '\n\n'
125
-
126
- prompt += "Instructions: Compose a comprehensive reply to the query using the search results given. "\
127
- "Cite each reference using [ Page Number] notation (every result has this number at the beginning). "\
128
- "Citation should be done at the end of each sentence. If the search results mention multiple subjects "\
129
- "with the same name, create separate answers for each. Only include information found in the results and "\
130
- "don't add any additional information. Make sure the answer is correct and don't output false content. "\
131
- "If the text does not relate to the query, simply state 'Text Not Found in PDF'. Ignore outlier "\
132
- "search results which has nothing to do with the question. Only answer what is asked. The "\
133
- "answer should be short and concise. Answer step-by-step. \n\nQuery: {question}\nAnswer: "
134
 
135
- prompt += f"Query: {question}\nAnswer:"
136
- answer = generate_text(openAI_key, prompt,"text-davinci-003")
137
- return answer
138
-
139
-
140
- def question_answer(url, file, question,openAI_key):
141
- if openAI_key.strip()=='':
142
- return '[ERROR]: Please enter you Open AI Key. Get your key here : https://platform.openai.com/account/api-keys'
143
- if url.strip() == '' and file == None:
144
- return '[ERROR]: Both URL and PDF is empty. Provide atleast one.'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- if url.strip() != '' and file != None:
147
- return '[ERROR]: Both URL and PDF is provided. Please provide only one (eiter URL or PDF).'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if url.strip() != '':
150
- glob_url = url
151
- download_pdf(glob_url, 'corpus.pdf')
152
- load_recommender('corpus.pdf')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- else:
155
- old_file_name = file.name
156
- file_name = file.name
157
- file_name = file_name[:-12] + file_name[-4:]
158
- os.rename(old_file_name, file_name)
159
- load_recommender(file_name)
160
 
161
- if question.strip() == '':
162
- return '[ERROR]: Question field is empty'
163
 
164
- return generate_answer(question,openAI_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- recommender = SemanticSearch()
 
168
 
169
- title = 'PDF GPT'
170
- description = """ PDF GPT allows you to chat with your PDF file using Universal Sentence Encoder and Open AI. It gives hallucination free response than other tools as the embeddings are better than OpenAI. The returned response can even cite the page number in square brackets([]) where the information is located, adding credibility to the responses and helping to locate pertinent information quickly."""
 
171
 
172
- with gr.Blocks() as demo:
 
173
 
174
- gr.Markdown(f'<center><h1>{title}</h1></center>')
175
- gr.Markdown(description)
176
 
177
- with gr.Row():
178
-
179
- with gr.Group():
180
- gr.Markdown(f'<p style="text-align:center">Get your Open AI API key <a href="https://platform.openai.com/account/api-keys">here</a></p>')
181
- openAI_key=gr.Textbox(label='Enter your OpenAI API key here')
182
- url = gr.Textbox(label='Enter PDF URL here')
183
- gr.Markdown("<center><h4>OR<h4></center>")
184
- file = gr.File(label='Upload your PDF/ Research Paper / Book here', file_types=['.pdf'])
185
- question = gr.Textbox(label='Enter your question here')
186
- btn = gr.Button(value='Submit')
187
- btn.style(full_width=True)
 
 
 
188
 
189
- with gr.Group():
190
- answer = gr.Textbox(label='The answer to your question is :')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- btn.click(question_answer, inputs=[url, file, question,openAI_key], outputs=[answer])
193
- #openai.api_key = os.getenv('Your_Key_Here')
194
- demo.launch()
 
1
+ import streamlit as st
 
 
 
 
 
 
 
 
 
2
  import openai
3
+ from streamlit_chat import message as st_message
4
+ from transformers import BlenderbotTokenizer
5
+ from transformers import BlenderbotForConditionalGeneration
6
+ from io import StringIO
7
+ from io import BytesIO
8
+ import requests
9
+ import torch
10
+ import PyPDF2
11
+ from transformers import GenerationConfig, LlamaTokenizer, LlamaForCausalLM
12
+ from langchain.embeddings.openai import OpenAIEmbeddings
13
+ from langchain.vectorstores import Chroma
14
+ from langchain.text_splitter import CharacterTextSplitter
15
+ from langchain.llms import OpenAI
16
+ from langchain.chains import RetrievalQA
17
+ from langchain.document_loaders import TextLoader
18
  import os
19
+ os.environ['OPENAI_API_KEY']="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"
20
+ import warnings
21
+ warnings.filterwarnings("ignore")
22
+
23
+
24
+ st.markdown(
25
+ """
26
+ <style>
27
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
28
+ width: 325px;
29
+ }
30
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
31
+ width: 325px;
32
+ margin-left: -350px;
33
+ }
34
+ </style>
35
+ """,
36
+ unsafe_allow_html=True,
37
+ )
38
+
39
+ st.sidebar.title('ChatFAQ')
40
+ st.sidebar.subheader('Parameters')
41
+
42
+ @st.cache_resource
43
+ def get_models():
44
+ # it may be necessary for other frameworks to cache the model
45
+ # seems pytorch keeps an internal state of the conversation
46
+ model_name = "facebook/blenderbot-400M-distill"
47
+ tokenizer = BlenderbotTokenizer.from_pretrained(model_name)
48
+ model = BlenderbotForConditionalGeneration.from_pretrained(model_name)
49
+ return tokenizer, model
50
+
51
+ st.title("ChatFAQ")
52
+
53
+ app_mode = st.sidebar.selectbox('Choose the App mode',
54
+ ['Blenderbot_1B', 'Blenderbot-400M-distill', 'ChatGPT-3.5', 'Fine-tune Alpaca 7B', 'Customized Alpaca 7B', 'Alpaca-LORA']
55
+ )
56
+
57
+ # app_mode = st.sidebar.selectbox('Choose the domain',
58
+ # ['Law','Economic','Technology']
59
+ # )
60
+
61
+ uploaded_file = st.sidebar.file_uploader("Choose a file")
62
+ if uploaded_file is not None:
63
+ string_data = ""
64
+ file_type = uploaded_file.type
65
+ if file_type == "application/pdf":
66
+ bytes_data = uploaded_file.getvalue()
67
+
68
+ # Create a BytesIO object from the bytes data
69
+ bytes_io = BytesIO(bytes_data)
70
+
71
+ # Create a PDF reader object
72
+ pdf_reader = PyPDF2.PdfReader(bytes_io)
73
+
74
+ # Get the number of pages in the PDF file
75
+ num_pages = len(pdf_reader.pages)
76
+ # Loop through each page and extract the text
77
+ for i in range(num_pages):
78
+ page = pdf_reader.pages[i]
79
+ text = page.extract_text()
80
+ string_data = string_data + text
81
+ elif file_type == "text/plain":
82
+ with st.spinner('Loading the document...'):
83
+ # To convert to a string based IO:
84
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
85
+
86
+ # To read file as string:
87
+ string_data = stringio.read()
88
+ st.success('Loading successfully!')
89
 
90
+ if app_mode =='Blenderbot_1B':
91
+ st.markdown('In this application, **Blenderbot_1B API** is used and **StreamLit** is to create the Web Graphical User Interface (GUI).')
92
+ st.markdown(
93
+ """
94
+ <style>
95
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
96
+ width: 300px;
97
+ }
98
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
99
+ width: 300px;
100
+ margin-left: -400px;
101
+ }
102
+ </style>
103
+ """,
104
+ unsafe_allow_html=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ if 'history1' not in st.session_state:
108
+ st.session_state['history1'] = []
109
+
110
+ API_TOKEN = "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"
111
+ API_URL = "https://api-inference.huggingface.co/models/facebook/blenderbot-1B-distill"
112
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
113
+
114
+ def query(payload):
115
+ response = requests.post(API_URL, headers=headers, json=payload)
116
+ return response.json()
117
+
118
+ def generate_answer():
119
+ historyInputs = {"past_user_inputs": [],
120
+ "generated_responses": []}
121
+ for element in st.session_state["history1"]:
122
+ if element["is_user"] == True:
123
+ historyInputs["past_user_inputs"].append(element["message"])
124
+ else:
125
+ historyInputs["generated_responses"].append(element["message"])
126
+ user_message = st.session_state.input_text
127
+ historyInputs["text"] = user_message if user_message != "" else " "
128
+ print(historyInputs)
129
+ output = query({
130
+ "inputs": historyInputs,
131
+ })
132
+ print(output)
133
+ print(output["generated_text"])
134
+ st.session_state['history1'].append({"message": user_message, "is_user": True})
135
+ st.session_state['history1'].append({"message": output["generated_text"], "is_user": False})
136
+ print(st.session_state['history1'])
137
+
138
+ for chat in st.session_state['history1']:
139
+ st_message(**chat) # unpacking
140
+
141
+ st.text_input("Talk to the bot", key="input_text", on_change=generate_answer)
142
 
143
+ if st.button("Clear"):
144
+ st.session_state["history1"] = []
145
+ for chat in st.session_state['history1']:
146
+ st_message(**chat) # unpacking
147
+
148
+ if app_mode =='Blenderbot-400M-distill':
149
+ st.markdown('In this application, **Blenderbot-400M-distill API** is used and **StreamLit** is to create the Web Graphical User Interface (GUI).')
150
+ st.markdown(
151
+ """
152
+ <style>
153
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
154
+ width: 300px;
155
+ }
156
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
157
+ width: 300px;
158
+ margin-left: -400px;
159
+ }
160
+ </style>
161
+ """,
162
+ unsafe_allow_html=True,
163
+ )
164
 
165
+ if 'history2' not in st.session_state:
166
+ st.session_state['history2'] = []
167
+
168
+ def generate_answer():
169
+ tokenizer, model = get_models()
170
+ user_message = st.session_state.input_text
171
+ print(type(user_message), user_message)
172
+ History_inputs = []
173
+ for element in st.session_state["history2"]:
174
+ if element["is_user"] == True:
175
+ History_inputs.append(element["message"])
176
+ historyInputs = ". ".join(History_inputs)
177
+ print(historyInputs + " " + st.session_state.input_text)
178
+ inputs = tokenizer(historyInputs + " . " + st.session_state.input_text, return_tensors="pt")
179
+ result = model.generate(**inputs)
180
+ message_bot = tokenizer.decode(
181
+ result[0], skip_special_tokens=True
182
+ ) # .replace("<s>", "").replace("</s>", "")
183
+
184
+ st.session_state['history2'].append({"message": user_message, "is_user": True})
185
+ st.session_state['history2'].append({"message": message_bot, "is_user": False})
186
+
187
+ for chat in st.session_state['history2']:
188
+ st_message(**chat) # unpacking
189
+
190
+ st.text_input("Talk to the bot", key="input_text", on_change=generate_answer)
191
+
192
+ if st.button("Clear"):
193
+ st.session_state["history2"] = []
194
+ for chat in st.session_state['history2']:
195
+ st_message(**chat) # unpacking
196
+
197
+ if app_mode =='ChatGPT-3.5':
198
+ counter = 0
199
+
200
+ def get_unique_key():
201
+ global counter
202
+ counter += 1
203
+ return f"chat{counter}"
204
+
205
+ OPENAI_KEY="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"
206
+ openai.api_key = OPENAI_KEY
207
+ openai_engine = openai.ChatCompletion()
208
+
209
+ if 'history3' not in st.session_state:
210
+ st.session_state['history3'] = []
211
+
212
+ if "messages" not in st.session_state:
213
+ st.session_state["messages"] = []
214
+
215
+ if "messagesDocument" not in st.session_state:
216
+ st.session_state["messagesDocument"] = []
217
+
218
+ def generate_answer():
219
+ st.session_state["messages"] += [{"role": "user", "content": st.session_state.input_text}]
220
+ response = openai.ChatCompletion.create(
221
+ model="gpt-3.5-turbo", messages=st.session_state["messages"]
222
+ )
223
+ message_response = response["choices"][0]["message"]["content"]
224
+ st.session_state["messages"] += [
225
+ {"role": "system", "content": message_response}
226
+ ]
227
+ st.session_state['history3'].append({"message": st.session_state.input_text, "is_user": True})
228
+ st.session_state['history3'].append({"message": message_response, "is_user": False})
229
+ print(st.session_state['history3'])
230
+ print(st.session_state["messages"])
231
+
232
+ if st.button("Retrieve the document's content"):
233
+ if uploaded_file is None:
234
+ st.error("Please input the document!", icon="🚨")
235
+ else:
236
+ with st.spinner('Wait for processing the document...'):
237
+ with open("my_text.txt", "w", encoding='utf-8') as f:
238
+ f.write(string_data)
239
+ loader = TextLoader("my_text.txt", encoding='utf-8')
240
+ documents = loader.load()
241
+ print(type(documents), documents)
242
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
243
+ texts = text_splitter.split_documents(documents)
244
+
245
+ embeddings = OpenAIEmbeddings()
246
+ docsearch = Chroma.from_documents(texts, embeddings)
247
+
248
+ qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=docsearch.as_retriever(search_kwargs={"k": 1}))
249
+ st.success('Successful!')
250
+
251
+ def generate_answer():
252
+ query = st.session_state.input_text
253
+ docs = qa.run(query)
254
+ system_prompt_first = """
255
+ You are a helpful assisant that help user with answering questions over a content that was pulled from a database
256
+ ---CONTENT START---\n
257
+ """
258
+ system_prompt_second = """
259
+ \n---CONTENT END---
260
+ Based on information pulled from the database, answer the question below from the user. If the content pulled from the database is not related to the question, say "I do not have enough information for this question"
261
+ Question:
262
+ """
263
+ system_prompt_ans = "\nAnswer:"
264
+ prompt = system_prompt_first + docs + system_prompt_second + query + system_prompt_ans
265
+ print(prompt)
266
+ st.session_state["messagesDocument"] += [{"role": "user", "content": prompt}]
267
+ message_response = openai_engine.create(model='gpt-3.5-turbo',messages=st.session_state["messagesDocument"])
268
+ st.session_state['history3'].append({"message": st.session_state.input_text, "is_user": True})
269
+ st.session_state['history3'].append({"message": message_response.choices[0].message.content, "is_user": False})
270
+ st.session_state["messagesDocument"] += [
271
+ {"role": "system", "content": message_response.choices[0].message.content}
272
+ ]
273
+ print(st.session_state["messagesDocument"])
274
+
275
+ st.markdown("""
276
+ <style>
277
+ .chatbox {
278
+ max-height: 300px;
279
+ overflow-y: auto;
280
+ }
281
+ </style>
282
+ """, unsafe_allow_html=True)
283
+
284
+ for chat in st.session_state['history3']:
285
+ st_message(**chat, key=get_unique_key()) # unpacking
286
+
287
+
288
+ st.text_input("Talk to the bot: ",placeholder = "Ask me anything ...", key="input_text", on_change=generate_answer)
289
+
290
+ if st.button("Clear"):
291
+ st.session_state["history3"] = []
292
+ st.session_state["messages"] = []
293
+ for chat in st.session_state['history3']:
294
+ st_message(**chat, key=get_unique_key()) # unpacking
295
+
296
+
297
+ if app_mode =='Fine-tune Alpaca 7B':
298
+ st.markdown('In this application, we are using **Fine-tune Alpaca 7B API** and **StreamLit** is to create the Web Graphical User Interface (GUI). ')
299
+ st.markdown(
300
+ """
301
+ <style>
302
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
303
+ width: 300px;
304
+ }
305
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
306
+ width: 300px;
307
+ margin-left: -400px;
308
+ }
309
+ </style>
310
+ """,
311
+ unsafe_allow_html=True,
312
+ )
313
 
 
 
 
 
 
 
314
 
 
 
315
 
316
+ if app_mode =='Customized Alpaca 7B':
317
+ st.markdown('In this application, we are using **PART - Part Attention Regressor for 3D Human Body Estimation [ICCV 2021]** for creating Body Mesh and **Dynamic Time Warping** for comparing poses. **StreamLit** is to create the Web Graphical User Interface (GUI). ')
318
+ st.markdown(
319
+ """
320
+ <style>
321
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
322
+ width: 300px;
323
+ }
324
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
325
+ width: 300px;
326
+ margin-left: -400px;
327
+ }
328
+ </style>
329
+ """,
330
+ unsafe_allow_html=True,
331
+ )
332
 
333
+ if app_mode =='Alpaca-LORA':
334
+ st.markdown('In this application, we are using **PART - Part Attention Regressor for 3D Human Body Estimation [ICCV 2021]** for creating Body Mesh and **Dynamic Time Warping** for comparing poses. **StreamLit** is to create the Web Graphical User Interface (GUI). ')
335
+ st.markdown(
336
+ """
337
+ <style>
338
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
339
+ width: 300px;
340
+ }
341
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
342
+ width: 300px;
343
+ margin-left: -400px;
344
+ }
345
+ </style>
346
+ """,
347
+ unsafe_allow_html=True,
348
+ )
349
+ def generate_prompt(instruction: str, input_ctxt: str = None) -> str:
350
+ if input_ctxt:
351
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
352
+ ### Instruction:
353
+ {instruction}
354
 
355
+ ### Input:
356
+ {input_ctxt}
357
 
358
+ ### Response:"""
359
+ else:
360
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
361
 
362
+ ### Instruction:
363
+ {instruction}
364
 
365
+ ### Response:"""
 
366
 
367
+ tokenizer = LlamaTokenizer.from_pretrained("chainyo/alpaca-lora-7b")
368
+ model = LlamaForCausalLM.from_pretrained(
369
+ "chainyo/alpaca-lora-7b",
370
+ load_in_8bit=True,
371
+ torch_dtype=torch.float16,
372
+ device_map="auto",
373
+ )
374
+ generation_config = GenerationConfig(
375
+ temperature=0.2,
376
+ top_p=0.75,
377
+ top_k=40,
378
+ num_beams=4,
379
+ max_new_tokens=128,
380
+ )
381
 
382
+ model.eval()
383
+ if torch.__version__ >= "2":
384
+ model = torch.compile(model)
385
+
386
+ instruction = "What is the meaning of life?"
387
+ input_ctxt = None # For some tasks, you can provide an input context to help the model generate a better response.
388
+
389
+ prompt = generate_prompt(instruction, input_ctxt)
390
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
391
+ input_ids = input_ids.to(model.device)
392
+
393
+ with torch.no_grad():
394
+ outputs = model.generate(
395
+ input_ids=input_ids,
396
+ generation_config=generation_config,
397
+ return_dict_in_generate=True,
398
+ output_scores=True,
399
+ )
400
+
401
+ response = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
402
+ print(response)
403
+
404
+
405
+ # def generate_answer():
406
+ # tokenizer, model = get_models()
407
+ # user_message = st.session_state.input_text
408
+ # print(type(user_message), user_message)
409
+ # History_inputs = []
410
+ # for element in st.session_state["history"]:
411
+ # if element["is_user"] == True:
412
+ # History_inputs.append(element["message"])
413
+ # historyInputs = ". ".join(History_inputs)
414
+ # print(historyInputs + " " + st.session_state.input_text)
415
+ # inputs = tokenizer(historyInputs + " . " + st.session_state.input_text, return_tensors="pt")
416
+ # result = model.generate(**inputs)
417
+ # message_bot = tokenizer.decode(
418
+ # result[0], skip_special_tokens=True
419
+ # ) # .replace("<s>", "").replace("</s>", "")
420
+
421
+ # st.session_state['history'].append({"message": user_message, "is_user": True})
422
+ # st.session_state['history'].append({"message": message_bot, "is_user": False})
423
+
424
+ # for chat in st.session_state['history']:
425
+ # st_message(**chat) # unpacking
426
+
427
+ # st.text_input("Talk to the bot", key="input_text", on_change=generate_answer)
428