sonoisa commited on
Commit
187910a
·
verified ·
1 Parent(s): c16227d

Add a flag to enable/disable RAG feature

Browse files
Files changed (1) hide show
  1. index.html +40 -26
index.html CHANGED
@@ -553,7 +553,7 @@ actual_total_cost_prompt = 0
553
  actual_total_cost_completion = 0
554
 
555
 
556
- async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature):
557
  """
558
  ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。
559
 
@@ -569,6 +569,7 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
569
  model_name (str): 使用するAIモデルの名前
570
  max_tokens (int): 生成する最大トークン数
571
  temperature (float): クリエイティビティの度合いを示す温度パラメータ
 
572
 
573
  Returns:
574
  str: ChatGPTによる生成結果
@@ -593,15 +594,24 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
593
  http_client=http_client
594
  )
595
 
596
- completion = openai_client.chat.completions.create(
597
- messages=messages,
598
- model=model_name,
599
- max_tokens=max_tokens,
600
- temperature=temperature,
601
- tools=CHAT_TOOLS,
602
- tool_choice="auto",
603
- stream=False
604
- )
 
 
 
 
 
 
 
 
 
605
 
606
  bot_response = ""
607
  if hasattr(completion, "error"):
@@ -742,6 +752,7 @@ DEFAULT_SETTINGS = {
742
  "model_name": "gpt-4-turbo-preview",
743
  "max_tokens": 4096,
744
  "temperature": 0.2,
 
745
  "save_chat_history_to_url": False
746
  };
747
 
@@ -791,6 +802,7 @@ def main():
791
  entry["model_name"] || default_model_name,
792
  entry["max_tokens"] || default_max_tokens,
793
  entry["temperature"] || default_temperature,
 
794
  entry["save_chat_history_to_url"] || default_save_chat_history_to_url
795
  ]);
796
  }
@@ -798,7 +810,7 @@ def main():
798
  saved_settings = default_saved_settings;
799
  }
800
 
801
- return [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url, saved_settings];
802
  };
803
 
804
  globalThis.resetSettings = () => {
@@ -1007,6 +1019,10 @@ def main():
1007
  temperature.change(None, inputs=temperature, outputs=None,
1008
  js='(x) => saveItem("temperature", x)', show_progress="hidden")
1009
 
 
 
 
 
1010
  save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
1011
 
1012
  reset_button = gr.Button("Reset Settings")
@@ -1017,10 +1033,10 @@ def main():
1017
  saved_settings_df = gr.Dataframe(
1018
  elem_id="saved_settings",
1019
  value=[default_saved_settings],
1020
- headers=["Name", "Platform", "Endpoint", "Azure Deployment", "Azure API Version", "Model", "Max Tokens", "Temperature", "Save Chat History to URL"],
1021
  row_count=(0, "dynamic"),
1022
- col_count=(9, "fixed"),
1023
- datatype=["str", "str", "str", "str", "str", "str", "number", "number", "bool"],
1024
  type="array",
1025
  label="Saved Settings",
1026
  show_label=True,
@@ -1057,15 +1073,15 @@ def main():
1057
 
1058
  row_index = selected_setting[0]
1059
 
1060
- setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url = saved_settings[row_index]
1061
 
1062
- return u(setting_name), u(platform), u(endpoint), u(azure_deployment), u(azure_api_version), u(model_name), u(max_tokens), u(temperature), u(save_chat_history_to_url), None
1063
 
1064
 
1065
- load_saved_settings_button.click(load_saved_setting, inputs=[saved_settings_df, selected_setting], outputs=[setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url, selected_setting], queue=False, show_progress="hidden")
1066
 
1067
 
1068
- def append_or_overwrite_setting(saved_settings, setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url):
1069
 
1070
  setting_name = setting_name.strip()
1071
 
@@ -1073,13 +1089,13 @@ def main():
1073
  new_saved_settings = []
1074
  for entry in saved_settings:
1075
  if entry[0] == setting_name:
1076
- new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, save_chat_history_to_url])
1077
  found = True
1078
  else:
1079
  new_saved_settings.append(entry)
1080
 
1081
  if not found:
1082
- new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, save_chat_history_to_url])
1083
 
1084
  return new_saved_settings, None
1085
 
@@ -1095,7 +1111,7 @@ def main():
1095
 
1096
 
1097
  append_or_overwrite_saved_settings_button.click(
1098
- append_or_overwrite_setting, inputs=[saved_settings_df, setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, save_chat_history_to_url], outputs=[saved_settings_df, selected_setting], queue=False, show_progress="hidden"
1099
  ).then(
1100
  serialize_saved_settings, inputs=saved_settings_df, outputs=serialized_saved_settings_state, queue=False, show_progress="hidden",
1101
  ).then(
@@ -1127,8 +1143,7 @@ def main():
1127
  temp_saved_settings = gr.JSON(visible=False)
1128
  temp_saved_settings.change(lambda x: x, inputs=temp_saved_settings, outputs=saved_settings_df, queue=False, show_progress="hidden")
1129
 
1130
- setting_items = [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens,
1131
- temperature, save_chat_history_to_url, temp_saved_settings]
1132
  reset_button.click(None, inputs=None, outputs=setting_items,
1133
  js="() => resetSettings()", show_progress="hidden")
1134
 
@@ -1147,8 +1162,7 @@ def main():
1147
 
1148
  with gr.Column(scale=2):
1149
 
1150
- additional_inputs = [context, platform, endpoint, azure_deployment, azure_api_version, api_key,
1151
- model_name, max_tokens, temperature]
1152
 
1153
  with gr.Blocks() as chat:
1154
  gr.Markdown(f"# Chat with your PDF")
@@ -1326,4 +1340,4 @@ main()
1326
  </script>
1327
  <script type="module" crossorigin src="https://cdn.jsdelivr.net/npm/@gradio/lite@4.29.0/dist/lite.js"></script>
1328
  </body>
1329
- </html>
 
553
  actual_total_cost_completion = 0
554
 
555
 
556
+ async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature, enable_rag):
557
  """
