nxphi47 commited on
Commit
7d3d5c9
1 Parent(s): 5d40c70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -6
app.py CHANGED
@@ -172,6 +172,80 @@ EOS_TOKEN = '</s>'
172
  SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
173
 
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  # ============ CONSTANT ============
176
  # https://github.com/gradio-app/gradio/issues/884
177
  MODEL_NAME = "SeaLLM-7B"
@@ -771,7 +845,7 @@ def chat_response_stream_multiturn(
771
  presence_penalty: float,
772
  system_prompt: Optional[str] = SYSTEM_PROMPT_1,
773
  current_time: Optional[float] = None,
774
- profile: Optional[gr.OAuthProfile] = None,
775
  ) -> str:
776
  """
777
  gr.Number(value=temperature, label='Temperature (higher -> more random)'),
@@ -794,7 +868,8 @@ def chat_response_stream_multiturn(
794
  global llm, RES_PRINTED
795
  assert llm is not None
796
  assert system_prompt.strip() != '', f'system prompt is empty'
797
- is_by_pass = False if profile is None else profile.username in BYPASS_USERS
 
798
 
799
  tokenizer = llm.get_tokenizer()
800
  # force removing all
@@ -876,6 +951,32 @@ def chat_response_stream_multiturn(
876
 
877
 
878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879
  def debug_generate_free_form_stream(message):
880
  output = " This is a debugging message...."
881
  for i in range(len(output)):
@@ -1450,6 +1551,61 @@ def create_chat_demo(title=None, description=None):
1450
  return demo_chat
1451
 
1452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1453
 
1454
  def launch_demo():
1455
  global demo, llm, DEBUG, LOG_FILE
@@ -1544,18 +1700,29 @@ def launch_demo():
1544
 
1545
  if ENABLE_BATCH_INFER:
1546
 
1547
- demo_file_upload = create_file_upload_demo()
1548
 
1549
  demo_free_form = create_free_form_generation_demo()
1550
 
1551
  demo_chat = create_chat_demo()
 
1552
  descriptions = model_desc
1553
  if DISPLAY_MODEL_PATH:
1554
  descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
1555
 
1556
  demo = CustomTabbedInterface(
1557
- interface_list=[demo_chat, demo_file_upload, demo_free_form],
1558
- tab_names=["Chat Interface", "Batch Inference", "Free-form"],
 
 
 
 
 
 
 
 
 
 
1559
  title=f"{model_title}",
1560
  description=descriptions,
1561
  )
@@ -1582,7 +1749,7 @@ def launch_demo():
1582
  if ENABLE_AGREE_POPUP:
1583
  demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
1584
 
1585
- login_btn = gr.LoginButton()
1586
 
1587
  demo.queue(api_open=False)
1588
  return demo
 
172
  SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
173
 
174
 
175
+
176
+ # ######### RAG PREPARE
177
+ RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
178
+
179
+ RAG_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
180
+
181
+
182
+ def load_embeddings():
183
+ global RAG_EMBED
184
+ if RAG_EMBED is None:
185
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
186
+ print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
187
+ RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True})
188
+ else:
189
+ print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
190
+ return RAG_EMBED
191
+
192
+
193
+ def get_rag_embeddings():
194
+ return load_embeddings()
195
+
196
+ _ = get_rag_embeddings()
197
+
198
+ RAG_CURRENT_VECTORSTORE = None
199
+
200
+ def load_document_split_vectorstore(file_path):
201
+ global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
202
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
203
+ from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
204
+ from langchain_community.vectorstores import Chroma, FAISS
205
+ from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
206
+ # assert RAG_EMBED is not None
207
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
208
+ if file_path.endswith('.pdf'):
209
+ loader = PyPDFLoader(file_path)
210
+ elif file_path.endswith('.docx'):
211
+ loader = Docx2txtLoader(file_path)
212
+ elif file_path.endswith('.txt'):
213
+ loader = TextLoader(file_path)
214
+ splits = loader.load_and_split(splitter)
215
+ RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
216
+ return RAG_CURRENT_VECTORSTORE
217
+
218
+
219
+ def docs_to_rag_context(docs: List[str]):
220
+ contexts = "\n".join([d.page_content for d in docs])
221
+ context = f"""### Begin document
222
+ {contexts}
223
+ ### End document
224
+ Asnwer the following query exclusively based on the information provided in the document above. \
225
+ Remember to follow the language of the user query.
226
+ """
227
+ return context
228
+
229
+ def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
230
+ global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
231
+ doc_context = None
232
+ if file_input is not None:
233
+ assert os.path.exists(file_input), f"not found: {file_input}"
234
+ if file_input == RAG_CURRENT_FILE:
235
+ # reuse
236
+ vectorstore = RAG_CURRENT_VECTORSTORE
237
+ print(f'Reuse vectorstore: {file_input}')
238
+ else:
239
+ vectorstore = load_document_split_vectorstore(file_input)
240
+ print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
241
+ RAG_CURRENT_FILE = file_input
242
+ docs = vectorstore.similarity_search(message, k=rag_num_docs)
243
+ doc_context = docs_to_rag_context(docs)
244
+ return doc_context
245
+
246
+ # ######### RAG PREPARE
247
+
248
+
249
  # ============ CONSTANT ============
250
  # https://github.com/gradio-app/gradio/issues/884
251
  MODEL_NAME = "SeaLLM-7B"
 
845
  presence_penalty: float,
846
  system_prompt: Optional[str] = SYSTEM_PROMPT_1,
847
  current_time: Optional[float] = None,
848
+ # profile: Optional[gr.OAuthProfile] = None,
849
  ) -> str:
850
  """