558
  ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。
559
 
 
569
  model_name (str): 使用するAIモデルの名前
570
  max_tokens (int): 生成する最大トークン数
571
  temperature (float): クリエイティビティの度合いを示す温度パラメータ
572
+ enable_rag (bool): RAG機能を有効にするかどうか
573
 
574
  Returns:
575
  str: ChatGPTによる生成結果
 
594
  http_client=http_client
595
  )
596
 
597
+ if enable_rag:
598
+ completion = openai_client.chat.completions.create(
599
+ messages=messages,
600
+ model=model_name,
601
+ max_tokens=max_tokens,
602
+ temperature=temperature,
603
+ tools=CHAT_TOOLS,
604
+ tool_choice="auto",
605
+ stream=False
606
+ )
607
+ else:
608
+ completion = openai_client.chat.completions.create(
609
+ messages=messages,
610
+ model=model_name,
611
+ max_tokens=max_tokens,
612
+ temperature=temperature,
613
+ stream=False
614
+ )
615
 
616
  bot_response = ""
617
  if hasattr(completion, "error"):
 
752
  "model_name": "gpt-4-turbo-preview",
753
  "max_tokens": 4096,
754
  "temperature": 0.2,
755
+ "enable_rag": True,
756
  "save_chat_history_to_url": False
757
  };
758
 
 
802
  entry["model_name"] || default_model_name,
803
  entry["max_tokens"] || default_max_tokens,
804
  entry["temperature"] || default_temperature,
805
+ entry["enable_rag"] || default_enable_rag,
806
  entry["save_chat_history_to_url"] || default_save_chat_history_to_url
807
  ]);
808
  }
 
810
  saved_settings = default_saved_settings;
811
  }
812
 
813
+ return [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, saved_settings];
814
  };
815
 
816
  globalThis.resetSettings = () =&gt; {
 
1019
  temperature.change(None, inputs=temperature, outputs=None,
1020
  js='(x) =&gt; saveItem("temperature", x)', show_progress="hidden")
1021
 
1022
+ enable_rag = gr.Checkbox(label="Enable RAG (Retrieval Augmented Generation)", interactive=True)
1023
+ enable_rag.change(None, inputs=enable_rag, outputs=None,
1024
+ js='(x) =&gt; saveItem("enable_rag", x)', show_progress="hidden")
1025
+
1026
  save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
1027
 
1028
  reset_button = gr.Button("Reset Settings")
 
1033
  saved_settings_df = gr.Dataframe(
1034
  elem_id="saved_settings",
1035
  value=[default_saved_settings],
1036
+ headers=["Name", "Platform", "Endpoint", "Azure Deployment", "Azure API Version", "Model", "Max Tokens", "Temperature", "Enable RAG", "Save Chat History to URL"],
1037
  row_count=(0, "dynamic"),
1038
+ col_count=(10, "fixed"),
1039
+ datatype=["str", "str", "str", "str", "str", "str", "number", "number", "bool", "bool"],
1040
  type="array",
1041
  label="Saved Settings",
1042
  show_label=True,
 
1073
 
1074
  row_index = selected_setting[0]
1075
 
1076
+ setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url = saved_settings[row_index]
1077
 
1078
+ return u(setting_name), u(platform), u(endpoint), u(azure_deployment), u(azure_api_version), u(model_name), u(max_tokens), u(temperature), u(enable_rag), u(save_chat_history_to_url), None
1079
 
1080
 
1081
+ load_saved_settings_button.click(load_saved_setting, inputs=[saved_settings_df, selected_setting], outputs=[setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, selected_setting], queue=False, show_progress="hidden")
1082
 
1083
 
1084
+ def append_or_overwrite_setting(saved_settings, setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url):
1085
 
1086
  setting_name = setting_name.strip()
1087
 
 
1089
  new_saved_settings = []
1090
  for entry in saved_settings:
1091
  if entry[0] == setting_name:
1092
+ new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url])
1093
  found = True
1094
  else:
1095
  new_saved_settings.append(entry)
1096
 
1097
  if not found:
1098
+ new_saved_settings.append([setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url])
1099
 
1100
  return new_saved_settings, None
1101
 
 
1111
 
1112
 
1113
  append_or_overwrite_saved_settings_button.click(
1114
+ append_or_overwrite_setting, inputs=[saved_settings_df, setting_name, platform, endpoint, azure_deployment, azure_api_version,model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url], outputs=[saved_settings_df, selected_setting], queue=False, show_progress="hidden"
1115
  ).then(
1116
  serialize_saved_settings, inputs=saved_settings_df, outputs=serialized_saved_settings_state, queue=False, show_progress="hidden",
1117
  ).then(
 
1143
  temp_saved_settings = gr.JSON(visible=False)
1144
  temp_saved_settings.change(lambda x: x, inputs=temp_saved_settings, outputs=saved_settings_df, queue=False, show_progress="hidden")
1145
 
1146
+ setting_items = [setting_name, platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, enable_rag, save_chat_history_to_url, temp_saved_settings]
 
1147
  reset_button.click(None, inputs=None, outputs=setting_items,
1148
  js="() =&gt; resetSettings()", show_progress="hidden")
1149
 
 
1162
 
1163
  with gr.Column(scale=2):
1164
 
1165
+ additional_inputs = [context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature, enable_rag]
 
1166
 
1167
  with gr.Blocks() as chat:
1168
  gr.Markdown(f"# Chat with your PDF")
 
1340
  </script>
1341
  <script type="module" crossorigin src="https://cdn.jsdelivr.net/npm/@gradio/lite@4.29.0/dist/lite.js"></script>
1342
  </body>
1343
+ </html>