851
  gr.Number(value=temperature, label='Temperature (higher -> more random)'),
 
868
  global llm, RES_PRINTED
869
  assert llm is not None
870
  assert system_prompt.strip() != '', f'system prompt is empty'
871
+ # is_by_pass = False if profile is None else profile.username in BYPASS_USERS
872
+ is_by_pass = False
873
 
874
  tokenizer = llm.get_tokenizer()
875
  # force removing all
 
951
 
952
 
953
 
954
+ def chat_response_stream_rag_multiturn(
955
+ message: str,
956
+ history: List[Tuple[str, str]],
957
+ file_input: str,
958
+ temperature: float,
959
+ max_tokens: int,
960
+ # frequency_penalty: float,
961
+ # presence_penalty: float,
962
+ system_prompt: Optional[str] = SYSTEM_PROMPT_1,
963
+ current_time: Optional[float] = None,
964
+ rag_num_docs: Optional[int] = 3,
965
+ ):
966
+ message = message.strip()
967
+ frequency_penalty = FREQUENCE_PENALTY
968
+ presence_penalty = PRESENCE_PENALTY
969
+ if len(message) == 0:
970
+ raise gr.Error("The message cannot be empty!")
971
+ doc_context = maybe_get_doc_context(message, file_input, rag_num_docs=rag_num_docs)
972
+ if doc_context is not None:
973
+ message = f"{doc_context}\n\n{message}"
974
+ yield from chat_response_stream_multiturn(
975
+ message, history, temperature, max_tokens, frequency_penalty,
976
+ presence_penalty, system_prompt, current_time
977
+ )
978
+
979
+
980
  def debug_generate_free_form_stream(message):
981
  output = " This is a debugging message...."
982
  for i in range(len(output)):
 
1551
  return demo_chat
1552
 
1553
 
1554
+ def upload_file(file):
1555
+ # file_paths = [file.name for file in files]
1556
+ # return file_paths
1557
+ return file.name
1558
+
1559
+ def create_chat_demo_rag(title=None, description=None):
1560
+ sys_prompt = SYSTEM_PROMPT_1
1561
+ max_tokens = MAX_TOKENS
1562
+ temperature = TEMPERATURE
1563
+ frequence_penalty = FREQUENCE_PENALTY
1564
+ presence_penalty = PRESENCE_PENALTY
1565
+
1566
+ # with gr.Blocks(title="RAG") as rag_demo:
1567
+ additional_inputs = [
1568
+ # gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt', 'json']),
1569
+ gr.Textbox(value=None, label='Document path', lines=1, interactive=False),
1570
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1571
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1572
+ # gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1573
+ # gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1574
+ gr.Textbox(value=sys_prompt, label='System prompt', lines=1, interactive=False),
1575
+ gr.Number(value=0, label='current_time', visible=False),
1576
+ ]
1577
+
1578
+
1579
+ demo_rag_chat = gr.ChatInterface(
1580
+ chat_response_stream_rag_multiturn,
1581
+ chatbot=gr.Chatbot(
1582
+ label=MODEL_NAME + "-RAG",
1583
+ bubble_full_width=False,
1584
+ latex_delimiters=[
1585
+ { "left": "$", "right": "$", "display": False},
1586
+ { "left": "$$", "right": "$$", "display": True},
1587
+ ],
1588
+ show_copy_button=True,
1589
+ ),
1590
+ textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
1591
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1592
+ # ! consider preventing the stop button
1593
+ # stop_btn=None,
1594
+ title=title,
1595
+ description=description,
1596
+ additional_inputs=additional_inputs,
1597
+ additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
1598
+ # examples=CHAT_EXAMPLES,
1599
+ cache_examples=False
1600
+ )
1601
+ with demo_rag_chat:
1602
+ upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt', 'json'], file_count="single")
1603
+ upload_button.upload(upload_file, upload_button, additional_inputs[0])
1604
+
1605
+ # return demo_chat
1606
+ return demo_rag_chat
1607
+
1608
+
1609
 
1610
  def launch_demo():
1611
  global demo, llm, DEBUG, LOG_FILE
 
1700
 
1701
  if ENABLE_BATCH_INFER:
1702
 
1703
+ # demo_file_upload = create_file_upload_demo()
1704
 
1705
  demo_free_form = create_free_form_generation_demo()
1706
 
1707
  demo_chat = create_chat_demo()
1708
+ demo_chat_rag = create_chat_demo_rag()
1709
  descriptions = model_desc
1710
  if DISPLAY_MODEL_PATH:
1711
  descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
1712
 
1713
  demo = CustomTabbedInterface(
1714
+ interface_list=[
1715
+ demo_chat,
1716
+ demo_chat_rag,
1717
+ demo_free_form
1718
+ # demo_file_upload,
1719
+ ],
1720
+ tab_names=[
1721
+ "Chat Interface",
1722
+ "RAG Chat Interface"
1723
+ "Text completion"
1724
+ # "Batch Inference",
1725
+ ],
1726
  title=f"{model_title}",
1727
  description=descriptions,
1728
  )
 
1749
  if ENABLE_AGREE_POPUP:
1750
  demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
1751
 
1752
+ # login_btn = gr.LoginButton()
1753
 
1754
  demo.queue(api_open=False)
1755
  return demo